From 4fcecdb3dff5b0c22d66b11d499cd8cc7d068033 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Mon, 3 Feb 2020 16:18:46 -0800 Subject: [PATCH] Fix bugs that were introduced into R cross validation --- R/R/diagnostics.R | 7 ++++--- R/tests/testthat/test_diagnostics.R | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index a2ea766..51b46ac 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -138,8 +138,9 @@ cross_validation <- function( future <- df.predict[columns] yhat <- stats::predict(m, future) # Merge yhat, y, and cutoff. - df.c <- dplyr::inner_join(df.predict, yhat$predict_columns, by = "ds") - df.c <- dplyr::select(df.c, y, yhat$predict_columns) + df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds") + df.c <- df.c[c(predict_columns, "y")] + df.c <- dplyr::select(df.c, y, predict_columns) df.c$cutoff <- cutoff predicts <- rbind(predicts, df.c) } @@ -240,7 +241,7 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { if (is.null(metrics)) { metrics <- valid_metrics } - if (!('yhat_lower' %in% df) | (!('yhat_upper' %in% df)) & ('coverage' %in% metrics)){ + if (!('yhat_lower' %in% colnames(df)) | (!('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){ metrics <- valid_metrics[valid_metrics != 'coverage'] } diff --git a/R/tests/testthat/test_diagnostics.R b/R/tests/testthat/test_diagnostics.R index 29f7d64..757acc4 100644 --- a/R/tests/testthat/test_diagnostics.R +++ b/R/tests/testthat/test_diagnostics.R @@ -26,7 +26,7 @@ test_that("cross_validation", { m, horizon = 4, units = "days", period = 10, initial = 115) expect_equal(length(unique(df.cv$cutoff)), 3) expect_equal(max(df.cv$ds - df.cv$cutoff), horizon) - expect_true(min(df.cv$cutoff) >= ts + initial) + expect_true(as.Date(min(df.cv$cutoff)) >= ts + initial) dc <- diff(df.cv$cutoff) dc <- min(dc[dc > 0]) expect_true(dc >= period) @@ -131,7 +131,7 @@ test_that("performance_metrics", { expect_null(df_horizon) # List of metrics containing non valid metrics expect_error( - performance_metrics(df, metrics = c('mse', 'error_metric')), + performance_metrics(df_cv, metrics = c('mse', 'error_metric')), 'Valid values for metrics are: mse, rmse, mae, mape, coverage' ) })