From e4ec600da47fb7e6210be7cab21f940430ae1a3d Mon Sep 17 00:00:00 2001 From: bletham Date: Sat, 19 Aug 2017 14:03:00 -0700 Subject: [PATCH] Handle constant y in history --- R/R/prophet.R | 10 +++++++++- R/tests/testthat/test_prophet.R | 13 +++++++++++++ python/fbprophet/forecaster.py | 10 +++++++++- python/fbprophet/tests/test_prophet.py | 16 ++++++++++++++++ 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 139b7c9..ca565a0 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -286,6 +286,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) { if (initialize_scales) { m$y.scale <- max(abs(df$y)) + if (m$y.scale == 0) { + m$y.scale <- 1 + } m$start <- min(df$ds) m$t.scale <- time_diff(max(df$ds), m$start, "secs") } @@ -703,7 +706,12 @@ fit.prophet <- function(m, df, ...) { ) } - if (m$mcmc.samples > 0) { + if (min(history$y) == max(history$y)) { + # Nothing to fit. + m$params <- stan_init() + m$params$sigma_obs <- 0. + n.iteration <- 1. + } else if (m$mcmc.samples > 0) { stan.fit <- rstan::sampling( model, data = dat, diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 997ad87..520dd0a 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -46,6 +46,19 @@ test_that("fit_predict_duplicates", { expect_error(predict(m, future), NA) }) +test_that("fit_predict_constant_history", { + skip_if_not(Sys.getenv('R_ARCH') != '/i386') + train2 <- train + train2$y <- 20 + m <- prophet(train2) + fcst <- predict(m, future) + expect_equal(tail(fcst$yhat, 1), 20) + train2$y <- 0 + m <- prophet(train2) + fcst <- predict(m, future) + expect_equal(tail(fcst$yhat, 1), 0) +}) + test_that("setup_dataframe", { history <- train m <- prophet(history, fit = FALSE) diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index 988ad67..235874e 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -227,6 +227,8 @@ class Prophet(object): if initialize_scales: self.y_scale = df['y'].abs().max() + if self.y_scale == 0: + self.y_scale = 1 self.start = df['ds'].min() self.t_scale = df['ds'].max() - self.start for name, props in self.extra_regressors.items(): @@ -726,7 +728,13 @@ class Prophet(object): 'sigma_obs': 1, } - if self.mcmc_samples > 0: + if history['y'].min() == history['y'].max(): + # Nothing to fit. + self.params = stan_init() + self.params['sigma_obs'] = 0. + for par in self.params: + self.params[par] = np.array([self.params[par]]) + elif self.mcmc_samples > 0: stan_fit = model.sampling( dat, init=stan_init, diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 05b88e3..cd31999 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -79,6 +79,22 @@ class TestProphet(TestCase): forecaster.fit(train) forecaster.predict(future) + def test_fit_predict_constant_history(self): + N = DATA.shape[0] + train = DATA.head(N // 2).copy() + train['y'] = 20 + future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)}) + m = Prophet() + m.fit(train) + fcst = m.predict(future) + self.assertEqual(fcst['yhat'].values[-1], 20) + train['y'] = 0 + future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)}) + m = Prophet() + m.fit(train) + fcst = m.predict(future) + self.assertEqual(fcst['yhat'].values[-1], 0) + def test_setup_dataframe(self): m = Prophet() N = DATA.shape[0]