From 0addabcad7953b1620c4d46ab5f0edc7e0fa4fbe Mon Sep 17 00:00:00 2001 From: bletham Date: Sat, 4 Nov 2017 21:40:45 -0700 Subject: [PATCH] Fix copy with extra seasonalities / regressors R --- R/R/prophet.R | 21 +++++++++---- R/tests/testthat/test_prophet.R | 56 ++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 506615a..219d86c 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -1701,6 +1701,10 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) { #' #' @keywords internal prophet_copy <- function(m, cutoff = NULL) { + if (is.null(m$history)) { + stop("This is for copying a fitted Prophet object.") + } + if (m$specified.changepoints) { changepoints <- m$changepoints if (!is.null(cutoff)) { @@ -1710,13 +1714,15 @@ prophet_copy <- function(m, cutoff = NULL) { } else { changepoints <- NULL } - return(prophet( + # Auto seasonalities are set to FALSE because they are already set in + # m$seasonalities. + m2 <- prophet( growth = m$growth, changepoints = changepoints, n.changepoints = m$n.changepoints, - yearly.seasonality = m$yearly.seasonality, - weekly.seasonality = m$weekly.seasonality, - daily.seasonality = m$daily.seasonality, + yearly.seasonality = FALSE, + weekly.seasonality = FALSE, + daily.seasonality = FALSE, holidays = m$holidays, seasonality.prior.scale = m$seasonality.prior.scale, changepoint.prior.scale = m$changepoint.prior.scale, @@ -1724,8 +1730,11 @@ prophet_copy <- function(m, cutoff = NULL) { mcmc.samples = m$mcmc.samples, interval.width = m$interval.width, uncertainty.samples = m$uncertainty.samples, - fit = FALSE, - )) + fit = FALSE + ) + m2$extra_regressors <- m$extra_regressors + m2$seasonalities <- m$seasonalities + return(m2) } # fb-block 3 diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index ad72585..351082f 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -520,20 +520,15 @@ test_that("added_regressors", { test_that("copy", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') + df <- DATA + df$cap <- 200. + df$binary_feature <- c(rep(0, 255), rep(1, 255)) inputs <- list( growth = c('linear', 'logistic'), - changepoints = c(NULL, c('2016-12-25')), - n.changepoints = c(3), yearly.seasonality = c(TRUE, FALSE), weekly.seasonality = c(TRUE, FALSE), daily.seasonality = c(TRUE, FALSE), - holidays = c(NULL, 'insert_dataframe'), - seasonality.prior.scale = c(1.1), - holidays.prior.scale = c(1.1), - changepoints.prior.scale = c(0.1), - mcmc.samples = c(100), - interval.width = c(0.9), - uncertainty.samples = c(200) + holidays = c('null', 'insert_dataframe') ) products <- expand.grid(inputs) for (i in 1:length(products)) { @@ -543,32 +538,51 @@ test_that("copy", { holidays <- NULL } m1 <- prophet( - growth = products$growth[i], - changepoints = products$changepoints[i], - n.changepoints = products$n.changepoints[i], + growth = as.character(products$growth[i]), + changepoints = NULL, + n.changepoints = 3, yearly.seasonality = products$yearly.seasonality[i], weekly.seasonality = products$weekly.seasonality[i], daily.seasonality = products$daily.seasonality[i], holidays = holidays, - seasonality.prior.scale = products$seasonality.prior.scale[i], - holidays.prior.scale = products$holidays.prior.scale[i], - changepoints.prior.scale = products$changepoints.prior.scale[i], - mcmc.samples = products$mcmc.samples[i], - interval.width = products$interval.width[i], - uncertainty.samples = products$uncertainty.samples[i], + seasonality.prior.scale = 1.1, + holidays.prior.scale = 1.1, + changepoints.prior.scale = 0.1, + mcmc.samples = 100, + interval.width = 0.9, + uncertainty.samples = 200, fit = FALSE ) + out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE) + m1 <- out$m + m1$history <- out$df + m1 <- prophet:::set_auto_seasonalities(m1) m2 <- prophet:::prophet_copy(m1) # Values should be copied correctly - for (arg in names(inputs)) { + args <- c('growth', 'changepoints', 'n.changepoints', 'holidays', + 'seasonality.prior.scale', 'holidays.prior.scale', + 'changepoints.prior.scale', 'mcmc.samples', 'interval.width', + 'uncertainty.samples') + for (arg in args) { expect_equal(m1[[arg]], m2[[arg]]) } + expect_equal(FALSE, m2$yearly.seasonality) + expect_equal(FALSE, m2$weekly.seasonality) + expect_equal(FALSE, m2$daily.seasonality) + expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities)) + expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities)) + expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities)) } - # Check for cutoff + # Check for cutoff and custom seasonality and extra regressors changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d') cutoff <- as.Date('2012-07-25') - m1 <- prophet(DATA, changepoints = changepoints) + m1 <- prophet(changepoints = changepoints) + m1 <- add_seasonality(m1, 'custom', 10, 5) + m1 <- add_regressor(m1, 'binary_feature') + m1 <- fit.prophet(m1, df) m2 <- prophet:::prophet_copy(m1, cutoff) changepoints <- changepoints[changepoints <= cutoff] expect_equal(prophet:::set_date(changepoints), m2$changepoints) + expect_true('custom' %in% names(m2$seasonalities)) + expect_true('binary_feature' %in% names(m2$extra_regressors)) })