Files
entrix_case_challange/data_analysis/io/plotting.py
2025-05-02 10:45:06 +02:00

398 lines
16 KiB
Python

import logging
from pathlib import Path
import pandas as pd
import numpy as np # Import numpy for CI calculation
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, List
# Import analysis tools for plotting results
from statsmodels.tsa.seasonal import DecomposeResult
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf, seasonal_plot
from statsmodels.tsa.stattools import ccf # Import ccf
logger = logging.getLogger(__name__)
# --- Plotting Configuration ---
# Increase default figure size for better readability
plt.rcParams['figure.figsize'] = (15, 7)
# Use a clean style
plt.style.use('seaborn-v0_8-whitegrid')
def _save_plot(fig: plt.Figure, output_path: Path) -> Optional[str]:
"""Helper to save plots and handle errors."""
err = None
try:
fig.tight_layout() # Adjust layout before saving
fig.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Plot saved to: {output_path}")
plt.close(fig) # Close the figure to free memory
except Exception as e:
err = f"Failed to save plot to {output_path}: {e}"
logger.error(err, exc_info=True)
plt.close(fig) # Still try to close figure on error
return err
def plot_full_time_series(df: pd.DataFrame, price_col: str, output_path: Path) -> Optional[str]:
"""Plots the entire time series."""
logger.info(f"Generating full time series plot to {output_path}")
fig, ax = plt.subplots()
err = None
try:
sns.lineplot(data=df, x=df.index, y=price_col, ax=ax, linewidth=1)
ax.set_title('Full Time Series: Price Over Time')
ax.set_xlabel('Time')
ax.set_ylabel(price_col)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting full time series: {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
def plot_zoomed_time_series(df: pd.DataFrame, price_col: str, start_date: str, end_date: str, output_path: Path) -> Optional[str]:
"""Plots a specified time range of the series."""
logger.info(f"Generating zoomed time series plot ({start_date} to {end_date}) to {output_path}")
fig, ax = plt.subplots()
err = None
try:
# Ensure start_date and end_date are compatible with index type
df_zoomed = df.loc[start_date:end_date]
if df_zoomed.empty:
err = f"No data found in the specified zoom range: {start_date} to {end_date}"
logger.warning(err) # Use warning for empty range, not necessarily error
plt.close(fig)
return err
sns.lineplot(data=df_zoomed, x=df_zoomed.index, y=price_col, ax=ax, linewidth=1)
ax.set_title(f'Time Series: {start_date} to {end_date}')
ax.set_xlabel('Time')
ax.set_ylabel(price_col)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting zoomed time series: {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
def plot_boxplot_by_period(df: pd.DataFrame, price_col: str, period: str, output_path: Path) -> Optional[str]:
"""
Generates box plots of the price grouped by a specific time period.
Periods: 'hour', 'dayofweek', 'month', 'year'.
"""
logger.info(f"Generating box plot by {period} to {output_path}")
fig, ax = plt.subplots()
err = None
try:
# Create temporary column for the period
if period == 'hour':
group_col = df.index.hour
title = 'Price Distribution by Hour of Day'
x_label = 'Hour'
elif period == 'dayofweek':
group_col = df.index.dayofweek # Monday=0, Sunday=6
title = 'Price Distribution by Day of Week'
x_label = 'Day of Week (0=Mon, 6=Sun)'
elif period == 'month':
group_col = df.index.month
title = 'Price Distribution by Month'
x_label = 'Month'
elif period == 'year':
group_col = df.index.year
title = 'Price Distribution by Year'
x_label = 'Year'
else:
err = f"Unsupported period '{period}' for boxplot."
logger.error(err)
plt.close(fig)
return err
# Ensure group_col is numeric or categorical for plotting
sns.boxplot(x=group_col, y=df[price_col], ax=ax, palette="viridis", hue=group_col)
ax.set_title(title)
ax.set_xlabel(x_label)
ax.set_ylabel(price_col)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting boxplot by {period}: {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
# New function signature for seasonal subseries plot
def plot_seasonal_subseries(df: pd.DataFrame, price_col: str, period: int, period_name: str, output_path: Path) -> Optional[str]:
"""
Generates a seasonal subseries plot for a given period (e.g., 24 for daily).
"""
logger.info(f"Generating seasonal subseries plot for {period_name} (period={period}) to {output_path}")
err = None
try:
# Ensure the index is datetime and frequency is set or can be inferred
if not isinstance(df.index, pd.DatetimeIndex):
err = "DataFrame index must be a DatetimeIndex for seasonal subseries plot."
logger.error(err)
return err
# Create the appropriate grouping based on the period
if period == 24: # Daily
grouped = df[price_col].groupby(df.index.hour)
xticklabels = [f"{i:02d}:00" for i in range(24)]
elif period == 168: # Weekly
grouped = df[price_col].groupby(df.index.dayofweek)
xticklabels = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
else:
# For other periods, create a custom grouping
grouped = df[price_col].groupby(df.index % period)
xticklabels = [str(i) for i in range(period)]
# Create the plot using seasonal_plot
fig = seasonal_plot(grouped, xticklabels=xticklabels, ylabel=price_col)
fig.suptitle(f'Seasonal Subseries Plot ({period_name})', y=1.02)
fig.set_size_inches(15, 10)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting seasonal subseries ({period_name}): {e}"
logger.error(err, exc_info=True)
plt.close('all')
return err
def plot_histogram(df: pd.DataFrame, price_col: str, output_path: Path, bins: int = 50) -> Optional[str]:
"""Plots a histogram of the price values."""
logger.info(f"Generating histogram of '{price_col}' to {output_path}")
fig, ax = plt.subplots()
err = None
try:
sns.histplot(data=df, x=price_col, bins=bins, kde=True, ax=ax)
ax.set_title(f'Distribution of {price_col}')
ax.set_xlabel(price_col)
ax.set_ylabel('Frequency')
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting histogram: {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
def plot_decomposition(decomposition_result: DecomposeResult, period_name: str, output_path: Path) -> Optional[str]:
"""
Plots the observed, trend, seasonal, and residual components from a
time series decomposition result.
"""
logger.info(f"Generating {period_name} decomposition plot to {output_path}")
err = None
try:
# The plot method of DecomposeResult returns a Figure
fig = decomposition_result.plot()
fig.set_size_inches(15, 10) # Adjust size for better visibility
fig.suptitle(f'Time Series Decomposition ({period_name} Seasonality)', y=1.02)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting decomposition ({period_name}): {e}"
logger.error(err, exc_info=True)
# No access to the fig object if decomposition_result.plot() fails early
# Close all figures as a fallback
plt.close('all')
return err
def plot_residuals(residuals: pd.Series, title_suffix: str, output_path: Path) -> Optional[str]:
"""Plots the residuals over time."""
logger.info(f"Generating residuals plot ({title_suffix}) to {output_path}")
fig, ax = plt.subplots()
err = None
try:
residuals.plot(ax=ax, title=f'Residuals ({title_suffix})')
ax.set_xlabel('Time')
ax.set_ylabel('Residual Value')
# Add a horizontal line at zero
ax.axhline(0, color='r', linestyle='--', alpha=0.7)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting residuals ({title_suffix}): {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
def plot_acf_pacf(series: pd.Series, series_name: str, lags: int | None, output_path_base: Path) -> Optional[str]:
"""
Plots the Autocorrelation Function (ACF) and Partial Autocorrelation
Function (PACF) for a given series, saving them as separate files.
"""
logger.info(f"Generating ACF/PACF plots for {series_name} to {output_path_base.parent}")
err_acf = None
err_pacf = None
# Plot ACF
try:
fig_acf = plt.figure()
ax_acf = fig_acf.add_subplot(111)
plot_acf(series, lags=lags, ax=ax_acf, title=f'ACF - {series_name}')
acf_path = output_path_base.with_name(f"{output_path_base.stem}_acf.png")
err_acf = _save_plot(fig_acf, acf_path)
except Exception as e:
err_acf = f"Error plotting ACF for {series_name}: {e}"
logger.error(err_acf, exc_info=True)
plt.close(fig_acf)
# Plot PACF
try:
fig_pacf = plt.figure()
ax_pacf = fig_pacf.add_subplot(111)
# Use method='ywm' for Yule-Walker method, often preferred
plot_pacf(series, lags=lags, ax=ax_pacf, title=f'PACF - {series_name}', method='ywm')
pacf_path = output_path_base.with_name(f"{output_path_base.stem}_pacf.png")
err_pacf = _save_plot(fig_pacf, pacf_path)
except Exception as e:
err_pacf = f"Error plotting PACF for {series_name}: {e}"
logger.error(err_pacf, exc_info=True)
plt.close(fig_pacf)
# Return the first error encountered, or None if both succeeded
return err_acf or err_pacf
# Update cross-correlation plot function
def plot_cross_correlation(
target_series: pd.Series,
exog_series: pd.Series,
target_name: str,
exog_name: str,
max_lags: int,
output_path: Path
) -> Optional[str]:
"""
Generates and saves a cross-correlation plot between a target series and an exogenous series.
Plots correlation of target_series(t) with exog_series(t-lag).
Args:
target_series: The main time series to analyze
exog_series: The exogenous time series to correlate with
target_name: Name of the target series for labeling
exog_name: Name of the exogenous series for labeling
max_lags: Maximum number of lags to compute correlation for
output_path: Where to save the plot
Returns:
Optional[str]: Error message if something went wrong, None if successful
"""
logger.info(f"Generating cross-correlation plot ({target_name} vs {exog_name}) for lags up to {max_lags} to {output_path}")
err = None
try:
# Ensure series are aligned and have no NaNs affecting calculation
combined = pd.concat([target_series.rename(target_name), exog_series.rename(exog_name)], axis=1).dropna()
# Check if we have enough data points
if combined.empty or len(combined) <= max_lags:
err = f"Not enough overlapping non-NaN data points between {target_name} and {exog_name} for CCF calculation (need > {max_lags})."
# Will warn above!
# logger.warning(err)
return err
# Check if the exogenous variable actually varies
if exog_series.nunique() <= 1:
err = f"Cannot compute cross-correlation: {exog_name} has no variation (all values are the same)."
# Will warn above!
# logger.warning(err)
return err
# Calculate CCF: ccf(x, y) computes corr(x[t], y[t-lag])
# We want corr(target[t], exog[t-lag]), so order is ccf(target, exog)
cross_corr_values = ccf(combined[target_name], combined[exog_name], adjusted=False, nlags=max_lags)
lags_range = range(max_lags + 1) # CCF includes lag 0
# Plotting
fig, ax = plt.subplots()
markerline, stemlines, baseline = ax.stem(
lags_range, cross_corr_values, markerfmt='o', basefmt="gray"
)
plt.setp(markerline, markersize=5)
plt.setp(stemlines, linewidth=1)
# Add approximate 95% confidence intervals (Bartlett's formula approximation)
conf_level = 1.96 / np.sqrt(len(combined))
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
ax.set_title(f'Cross-Correlation: {target_name}(t) vs {exog_name}(t-lag)')
ax.set_xlabel('Lag (k)')
ax.set_ylabel(f'Corr({target_name}(t), {exog_name}(t-k))')
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting cross-correlation ({target_name} vs {exog_name}): {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err
def plot_weekly_autocorrelation(
series: pd.Series,
series_name: str,
output_path: Path,
max_weeks: int = 4
) -> Optional[str]:
"""
Generates and saves an autocorrelation plot between a series and its weekly lags.
This helps identify weekly seasonality patterns.
Args:
series: The time series to analyze
series_name: Name of the series for labeling
output_path: Where to save the plot
max_weeks: Maximum number of weeks to look back (default: 4)
Returns:
Optional[str]: Error message if something went wrong, None if successful
"""
logger.info(f"Generating weekly autocorrelation plot for {series_name} up to {max_weeks} weeks to {output_path}")
err = None
try:
# Ensure series has no NaNs
series = series.dropna()
if series.empty:
err = f"Series {series_name} is empty after dropping NaNs."
logger.warning(err)
return err
# Calculate weekly lags (168 hours = 1 week)
hours_per_week = 24 * 7
max_lags = max_weeks * hours_per_week
# Calculate autocorrelation
autocorr_values = ccf(series, series, adjusted=False, nlags=max_lags)
lags_range = list(range(0, min(max_lags + 1, autocorr_values.size - 1), hours_per_week)) # Only plot weekly intervals
# Plotting
fig, ax = plt.subplots()
markerline, stemlines, baseline = ax.stem(
[lag/hours_per_week for lag in lags_range], # Convert to weeks for x-axis
autocorr_values[lags_range],
markerfmt='o',
basefmt="gray"
)
plt.setp(markerline, markersize=5)
plt.setp(stemlines, linewidth=1)
# Add approximate 95% confidence intervals
conf_level = 1.96 / np.sqrt(len(series))
ax.axhspan(-conf_level, conf_level, alpha=0.2, color='blue', zorder=0)
ax.set_title(f'Weekly Autocorrelation: {series_name}')
ax.set_xlabel('Lag (weeks)')
ax.set_ylabel(f'Corr({series_name}(t), {series_name}(t-lag))')
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
# Add vertical lines at each week
for week in range(max_weeks + 1):
ax.axvline(x=week, color='gray', linestyle=':', alpha=0.3)
err = _save_plot(fig, output_path)
except Exception as e:
err = f"Error plotting weekly autocorrelation for {series_name}: {e}"
logger.error(err, exc_info=True)
plt.close(fig)
return err