From 0b4ec4a9b3c6d91e2ae70dbc02dd37d25b9758e7 Mon Sep 17 00:00:00 2001 From: bletham Date: Sat, 19 Aug 2017 11:20:53 -0700 Subject: [PATCH] Limit n_changepoints to number of observations. --- R/R/prophet.R | 12 +++++++--- R/tests/testthat/test_prophet.R | 15 ++++++++++++ python/fbprophet/forecaster.py | 32 ++++++++++++++++---------- python/fbprophet/tests/test_prophet.py | 12 ++++++++++ 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 5864087..139b7c9 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -329,10 +329,16 @@ set_changepoints <- function(m) { } } } else { + # Place potential changepoints evenly through the first 80 pcnt of + # the history. + hist.size <- floor(nrow(m$history) * .8) + if (m$n.changepoints + 1 > hist.size) { + m$n.changepoints <- hist.size - 1 + warning('n.changepoints greater than number of observations. Using ', + m$n.changepoints) + } if (m$n.changepoints > 0) { - # Place potential changepoints evenly through the first 80 pcnt of - # the history. - cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8), + cp.indexes <- round(seq.int(1, hist.size, length.out = (m$n.changepoints + 1))[-1]) m$changepoints <- m$history$ds[cp.indexes] } else { diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 4c6ff26..997ad87 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -101,6 +101,21 @@ test_that("get_zero_changepoints", { expect_equal(ncol(mat), 1) }) +test_that("override_n_changepoints", { + history <- train[1:20,] + m <- prophet(history, fit = FALSE) + + out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE) + m <- out$m + history <- out$df + m$history <- history + + m <- prophet:::set_changepoints(m) + expect_equal(m$n.changepoints, 15) + cp <- m$changepoints.t + expect_equal(length(cp), 15) +}) + test_that("fourier_series_weekly", { mat <- prophet:::fourier_series(DATA$ds, 7, 3) true.values <- c(0.9165623, 0.3998920, 0.7330519, -0.6801727, -0.3302791, diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index b0b561e..988ad67 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -277,19 +277,27 @@ class Prophet(object): too_low = min(self.changepoints) < self.history['ds'].min() too_high = max(self.changepoints) > self.history['ds'].max() if too_low or too_high: - raise ValueError('Changepoints must fall within training data.') - elif self.n_changepoints > 0: - # Place potential changepoints evenly through first 80% of history - max_ix = np.floor(self.history.shape[0] * 0.8) - cp_indexes = ( - np.linspace(0, max_ix, self.n_changepoints + 1) - .round() - .astype(np.int) - ) - self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1) + raise ValueError( + 'Changepoints must fall within training data.') else: - # set empty changepoints - self.changepoints = [] + # Place potential changepoints evenly through first 80% of history + hist_size = np.floor(self.history.shape[0] * 0.8) + if self.n_changepoints + 1 > hist_size: + self.n_changepoints = hist_size - 1 + logger.info( + 'n_changepoints greater than number of observations.' + 'Using {}.'.format(self.n_changepoints) + ) + if self.n_changepoints > 0: + cp_indexes = ( + np.linspace(0, hist_size, self.n_changepoints + 1) + .round() + .astype(np.int) + ) + self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1) + else: + # set empty changepoints + self.changepoints = [] if len(self.changepoints) > 0: self.changepoints_t = np.sort(np.array( (self.changepoints - self.start) / self.t_scale)) diff --git a/python/fbprophet/tests/test_prophet.py b/python/fbprophet/tests/test_prophet.py index 0555aab..05b88e3 100644 --- a/python/fbprophet/tests/test_prophet.py +++ b/python/fbprophet/tests/test_prophet.py @@ -130,6 +130,18 @@ class TestProphet(TestCase): self.assertEqual(mat.shape[0], N // 2) self.assertEqual(mat.shape[1], 1) + def test_override_n_changepoints(self): + m = Prophet() + history = DATA.head(20).copy() + + history = m.setup_dataframe(history, initialize_scales=True) + m.history = history + + m.set_changepoints() + self.assertEqual(m.n_changepoints, 15) + cp = m.changepoints_t + self.assertEqual(cp.shape[0], 15) + def test_fourier_series_weekly(self): mat = Prophet.fourier_series(DATA['ds'], 7, 3) # These are from the R forecast package directly.