Adapt model cv with 0 uncertainty samples and fix performance metrics

This commit is contained in:
Ryan Nazareth 2019-10-03 17:10:37 +01:00 committed by Ben Letham
parent e1a2b9c297
commit eb797eefaa

View file

@ -99,6 +99,11 @@ cross_validation <- function(
'is larger than initial window. Consider increasing initial.'))
}
}
predict_columns <- c('ds', 'yhat')
if (model$uncertainty_samples){
predict_columns <- append(predict_columns, c('yhat_lower', 'yhat_upper'))
}
cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt)
@ -133,8 +138,8 @@ cross_validation <- function(
future <- df.predict[columns]
yhat <- stats::predict(m, future)
# Merge yhat, y, and cutoff.
df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
df.c <- dplyr::inner_join(df.predict, yhat$predict_columns, by = "ds")
df.c <- dplyr::select(df.c, y, yhat$predict_columns)
df.c$cutoff <- cutoff
predicts <- rbind(predicts, df.c)
}
@ -235,6 +240,10 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
if (is.null(metrics)) {
metrics <- valid_metrics
}
if (!('yhat_lower' %in% df) | (!('yhat_upper' %in% df)) & ('coverage' %in% metrics)){
metrics <- valid_metrics[valid_metrics != 'coverage']
}
if (length(metrics) != length(unique(metrics))) {
stop('Input metrics must be an array of unique values.')
}