From 7f214f2bc764e5a2de5ae3ddf9de93346829e9fd Mon Sep 17 00:00:00 2001 From: Alexander Gawrilow Date: Wed, 31 Jul 2019 15:05:36 +0200 Subject: [PATCH] adapt cross validation to a model with 0 uncertainty samples --- python/fbprophet/diagnostics.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index f8d8df4..d5dffc6 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -101,6 +101,10 @@ def cross_validation(model, horizon, period=None, initial=None): msg += 'Consider increasing initial.' logger.warning(msg) + predict_columns = ['ds', 'yhat'] + if model.uncertainty_samples: + predict_columns.extend(['yhat_lower', 'yhat_upper']) + cutoffs = generate_cutoffs(df, horizon, initial, period) predicts = [] for cutoff in cutoffs: @@ -130,7 +134,7 @@ def cross_validation(model, horizon, period=None, initial=None): yhat = m.predict(df[index_predicted][columns]) # Merge yhat(predicts), y(df, original data) and cutoff predicts.append(pd.concat([ - yhat[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], + yhat[predict_columns], df[index_predicted][['y']].reset_index(drop=True), pd.DataFrame({'cutoff': [cutoff] * len(yhat)}) ], axis=1)) @@ -234,6 +238,8 @@ def performance_metrics(df, metrics=None, rolling_window=0.1): valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage'] if metrics is None: metrics = valid_metrics + if ('yhat_lower' not in df) or ('yhat_upper' not in df) and ('coverage' in metrics): + metrics.remove('coverage') if len(set(metrics)) != len(metrics): raise ValueError('Input metrics must be a list of unique values') if not set(metrics).issubset(set(valid_metrics)):