mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-26 22:35:48 +00:00
Update performance_metrics() in diagnostics.py (#1710)
* Update performance_metrics() in diagnostics.py Include smape as a valid metric. Add support for monthly horizons. * Update performance_metrics() docstring * Update performance_metrics() docstring * Include smape in test_diagnostics.py * sMAPE code formatting improvement
This commit is contained in:
parent
0a33f381ba
commit
97bb057de4
2 changed files with 14 additions and 7 deletions
|
|
@ -301,7 +301,7 @@ def prophet_copy(m, cutoff=None):
|
|||
return m2
|
||||
|
||||
|
||||
def performance_metrics(df, metrics=None, rolling_window=0.1):
|
||||
def performance_metrics(df, metrics=None, rolling_window=0.1, monthly=False):
|
||||
"""Compute performance metrics from cross-validation results.
|
||||
|
||||
Computes a suite of performance metrics on the output of cross-validation.
|
||||
|
|
@ -311,6 +311,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
|
|||
'mae': mean absolute error
|
||||
'mape': mean absolute percent error
|
||||
'mdape': median absolute percent error
|
||||
'smape': symmetric mean absolute percentage error
|
||||
'coverage': coverage of the upper and lower intervals
|
||||
|
||||
A subset of these can be specified by passing a list of names as the
|
||||
|
|
@ -337,15 +338,17 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
|
|||
----------
|
||||
df: The dataframe returned by cross_validation.
|
||||
metrics: A list of performance metrics to compute. If not provided, will
|
||||
use ['mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage'].
|
||||
use ['mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage'].
|
||||
rolling_window: Proportion of data to use in each rolling window for
|
||||
computing the metrics. Should be in [0, 1] to average
|
||||
computing the metrics. Should be in [0, 1] to average.
|
||||
monthly: monthly=True will compute horizons as numbers of calendar months
|
||||
from the cutoff date, starting from 0 for the cutoff month.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataframe with a column for each metric, and column 'horizon'
|
||||
"""
|
||||
valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage']
|
||||
valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'mdape', 'smape', 'coverage']
|
||||
if metrics is None:
|
||||
metrics = valid_metrics
|
||||
if ('yhat_lower' not in df or 'yhat_upper' not in df) and ('coverage' in metrics):
|
||||
|
|
@ -357,7 +360,10 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
|
|||
'Valid values for metrics are: {}'.format(valid_metrics)
|
||||
)
|
||||
df_m = df.copy()
|
||||
df_m['horizon'] = df_m['ds'] - df_m['cutoff']
|
||||
if monthly:
|
||||
df_m['horizon'] = df_m['ds'].dt.to_period('M').astype(int) - df_m['cutoff'].dt.to_period('M').astype(int)
|
||||
else:
|
||||
df_m['horizon'] = df_m['ds'] - df_m['cutoff']
|
||||
df_m.sort_values('horizon', inplace=True)
|
||||
if 'mape' in metrics and df_m['y'].abs().min() < 1e-8:
|
||||
logger.info('Skipping MAPE because y close to 0')
|
||||
|
|
@ -590,6 +596,7 @@ def mdape(df, w):
|
|||
|
||||
def smape(df, w):
|
||||
"""Symmetric mean absolute percentage error
|
||||
based on Chen and Yang (2004) formula
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -600,7 +607,7 @@ def smape(df, w):
|
|||
-------
|
||||
Dataframe with columns horizon and smape.
|
||||
"""
|
||||
sape = np.abs(df['yhat']-df['y']) / ((np.abs(df['y']) + np.abs(df['yhat'])) /2)
|
||||
sape = np.abs(df['y'] - df['yhat']) / ((np.abs(df['y']) + np.abs(df['yhat'])) / 2)
|
||||
if w < 0:
|
||||
return pd.DataFrame({'horizon': df['horizon'], 'smape': sape})
|
||||
return rolling_mean_by_h(
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ class TestDiagnostics(TestCase):
|
|||
df_none = diagnostics.performance_metrics(df_cv, rolling_window=-1)
|
||||
self.assertEqual(
|
||||
set(df_none.columns),
|
||||
{'horizon', 'coverage', 'mae', 'mape', 'mdape', 'mse', 'rmse'},
|
||||
{'horizon', 'coverage', 'mae', 'mape', 'mdape', 'mse', 'rmse', 'smape'},
|
||||
)
|
||||
self.assertEqual(df_none.shape[0], 16)
|
||||
# Aggregation level 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue