From 23d8bc25dcbbb07e9f76c7a5c26e70e9cde2c408 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Wed, 30 Aug 2017 17:10:43 -0700 Subject: [PATCH] R unit tests for added regressors --- R/R/prophet.R | 4 +- R/tests/testthat/test_prophet.R | 67 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index b001fdb..b024c1f 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -625,8 +625,8 @@ make_all_seasonality_features <- function(m, df) { # Additional regressors for (name in names(m$extra_regressors)) { - seasonal.features <- cbind(seasonal.features, df[[name]]) - prior.scales <- cbind(prior.scales, m$extra_regressors[[name]]$prior.scale) + seasonal.features[[name]] <- df[[name]] + prior.scales <- c(prior.scales, m$extra_regressors[[name]]$prior.scale) } if (ncol(seasonal.features) == 0) { diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 38be545..149731e 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -329,6 +329,73 @@ test_that("custom_seasonality", { m <- prophet(holidays=holidays) m <- add_seasonality(m, name='monthly', period=30, fourier.order=5) expect_equal(m$seasonalities[['monthly']], c(30, 5)) + expect_error( + add_seasonality(m, name='special_day', period=30, fourier_order=5) + ) + expect_error( + add_seasonality(m, name='trend', period=30, fourier_order=5) + ) + m <- add_seasonality(m, name='weekly', period=30, fourier.order=5) +}) + +test_that("added_regressors", { + skip_if_not(Sys.getenv('R_ARCH') != '/i386') + m <- prophet() + m <- add_regressor(m, 'binary_feature', prior.scale=0.2) + m <- add_regressor(m, 'numeric_feature', prior.scale=0.5) + m <- add_regressor(m, 'binary_feature2', standardize=TRUE) + df <- DATA + df$binary_feature <- c(rep(0, 255), rep(1, 255)) + df$numeric_feature <- 0:509 + # Require all regressors in df + expect_error( + fit.prophet(m, df) + ) + df$binary_feature2 <- c(rep(1, 100), rep(0, 410)) + m <- fit.prophet(m, df) + # Check that standardizations are correctly set + true <- list(prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto') + for (name in names(true)) { + expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]]) + } + true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585) + for (name in names(true)) { + expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]], + tolerance = 1e-5) + } + true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183) + for (name in names(true)) { + expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]], + tolerance = 1e-5) + } + # Check that standardization is done correctly + df2 <- prophet:::setup_dataframe(m, df)$df + expect_equal(df2$binary_feature[1], 0) + expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4) + expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4) + # Check that feature matrix and prior scales are correctly constructed + out <- prophet:::make_all_seasonality_features(m, df2) + seasonal.features <- out$seasonal.features + prior.scales <- out$prior.scales + expect_true('binary_feature' %in% colnames(seasonal.features)) + expect_true('numeric_feature' %in% colnames(seasonal.features)) + expect_true('binary_feature2' %in% colnames(seasonal.features)) + expect_equal(ncol(seasonal.features), 29) + expect_true(all(sort(prior.scales[27:29]) == c(0.2, 0.5, 10.))) + # Check that forecast components are reasonable + future <- data.frame( + ds = c('2014-06-01'), binary_feature = c(0), numeric_feature = c(10)) + expect_error(predict(m, future)) + future$binary_feature2 <- 0. + fcst <- predict(m, future) + expect_equal(ncol(fcst), 31) + expect_equal(fcst$binary_feature[1], 0) + expect_equal(fcst$extra_regressors[1], + fcst$numeric_feature[1] + fcst$binary_feature2[1]) + expect_equal(fcst$seasonalities[1], fcst$yearly[1] + fcst$weekly[1]) + expect_equal(fcst$seasonal[1], + fcst$seasonalities[1] + fcst$extra_regressors[1]) + expect_equal(fcst$yhat[1], fcst$trend[1] + fcst$seasonal[1]) }) test_that("copy", {