mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-03 23:49:47 +00:00
adapt cross validation to a model with 0 uncertainty samples
This commit is contained in:
parent
a6a1381a0a
commit
7f214f2bc7
1 changed files with 7 additions and 1 deletions
|
|
@ -101,6 +101,10 @@ def cross_validation(model, horizon, period=None, initial=None):
|
|||
msg += 'Consider increasing initial.'
|
||||
logger.warning(msg)
|
||||
|
||||
predict_columns = ['ds', 'yhat']
|
||||
if model.uncertainty_samples:
|
||||
predict_columns.extend(['yhat_lower', 'yhat_upper'])
|
||||
|
||||
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
||||
predicts = []
|
||||
for cutoff in cutoffs:
|
||||
|
|
@ -130,7 +134,7 @@ def cross_validation(model, horizon, period=None, initial=None):
|
|||
yhat = m.predict(df[index_predicted][columns])
|
||||
# Merge yhat(predicts), y(df, original data) and cutoff
|
||||
predicts.append(pd.concat([
|
||||
yhat[['ds', 'yhat', 'yhat_lower', 'yhat_upper']],
|
||||
yhat[predict_columns],
|
||||
df[index_predicted][['y']].reset_index(drop=True),
|
||||
pd.DataFrame({'cutoff': [cutoff] * len(yhat)})
|
||||
], axis=1))
|
||||
|
|
@ -234,6 +238,8 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
|
|||
valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'mdape', 'coverage']
|
||||
if metrics is None:
|
||||
metrics = valid_metrics
|
||||
if ('yhat_lower' not in df) or ('yhat_upper' not in df) and ('coverage' in metrics):
|
||||
metrics.remove('coverage')
|
||||
if len(set(metrics)) != len(metrics):
|
||||
raise ValueError('Input metrics must be a list of unique values')
|
||||
if not set(metrics).issubset(set(valid_metrics)):
|
||||
|
|
|
|||
Loading…
Reference in a new issue