From 8d0f23f8beeaa4195f38613c7fcbcdda9aa2c602 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Tue, 4 Feb 2020 13:05:28 -0800 Subject: [PATCH] Add unit tests for disabling uncertainty estimation in R --- R/tests/testthat/test_diagnostics.R | 12 ++++++++++++ R/tests/testthat/test_prophet.R | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/R/tests/testthat/test_diagnostics.R b/R/tests/testthat/test_diagnostics.R index 757acc4..9e8a9f6 100644 --- a/R/tests/testthat/test_diagnostics.R +++ b/R/tests/testthat/test_diagnostics.R @@ -90,6 +90,18 @@ test_that("cross_validation_default_value_check", { expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0) }) +test_that("cross_validation_uncertainty_disabled", { + skip_if_not(Sys.getenv('R_ARCH') != '/i386') + for (uncertainty in c(0, FALSE)) { + m <- prophet(uncertainty.samples = uncertainty) + m <- fit.prophet(m = m, df = DATA, algorithm = "Newton") + df.cv <- cross_validation( + m, horizon = 4, units = "days", period = 4, initial = 115) + expected.cols <- c('y', 'ds', 'yhat', 'cutoff') + expect_equal(expected.cols, colnames(df.cv)) + } +}) + test_that("performance_metrics", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') m <- prophet(DATA) diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 8dd6288..ae90b79 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -73,6 +73,16 @@ test_that("fit_predict_constant_history", { expect_equal(tail(fcst$yhat, 1), 0) }) +test_that("fit_predict_uncertainty_disabled", { + skip_if_not(Sys.getenv('R_ARCH') != '/i386') + for (uncertainty in c(0, FALSE)) { + m <- prophet(train, uncertainty.samples = uncertainty) + fcst <- predict(m, future) + expected.cols <- c('ds', 'trend', 'additive_terms', 'weekly', 'multiplicative_terms', 'yhat') + expect_equal(expected.cols, colnames(fcst)) + } +}) + test_that("setup_dataframe", { history <- train m <- prophet(history, fit = FALSE)