adapt cross validation to a model with 0 uncertainty samples

This commit is contained in:
Alexander Gawrilow 2019-07-31 15:05:36 +02:00 committed by Ben Letham
parent a6a1381a0a
commit 7f214f2bc7

View file

@ -101,6 +101,10 @@ def cross_validation(model, horizon, period=None, initial=None):
msg += 'Consider increasing initial.'
logger.warning(msg)
predict_columns = ['ds', 'yhat']
if model.uncertainty_samples:
predict_columns.extend(['yhat_lower', 'yhat_upper'])
cutoffs = generate_cutoffs(df, horizon, initial, period)
predicts = []
for cutoff in cutoffs:
@ -130,7 +134,7 @@ def cross_validation(model, horizon, period=None, initial=None):
yhat = m.predict(df[index_predicted][columns])
# Merge yhat(predicts), y(df, original data) and cutoff
predicts.append(pd.concat([
yhat[['ds', 'yhat', 'yhat_lower', 'yhat_upper']],
yhat[predict_columns],
df[index_predicted][['y']].reset_index(drop=True),
pd.DataFrame({'cutoff': [cutoff] * len(yhat)})
], axis=1))
@ -234,6 +238,8 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage']
if metrics is None:
metrics = valid_metrics
if ('yhat_lower' not in df) or ('yhat_upper' not in df) and ('coverage' in metrics):
metrics.remove('coverage')
if len(set(metrics)) != len(metrics):
raise ValueError('Input metrics must be a list of unique values')
if not set(metrics).issubset(set(valid_metrics)):