init
This commit is contained in:
398
data_analysis/io/plotting.py
Normal file
398
data_analysis/io/plotting.py
Normal file
@ -0,0 +1,398 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np # Import numpy for CI calculation
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from typing import Optional, List
|
||||
|
||||
# Import analysis tools for plotting results
|
||||
from statsmodels.tsa.seasonal import DecomposeResult
|
||||
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf, seasonal_plot
|
||||
from statsmodels.tsa.stattools import ccf # Import ccf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Plotting Configuration ---
|
||||
# Increase default figure size for better readability
|
||||
plt.rcParams['figure.figsize'] = (15, 7)
|
||||
# Use a clean style
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
|
||||
|
||||
def _save_plot(fig: plt.Figure, output_path: Path) -> Optional[str]:
|
||||
"""Helper to save plots and handle errors."""
|
||||
err = None
|
||||
try:
|
||||
fig.tight_layout() # Adjust layout before saving
|
||||
fig.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Plot saved to: {output_path}")
|
||||
plt.close(fig) # Close the figure to free memory
|
||||
except Exception as e:
|
||||
err = f"Failed to save plot to {output_path}: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig) # Still try to close figure on error
|
||||
return err
|
||||
|
||||
|
||||
def plot_full_time_series(df: pd.DataFrame, price_col: str, output_path: Path) -> Optional[str]:
|
||||
"""Plots the entire time series."""
|
||||
logger.info(f"Generating full time series plot to {output_path}")
|
||||
fig, ax = plt.subplots()
|
||||
err = None
|
||||
try:
|
||||
sns.lineplot(data=df, x=df.index, y=price_col, ax=ax, linewidth=1)
|
||||
ax.set_title('Full Time Series: Price Over Time')
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel(price_col)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting full time series: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
|
||||
def plot_zoomed_time_series(df: pd.DataFrame, price_col: str, start_date: str, end_date: str, output_path: Path) -> Optional[str]:
|
||||
"""Plots a specified time range of the series."""
|
||||
logger.info(f"Generating zoomed time series plot ({start_date} to {end_date}) to {output_path}")
|
||||
fig, ax = plt.subplots()
|
||||
err = None
|
||||
try:
|
||||
# Ensure start_date and end_date are compatible with index type
|
||||
df_zoomed = df.loc[start_date:end_date]
|
||||
if df_zoomed.empty:
|
||||
err = f"No data found in the specified zoom range: {start_date} to {end_date}"
|
||||
logger.warning(err) # Use warning for empty range, not necessarily error
|
||||
plt.close(fig)
|
||||
return err
|
||||
sns.lineplot(data=df_zoomed, x=df_zoomed.index, y=price_col, ax=ax, linewidth=1)
|
||||
ax.set_title(f'Time Series: {start_date} to {end_date}')
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel(price_col)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting zoomed time series: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
|
||||
def plot_boxplot_by_period(df: pd.DataFrame, price_col: str, period: str, output_path: Path) -> Optional[str]:
|
||||
"""
|
||||
Generates box plots of the price grouped by a specific time period.
|
||||
Periods: 'hour', 'dayofweek', 'month', 'year'.
|
||||
"""
|
||||
logger.info(f"Generating box plot by {period} to {output_path}")
|
||||
fig, ax = plt.subplots()
|
||||
err = None
|
||||
try:
|
||||
# Create temporary column for the period
|
||||
if period == 'hour':
|
||||
group_col = df.index.hour
|
||||
title = 'Price Distribution by Hour of Day'
|
||||
x_label = 'Hour'
|
||||
elif period == 'dayofweek':
|
||||
group_col = df.index.dayofweek # Monday=0, Sunday=6
|
||||
title = 'Price Distribution by Day of Week'
|
||||
x_label = 'Day of Week (0=Mon, 6=Sun)'
|
||||
elif period == 'month':
|
||||
group_col = df.index.month
|
||||
title = 'Price Distribution by Month'
|
||||
x_label = 'Month'
|
||||
elif period == 'year':
|
||||
group_col = df.index.year
|
||||
title = 'Price Distribution by Year'
|
||||
x_label = 'Year'
|
||||
else:
|
||||
err = f"Unsupported period '{period}' for boxplot."
|
||||
logger.error(err)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
# Ensure group_col is numeric or categorical for plotting
|
||||
sns.boxplot(x=group_col, y=df[price_col], ax=ax, palette="viridis", hue=group_col)
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(price_col)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting boxplot by {period}: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
# New function signature for seasonal subseries plot
|
||||
def plot_seasonal_subseries(df: pd.DataFrame, price_col: str, period: int, period_name: str, output_path: Path) -> Optional[str]:
|
||||
"""
|
||||
Generates a seasonal subseries plot for a given period (e.g., 24 for daily).
|
||||
"""
|
||||
logger.info(f"Generating seasonal subseries plot for {period_name} (period={period}) to {output_path}")
|
||||
err = None
|
||||
try:
|
||||
# Ensure the index is datetime and frequency is set or can be inferred
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
err = "DataFrame index must be a DatetimeIndex for seasonal subseries plot."
|
||||
logger.error(err)
|
||||
return err
|
||||
|
||||
# Create the appropriate grouping based on the period
|
||||
if period == 24: # Daily
|
||||
grouped = df[price_col].groupby(df.index.hour)
|
||||
xticklabels = [f"{i:02d}:00" for i in range(24)]
|
||||
elif period == 168: # Weekly
|
||||
grouped = df[price_col].groupby(df.index.dayofweek)
|
||||
xticklabels = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
else:
|
||||
# For other periods, create a custom grouping
|
||||
grouped = df[price_col].groupby(df.index % period)
|
||||
xticklabels = [str(i) for i in range(period)]
|
||||
|
||||
# Create the plot using seasonal_plot
|
||||
fig = seasonal_plot(grouped, xticklabels=xticklabels, ylabel=price_col)
|
||||
fig.suptitle(f'Seasonal Subseries Plot ({period_name})', y=1.02)
|
||||
fig.set_size_inches(15, 10)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting seasonal subseries ({period_name}): {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close('all')
|
||||
return err
|
||||
|
||||
|
||||
def plot_histogram(df: pd.DataFrame, price_col: str, output_path: Path, bins: int = 50) -> Optional[str]:
|
||||
"""Plots a histogram of the price values."""
|
||||
logger.info(f"Generating histogram of '{price_col}' to {output_path}")
|
||||
fig, ax = plt.subplots()
|
||||
err = None
|
||||
try:
|
||||
sns.histplot(data=df, x=price_col, bins=bins, kde=True, ax=ax)
|
||||
ax.set_title(f'Distribution of {price_col}')
|
||||
ax.set_xlabel(price_col)
|
||||
ax.set_ylabel('Frequency')
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting histogram: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
|
||||
def plot_decomposition(decomposition_result: DecomposeResult, period_name: str, output_path: Path) -> Optional[str]:
|
||||
"""
|
||||
Plots the observed, trend, seasonal, and residual components from a
|
||||
time series decomposition result.
|
||||
"""
|
||||
logger.info(f"Generating {period_name} decomposition plot to {output_path}")
|
||||
err = None
|
||||
try:
|
||||
# The plot method of DecomposeResult returns a Figure
|
||||
fig = decomposition_result.plot()
|
||||
fig.set_size_inches(15, 10) # Adjust size for better visibility
|
||||
fig.suptitle(f'Time Series Decomposition ({period_name} Seasonality)', y=1.02)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting decomposition ({period_name}): {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
# No access to the fig object if decomposition_result.plot() fails early
|
||||
# Close all figures as a fallback
|
||||
plt.close('all')
|
||||
return err
|
||||
|
||||
|
||||
def plot_residuals(residuals: pd.Series, title_suffix: str, output_path: Path) -> Optional[str]:
|
||||
"""Plots the residuals over time."""
|
||||
logger.info(f"Generating residuals plot ({title_suffix}) to {output_path}")
|
||||
fig, ax = plt.subplots()
|
||||
err = None
|
||||
try:
|
||||
residuals.plot(ax=ax, title=f'Residuals ({title_suffix})')
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Residual Value')
|
||||
# Add a horizontal line at zero
|
||||
ax.axhline(0, color='r', linestyle='--', alpha=0.7)
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting residuals ({title_suffix}): {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
def plot_acf_pacf(series: pd.Series, series_name: str, lags: int | None, output_path_base: Path) -> Optional[str]:
|
||||
"""
|
||||
Plots the Autocorrelation Function (ACF) and Partial Autocorrelation
|
||||
Function (PACF) for a given series, saving them as separate files.
|
||||
"""
|
||||
logger.info(f"Generating ACF/PACF plots for {series_name} to {output_path_base.parent}")
|
||||
err_acf = None
|
||||
err_pacf = None
|
||||
|
||||
# Plot ACF
|
||||
try:
|
||||
fig_acf = plt.figure()
|
||||
ax_acf = fig_acf.add_subplot(111)
|
||||
plot_acf(series, lags=lags, ax=ax_acf, title=f'ACF - {series_name}')
|
||||
acf_path = output_path_base.with_name(f"{output_path_base.stem}_acf.png")
|
||||
err_acf = _save_plot(fig_acf, acf_path)
|
||||
except Exception as e:
|
||||
err_acf = f"Error plotting ACF for {series_name}: {e}"
|
||||
logger.error(err_acf, exc_info=True)
|
||||
plt.close(fig_acf)
|
||||
|
||||
# Plot PACF
|
||||
try:
|
||||
fig_pacf = plt.figure()
|
||||
ax_pacf = fig_pacf.add_subplot(111)
|
||||
# Use method='ywm' for Yule-Walker method, often preferred
|
||||
plot_pacf(series, lags=lags, ax=ax_pacf, title=f'PACF - {series_name}', method='ywm')
|
||||
pacf_path = output_path_base.with_name(f"{output_path_base.stem}_pacf.png")
|
||||
err_pacf = _save_plot(fig_pacf, pacf_path)
|
||||
except Exception as e:
|
||||
err_pacf = f"Error plotting PACF for {series_name}: {e}"
|
||||
logger.error(err_pacf, exc_info=True)
|
||||
plt.close(fig_pacf)
|
||||
|
||||
# Return the first error encountered, or None if both succeeded
|
||||
return err_acf or err_pacf
|
||||
|
||||
|
||||
# Update cross-correlation plot function
|
||||
def plot_cross_correlation(
|
||||
target_series: pd.Series,
|
||||
exog_series: pd.Series,
|
||||
target_name: str,
|
||||
exog_name: str,
|
||||
max_lags: int,
|
||||
output_path: Path
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generates and saves a cross-correlation plot between a target series and an exogenous series.
|
||||
Plots correlation of target_series(t) with exog_series(t-lag).
|
||||
|
||||
Args:
|
||||
target_series: The main time series to analyze
|
||||
exog_series: The exogenous time series to correlate with
|
||||
target_name: Name of the target series for labeling
|
||||
exog_name: Name of the exogenous series for labeling
|
||||
max_lags: Maximum number of lags to compute correlation for
|
||||
output_path: Where to save the plot
|
||||
|
||||
Returns:
|
||||
Optional[str]: Error message if something went wrong, None if successful
|
||||
"""
|
||||
logger.info(f"Generating cross-correlation plot ({target_name} vs {exog_name}) for lags up to {max_lags} to {output_path}")
|
||||
err = None
|
||||
try:
|
||||
# Ensure series are aligned and have no NaNs affecting calculation
|
||||
combined = pd.concat([target_series.rename(target_name), exog_series.rename(exog_name)], axis=1).dropna()
|
||||
|
||||
# Check if we have enough data points
|
||||
if combined.empty or len(combined) <= max_lags:
|
||||
err = f"Not enough overlapping non-NaN data points between {target_name} and {exog_name} for CCF calculation (need > {max_lags})."
|
||||
# Will warn above!
|
||||
# logger.warning(err)
|
||||
return err
|
||||
|
||||
# Check if the exogenous variable actually varies
|
||||
if exog_series.nunique() <= 1:
|
||||
err = f"Cannot compute cross-correlation: {exog_name} has no variation (all values are the same)."
|
||||
# Will warn above!
|
||||
# logger.warning(err)
|
||||
return err
|
||||
|
||||
# Calculate CCF: ccf(x, y) computes corr(x[t], y[t-lag])
|
||||
# We want corr(target[t], exog[t-lag]), so order is ccf(target, exog)
|
||||
cross_corr_values = ccf(combined[target_name], combined[exog_name], adjusted=False, nlags=max_lags)
|
||||
lags_range = range(max_lags + 1) # CCF includes lag 0
|
||||
|
||||
# Plotting
|
||||
fig, ax = plt.subplots()
|
||||
markerline, stemlines, baseline = ax.stem(
|
||||
lags_range, cross_corr_values, markerfmt='o', basefmt="gray"
|
||||
)
|
||||
plt.setp(markerline, markersize=5)
|
||||
plt.setp(stemlines, linewidth=1)
|
||||
|
||||
# Add approximate 95% confidence intervals (Bartlett's formula approximation)
|
||||
conf_level = 1.96 / np.sqrt(len(combined))
|
||||
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
|
||||
|
||||
ax.set_title(f'Cross-Correlation: {target_name}(t) vs {exog_name}(t-lag)')
|
||||
ax.set_xlabel('Lag (k)')
|
||||
ax.set_ylabel(f'Corr({target_name}(t), {exog_name}(t-k))')
|
||||
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
|
||||
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting cross-correlation ({target_name} vs {exog_name}): {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
||||
|
||||
def plot_weekly_autocorrelation(
|
||||
series: pd.Series,
|
||||
series_name: str,
|
||||
output_path: Path,
|
||||
max_weeks: int = 4
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generates and saves an autocorrelation plot between a series and its weekly lags.
|
||||
This helps identify weekly seasonality patterns.
|
||||
|
||||
Args:
|
||||
series: The time series to analyze
|
||||
series_name: Name of the series for labeling
|
||||
output_path: Where to save the plot
|
||||
max_weeks: Maximum number of weeks to look back (default: 4)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Error message if something went wrong, None if successful
|
||||
"""
|
||||
logger.info(f"Generating weekly autocorrelation plot for {series_name} up to {max_weeks} weeks to {output_path}")
|
||||
err = None
|
||||
try:
|
||||
# Ensure series has no NaNs
|
||||
series = series.dropna()
|
||||
if series.empty:
|
||||
err = f"Series {series_name} is empty after dropping NaNs."
|
||||
logger.warning(err)
|
||||
return err
|
||||
|
||||
# Calculate weekly lags (168 hours = 1 week)
|
||||
hours_per_week = 24 * 7
|
||||
max_lags = max_weeks * hours_per_week
|
||||
|
||||
# Calculate autocorrelation
|
||||
autocorr_values = ccf(series, series, adjusted=False, nlags=max_lags)
|
||||
lags_range = list(range(0, min(max_lags + 1, autocorr_values.size - 1), hours_per_week)) # Only plot weekly intervals
|
||||
|
||||
# Plotting
|
||||
fig, ax = plt.subplots()
|
||||
markerline, stemlines, baseline = ax.stem(
|
||||
[lag/hours_per_week for lag in lags_range], # Convert to weeks for x-axis
|
||||
autocorr_values[lags_range],
|
||||
markerfmt='o',
|
||||
basefmt="gray"
|
||||
)
|
||||
plt.setp(markerline, markersize=5)
|
||||
plt.setp(stemlines, linewidth=1)
|
||||
|
||||
# Add approximate 95% confidence intervals
|
||||
conf_level = 1.96 / np.sqrt(len(series))
|
||||
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
|
||||
|
||||
ax.set_title(f'Weekly Autocorrelation: {series_name}')
|
||||
ax.set_xlabel('Lag (weeks)')
|
||||
ax.set_ylabel(f'Corr({series_name}(t), {series_name}(t-lag))')
|
||||
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
|
||||
|
||||
# Add vertical lines at each week
|
||||
for week in range(max_weeks + 1):
|
||||
ax.axvline(x=week, color='gray', linestyle=':', alpha=0.3)
|
||||
|
||||
err = _save_plot(fig, output_path)
|
||||
except Exception as e:
|
||||
err = f"Error plotting weekly autocorrelation for {series_name}: {e}"
|
||||
logger.error(err, exc_info=True)
|
||||
plt.close(fig)
|
||||
return err
|
Reference in New Issue
Block a user