diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index b420dbd..3633a18 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -99,6 +99,11 @@ cross_validation <- function( 'is larger than initial window. Consider increasing initial.')) } } + + predict_columns <- c('ds', 'yhat') + if (model$uncertainty_samples){ + predict_columns <- append(predict_columns, c('yhat_lower', 'yhat_upper')) + } cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt) @@ -133,8 +138,8 @@ cross_validation <- function( future <- df.predict[columns] yhat <- stats::predict(m, future) # Merge yhat, y, and cutoff. - df.c <- dplyr::inner_join(df.predict, yhat, by = "ds") - df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper) + df.c <- dplyr::inner_join(df.predict, yhat$predict_columns, by = "ds") + df.c <- dplyr::select(df.c, y, yhat$predict_columns) df.c$cutoff <- cutoff predicts <- rbind(predicts, df.c) } @@ -235,6 +240,10 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { if (is.null(metrics)) { metrics <- valid_metrics } + if (!('yhat_lower' %in% df) | (!('yhat_upper' %in% df)) & ('coverage' %in% metrics)){ + metrics <- valid_metrics[valid_metrics != 'coverage'] + } + if (length(metrics) != length(unique(metrics))) { stop('Input metrics must be an array of unique values.') }