From ba1bac834e864ccf736397194847b3ac95cff8cf Mon Sep 17 00:00:00 2001 From: Marc Ferradou Date: Fri, 25 May 2018 18:53:19 -0400 Subject: [PATCH] Adding changepoint threshold (#299) --- R/R/prophet.R | 15 ++++++++++++--- python/fbprophet/forecaster.py | 17 ++++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/R/R/prophet.R b/R/R/prophet.R index 92ed985..2c48d90 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -28,7 +28,9 @@ utils::globalVariables(c( #' @param n.changepoints Number of potential changepoints to include. Not used #' if input `changepoints` is supplied. If `changepoints` is not supplied, #' then n.changepoints potential changepoints are selected uniformly from the -#' first 80 percent of df$ds. +#' first `changepoint.threshold` percent of df$ds. +#' @param changepoint.threshold Parameter controling where to select the changepoints. +#' Not used if input `changepoints` is supplied. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE, #' FALSE, or a number of Fourier terms to generate. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE, @@ -79,6 +81,7 @@ prophet <- function(df = NULL, growth = 'linear', changepoints = NULL, n.changepoints = 25, + changepoint.threshold = 0.8, yearly.seasonality = 'auto', weekly.seasonality = 'auto', daily.seasonality = 'auto', @@ -103,6 +106,7 @@ prophet <- function(df = NULL, growth = growth, changepoints = changepoints, n.changepoints = n.changepoints, + changepoint.threshold = changepoint.threshold, yearly.seasonality = yearly.seasonality, weekly.seasonality = weekly.seasonality, daily.seasonality = daily.seasonality, @@ -451,9 +455,14 @@ set_changepoints <- function(m) { } } } else { - # Place potential changepoints evenly through the first 80 pcnt of + # Place potential changepoints evenly through the first changepoint.threshold pcnt of # the history. - hist.size <- floor(nrow(m$history) * .8) + if (m$changepoint.threshold > 1 || m$changepoint.threshold <= 0){ + m$changepoint.threshold <- .8 + message('changepoint.threshold greater than 1 or less than equal to 0. Using ', + m$changepoint.threshold) + } + hist.size <- floor(nrow(m$history) * m$changepoint.threshold) if (m$n.changepoints + 1 > hist.size) { m$n.changepoints <- hist.size - 1 message('n.changepoints greater than number of observations. Using ', diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index f9a7984..bd9aa9a 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -57,7 +57,9 @@ class Prophet(object): n_changepoints: Number of potential changepoints to include. Not used if input `changepoints` is supplied. If `changepoints` is not supplied, then n_changepoints potential changepoints are selected uniformly from - the first 80 percent of the history. + the first `changepoint_threshold` percent of the history. + changepoint_threshold: Parameter controling where to select the changepoints. + Not used if input `changepoints` is supplied. yearly_seasonality: Fit yearly seasonality. Can be 'auto', True, False, or a number of Fourier terms to generate. weekly_seasonality: Fit weekly seasonality. @@ -97,6 +99,7 @@ class Prophet(object): growth='linear', changepoints=None, n_changepoints=25, + changepoint_threshold=0.8, yearly_seasonality='auto', weekly_seasonality='auto', daily_seasonality='auto', @@ -119,6 +122,7 @@ class Prophet(object): self.n_changepoints = n_changepoints self.specified_changepoints = False + self.changepoint_threshold = changepoint_threshold self.yearly_seasonality = yearly_seasonality self.weekly_seasonality = weekly_seasonality self.daily_seasonality = daily_seasonality @@ -332,8 +336,15 @@ class Prophet(object): raise ValueError( 'Changepoints must fall within training data.') else: - # Place potential changepoints evenly through first 80% of history - hist_size = np.floor(self.history.shape[0] * 0.8) + # Place potential changepoints evenly through first changepoint_threshold + # of history + if (self.changepoint_threshold > 1 or self.changepoint_threshold <= 0): + self.changepoint_threshold = 0.8 + logger.info( + 'changepoint_threshold greater than 1 or less than equal to 0.' + 'Using {}.'.format(self.changepoint_threshold) + ) + hist_size = np.floor(self.history.shape[0] * self.changepoint_threshold) if self.n_changepoints + 1 > hist_size: self.n_changepoints = hist_size - 1 logger.info(