398 lines
16 KiB
Python
398 lines
16 KiB
Python
import logging
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import numpy as np # Import numpy for CI calculation
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from typing import Optional, List
|
|
|
|
# Import analysis tools for plotting results
|
|
from statsmodels.tsa.seasonal import DecomposeResult
|
|
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf, seasonal_plot
|
|
from statsmodels.tsa.stattools import ccf # Import ccf
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Plotting Configuration ---
|
|
# Increase default figure size for better readability
|
|
plt.rcParams['figure.figsize'] = (15, 7)
|
|
# Use a clean style
|
|
plt.style.use('seaborn-v0_8-whitegrid')
|
|
|
|
|
|
def _save_plot(fig: plt.Figure, output_path: Path) -> Optional[str]:
|
|
"""Helper to save plots and handle errors."""
|
|
err = None
|
|
try:
|
|
fig.tight_layout() # Adjust layout before saving
|
|
fig.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
logger.info(f"Plot saved to: {output_path}")
|
|
plt.close(fig) # Close the figure to free memory
|
|
except Exception as e:
|
|
err = f"Failed to save plot to {output_path}: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig) # Still try to close figure on error
|
|
return err
|
|
|
|
|
|
def plot_full_time_series(df: pd.DataFrame, price_col: str, output_path: Path) -> Optional[str]:
|
|
"""Plots the entire time series."""
|
|
logger.info(f"Generating full time series plot to {output_path}")
|
|
fig, ax = plt.subplots()
|
|
err = None
|
|
try:
|
|
sns.lineplot(data=df, x=df.index, y=price_col, ax=ax, linewidth=1)
|
|
ax.set_title('Full Time Series: Price Over Time')
|
|
ax.set_xlabel('Time')
|
|
ax.set_ylabel(price_col)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting full time series: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
|
|
def plot_zoomed_time_series(df: pd.DataFrame, price_col: str, start_date: str, end_date: str, output_path: Path) -> Optional[str]:
|
|
"""Plots a specified time range of the series."""
|
|
logger.info(f"Generating zoomed time series plot ({start_date} to {end_date}) to {output_path}")
|
|
fig, ax = plt.subplots()
|
|
err = None
|
|
try:
|
|
# Ensure start_date and end_date are compatible with index type
|
|
df_zoomed = df.loc[start_date:end_date]
|
|
if df_zoomed.empty:
|
|
err = f"No data found in the specified zoom range: {start_date} to {end_date}"
|
|
logger.warning(err) # Use warning for empty range, not necessarily error
|
|
plt.close(fig)
|
|
return err
|
|
sns.lineplot(data=df_zoomed, x=df_zoomed.index, y=price_col, ax=ax, linewidth=1)
|
|
ax.set_title(f'Time Series: {start_date} to {end_date}')
|
|
ax.set_xlabel('Time')
|
|
ax.set_ylabel(price_col)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting zoomed time series: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
|
|
def plot_boxplot_by_period(df: pd.DataFrame, price_col: str, period: str, output_path: Path) -> Optional[str]:
|
|
"""
|
|
Generates box plots of the price grouped by a specific time period.
|
|
Periods: 'hour', 'dayofweek', 'month', 'year'.
|
|
"""
|
|
logger.info(f"Generating box plot by {period} to {output_path}")
|
|
fig, ax = plt.subplots()
|
|
err = None
|
|
try:
|
|
# Create temporary column for the period
|
|
if period == 'hour':
|
|
group_col = df.index.hour
|
|
title = 'Price Distribution by Hour of Day'
|
|
x_label = 'Hour'
|
|
elif period == 'dayofweek':
|
|
group_col = df.index.dayofweek # Monday=0, Sunday=6
|
|
title = 'Price Distribution by Day of Week'
|
|
x_label = 'Day of Week (0=Mon, 6=Sun)'
|
|
elif period == 'month':
|
|
group_col = df.index.month
|
|
title = 'Price Distribution by Month'
|
|
x_label = 'Month'
|
|
elif period == 'year':
|
|
group_col = df.index.year
|
|
title = 'Price Distribution by Year'
|
|
x_label = 'Year'
|
|
else:
|
|
err = f"Unsupported period '{period}' for boxplot."
|
|
logger.error(err)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
# Ensure group_col is numeric or categorical for plotting
|
|
sns.boxplot(x=group_col, y=df[price_col], ax=ax, palette="viridis", hue=group_col)
|
|
ax.set_title(title)
|
|
ax.set_xlabel(x_label)
|
|
ax.set_ylabel(price_col)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting boxplot by {period}: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
# New function signature for seasonal subseries plot
|
|
def plot_seasonal_subseries(df: pd.DataFrame, price_col: str, period: int, period_name: str, output_path: Path) -> Optional[str]:
|
|
"""
|
|
Generates a seasonal subseries plot for a given period (e.g., 24 for daily).
|
|
"""
|
|
logger.info(f"Generating seasonal subseries plot for {period_name} (period={period}) to {output_path}")
|
|
err = None
|
|
try:
|
|
# Ensure the index is datetime and frequency is set or can be inferred
|
|
if not isinstance(df.index, pd.DatetimeIndex):
|
|
err = "DataFrame index must be a DatetimeIndex for seasonal subseries plot."
|
|
logger.error(err)
|
|
return err
|
|
|
|
# Create the appropriate grouping based on the period
|
|
if period == 24: # Daily
|
|
grouped = df[price_col].groupby(df.index.hour)
|
|
xticklabels = [f"{i:02d}:00" for i in range(24)]
|
|
elif period == 168: # Weekly
|
|
grouped = df[price_col].groupby(df.index.dayofweek)
|
|
xticklabels = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
|
else:
|
|
# For other periods, create a custom grouping
|
|
grouped = df[price_col].groupby(df.index % period)
|
|
xticklabels = [str(i) for i in range(period)]
|
|
|
|
# Create the plot using seasonal_plot
|
|
fig = seasonal_plot(grouped, xticklabels=xticklabels, ylabel=price_col)
|
|
fig.suptitle(f'Seasonal Subseries Plot ({period_name})', y=1.02)
|
|
fig.set_size_inches(15, 10)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting seasonal subseries ({period_name}): {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close('all')
|
|
return err
|
|
|
|
|
|
def plot_histogram(df: pd.DataFrame, price_col: str, output_path: Path, bins: int = 50) -> Optional[str]:
|
|
"""Plots a histogram of the price values."""
|
|
logger.info(f"Generating histogram of '{price_col}' to {output_path}")
|
|
fig, ax = plt.subplots()
|
|
err = None
|
|
try:
|
|
sns.histplot(data=df, x=price_col, bins=bins, kde=True, ax=ax)
|
|
ax.set_title(f'Distribution of {price_col}')
|
|
ax.set_xlabel(price_col)
|
|
ax.set_ylabel('Frequency')
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting histogram: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
|
|
def plot_decomposition(decomposition_result: DecomposeResult, period_name: str, output_path: Path) -> Optional[str]:
|
|
"""
|
|
Plots the observed, trend, seasonal, and residual components from a
|
|
time series decomposition result.
|
|
"""
|
|
logger.info(f"Generating {period_name} decomposition plot to {output_path}")
|
|
err = None
|
|
try:
|
|
# The plot method of DecomposeResult returns a Figure
|
|
fig = decomposition_result.plot()
|
|
fig.set_size_inches(15, 10) # Adjust size for better visibility
|
|
fig.suptitle(f'Time Series Decomposition ({period_name} Seasonality)', y=1.02)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting decomposition ({period_name}): {e}"
|
|
logger.error(err, exc_info=True)
|
|
# No access to the fig object if decomposition_result.plot() fails early
|
|
# Close all figures as a fallback
|
|
plt.close('all')
|
|
return err
|
|
|
|
|
|
def plot_residuals(residuals: pd.Series, title_suffix: str, output_path: Path) -> Optional[str]:
|
|
"""Plots the residuals over time."""
|
|
logger.info(f"Generating residuals plot ({title_suffix}) to {output_path}")
|
|
fig, ax = plt.subplots()
|
|
err = None
|
|
try:
|
|
residuals.plot(ax=ax, title=f'Residuals ({title_suffix})')
|
|
ax.set_xlabel('Time')
|
|
ax.set_ylabel('Residual Value')
|
|
# Add a horizontal line at zero
|
|
ax.axhline(0, color='r', linestyle='--', alpha=0.7)
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting residuals ({title_suffix}): {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
def plot_acf_pacf(series: pd.Series, series_name: str, lags: int | None, output_path_base: Path) -> Optional[str]:
|
|
"""
|
|
Plots the Autocorrelation Function (ACF) and Partial Autocorrelation
|
|
Function (PACF) for a given series, saving them as separate files.
|
|
"""
|
|
logger.info(f"Generating ACF/PACF plots for {series_name} to {output_path_base.parent}")
|
|
err_acf = None
|
|
err_pacf = None
|
|
|
|
# Plot ACF
|
|
try:
|
|
fig_acf = plt.figure()
|
|
ax_acf = fig_acf.add_subplot(111)
|
|
plot_acf(series, lags=lags, ax=ax_acf, title=f'ACF - {series_name}')
|
|
acf_path = output_path_base.with_name(f"{output_path_base.stem}_acf.png")
|
|
err_acf = _save_plot(fig_acf, acf_path)
|
|
except Exception as e:
|
|
err_acf = f"Error plotting ACF for {series_name}: {e}"
|
|
logger.error(err_acf, exc_info=True)
|
|
plt.close(fig_acf)
|
|
|
|
# Plot PACF
|
|
try:
|
|
fig_pacf = plt.figure()
|
|
ax_pacf = fig_pacf.add_subplot(111)
|
|
# Use method='ywm' for Yule-Walker method, often preferred
|
|
plot_pacf(series, lags=lags, ax=ax_pacf, title=f'PACF - {series_name}', method='ywm')
|
|
pacf_path = output_path_base.with_name(f"{output_path_base.stem}_pacf.png")
|
|
err_pacf = _save_plot(fig_pacf, pacf_path)
|
|
except Exception as e:
|
|
err_pacf = f"Error plotting PACF for {series_name}: {e}"
|
|
logger.error(err_pacf, exc_info=True)
|
|
plt.close(fig_pacf)
|
|
|
|
# Return the first error encountered, or None if both succeeded
|
|
return err_acf or err_pacf
|
|
|
|
|
|
# Update cross-correlation plot function
|
|
def plot_cross_correlation(
|
|
target_series: pd.Series,
|
|
exog_series: pd.Series,
|
|
target_name: str,
|
|
exog_name: str,
|
|
max_lags: int,
|
|
output_path: Path
|
|
) -> Optional[str]:
|
|
"""
|
|
Generates and saves a cross-correlation plot between a target series and an exogenous series.
|
|
Plots correlation of target_series(t) with exog_series(t-lag).
|
|
|
|
Args:
|
|
target_series: The main time series to analyze
|
|
exog_series: The exogenous time series to correlate with
|
|
target_name: Name of the target series for labeling
|
|
exog_name: Name of the exogenous series for labeling
|
|
max_lags: Maximum number of lags to compute correlation for
|
|
output_path: Where to save the plot
|
|
|
|
Returns:
|
|
Optional[str]: Error message if something went wrong, None if successful
|
|
"""
|
|
logger.info(f"Generating cross-correlation plot ({target_name} vs {exog_name}) for lags up to {max_lags} to {output_path}")
|
|
err = None
|
|
try:
|
|
# Ensure series are aligned and have no NaNs affecting calculation
|
|
combined = pd.concat([target_series.rename(target_name), exog_series.rename(exog_name)], axis=1).dropna()
|
|
|
|
# Check if we have enough data points
|
|
if combined.empty or len(combined) <= max_lags:
|
|
err = f"Not enough overlapping non-NaN data points between {target_name} and {exog_name} for CCF calculation (need > {max_lags})."
|
|
# Will warn above!
|
|
# logger.warning(err)
|
|
return err
|
|
|
|
# Check if the exogenous variable actually varies
|
|
if exog_series.nunique() <= 1:
|
|
err = f"Cannot compute cross-correlation: {exog_name} has no variation (all values are the same)."
|
|
# Will warn above!
|
|
# logger.warning(err)
|
|
return err
|
|
|
|
# Calculate CCF: ccf(x, y) computes corr(x[t], y[t-lag])
|
|
# We want corr(target[t], exog[t-lag]), so order is ccf(target, exog)
|
|
cross_corr_values = ccf(combined[target_name], combined[exog_name], adjusted=False, nlags=max_lags)
|
|
lags_range = range(max_lags + 1) # CCF includes lag 0
|
|
|
|
# Plotting
|
|
fig, ax = plt.subplots()
|
|
markerline, stemlines, baseline = ax.stem(
|
|
lags_range, cross_corr_values, markerfmt='o', basefmt="gray"
|
|
)
|
|
plt.setp(markerline, markersize=5)
|
|
plt.setp(stemlines, linewidth=1)
|
|
|
|
# Add approximate 95% confidence intervals (Bartlett's formula approximation)
|
|
conf_level = 1.96 / np.sqrt(len(combined))
|
|
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
|
|
|
|
ax.set_title(f'Cross-Correlation: {target_name}(t) vs {exog_name}(t-lag)')
|
|
ax.set_xlabel('Lag (k)')
|
|
ax.set_ylabel(f'Corr({target_name}(t), {exog_name}(t-k))')
|
|
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting cross-correlation ({target_name} vs {exog_name}): {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err
|
|
|
|
def plot_weekly_autocorrelation(
|
|
series: pd.Series,
|
|
series_name: str,
|
|
output_path: Path,
|
|
max_weeks: int = 4
|
|
) -> Optional[str]:
|
|
"""
|
|
Generates and saves an autocorrelation plot between a series and its weekly lags.
|
|
This helps identify weekly seasonality patterns.
|
|
|
|
Args:
|
|
series: The time series to analyze
|
|
series_name: Name of the series for labeling
|
|
output_path: Where to save the plot
|
|
max_weeks: Maximum number of weeks to look back (default: 4)
|
|
|
|
Returns:
|
|
Optional[str]: Error message if something went wrong, None if successful
|
|
"""
|
|
logger.info(f"Generating weekly autocorrelation plot for {series_name} up to {max_weeks} weeks to {output_path}")
|
|
err = None
|
|
try:
|
|
# Ensure series has no NaNs
|
|
series = series.dropna()
|
|
if series.empty:
|
|
err = f"Series {series_name} is empty after dropping NaNs."
|
|
logger.warning(err)
|
|
return err
|
|
|
|
# Calculate weekly lags (168 hours = 1 week)
|
|
hours_per_week = 24 * 7
|
|
max_lags = max_weeks * hours_per_week
|
|
|
|
# Calculate autocorrelation
|
|
autocorr_values = ccf(series, series, adjusted=False, nlags=max_lags)
|
|
lags_range = list(range(0, min(max_lags + 1, autocorr_values.size - 1), hours_per_week)) # Only plot weekly intervals
|
|
|
|
# Plotting
|
|
fig, ax = plt.subplots()
|
|
markerline, stemlines, baseline = ax.stem(
|
|
[lag/hours_per_week for lag in lags_range], # Convert to weeks for x-axis
|
|
autocorr_values[lags_range],
|
|
markerfmt='o',
|
|
basefmt="gray"
|
|
)
|
|
plt.setp(markerline, markersize=5)
|
|
plt.setp(stemlines, linewidth=1)
|
|
|
|
# Add approximate 95% confidence intervals
|
|
conf_level = 1.96 / np.sqrt(len(series))
|
|
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
|
|
|
|
ax.set_title(f'Weekly Autocorrelation: {series_name}')
|
|
ax.set_xlabel('Lag (weeks)')
|
|
ax.set_ylabel(f'Corr({series_name}(t), {series_name}(t-lag))')
|
|
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
|
|
|
|
# Add vertical lines at each week
|
|
for week in range(max_weeks + 1):
|
|
ax.axvline(x=week, color='gray', linestyle=':', alpha=0.3)
|
|
|
|
err = _save_plot(fig, output_path)
|
|
except Exception as e:
|
|
err = f"Error plotting weekly autocorrelation for {series_name}: {e}"
|
|
logger.error(err, exc_info=True)
|
|
plt.close(fig)
|
|
return err |