Fix bugs that were introduced into R cross validation

This commit is contained in:
Ben Letham 2020-02-03 16:18:46 -08:00
parent aeaf0ad2e1
commit 4fcecdb3df
2 changed files with 6 additions and 5 deletions

View file

@ -138,8 +138,9 @@ cross_validation <- function(
future <- df.predict[columns]
yhat <- stats::predict(m, future)
# Merge yhat, y, and cutoff.
df.c <- dplyr::inner_join(df.predict, yhat$predict_columns, by = "ds")
df.c <- dplyr::select(df.c, y, yhat$predict_columns)
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
df.c <- df.c[c(predict_columns, "y")]
df.c <- dplyr::select(df.c, y, predict_columns)
df.c$cutoff <- cutoff
predicts <- rbind(predicts, df.c)
}
@ -240,7 +241,7 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
if (is.null(metrics)) {
metrics <- valid_metrics
}
if (!('yhat_lower' %in% df) | (!('yhat_upper' %in% df)) & ('coverage' %in% metrics)){
if (!('yhat_lower' %in% colnames(df)) | (!('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){
metrics <- valid_metrics[valid_metrics != 'coverage']
}

View file

@ -26,7 +26,7 @@ test_that("cross_validation", {
m, horizon = 4, units = "days", period = 10, initial = 115)
expect_equal(length(unique(df.cv$cutoff)), 3)
expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
expect_true(min(df.cv$cutoff) >= ts + initial)
expect_true(as.Date(min(df.cv$cutoff)) >= ts + initial)
dc <- diff(df.cv$cutoff)
dc <- min(dc[dc > 0])
expect_true(dc >= period)
@ -131,7 +131,7 @@ test_that("performance_metrics", {
expect_null(df_horizon)
# List of metrics containing non valid metrics
expect_error(
performance_metrics(df, metrics = c('mse', 'error_metric')),
performance_metrics(df_cv, metrics = c('mse', 'error_metric')),
'Valid values for metrics are: mse, rmse, mae, mape, coverage'
)
})