diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index bb7ee01..88f2d8b 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -144,3 +144,50 @@ cross_validation <- function( } return(simulated_historical_forecasts(model, horizon, units, k, period)) } + +#' Copy Prophet object. +#' +#' @param m Prophet model object. +#' @param cutoff Date, possibly as string. Changepoints are only retained if +#' changepoints <= cutoff. +#' +#' @return An unfitted Prophet model object with the same parameters as the +#' input model. +#' +#' @keywords internal +prophet_copy <- function(m, cutoff = NULL) { + if (is.null(m$history)) { + stop("This is for copying a fitted Prophet object.") + } + + if (m$specified.changepoints) { + changepoints <- m$changepoints + if (!is.null(cutoff)) { + cutoff <- set_date(cutoff) + changepoints <- changepoints[changepoints <= cutoff] + } + } else { + changepoints <- NULL + } + # Auto seasonalities are set to FALSE because they are already set in + # m$seasonalities. + m2 <- prophet( + growth = m$growth, + changepoints = changepoints, + n.changepoints = m$n.changepoints, + yearly.seasonality = FALSE, + weekly.seasonality = FALSE, + daily.seasonality = FALSE, + holidays = m$holidays, + seasonality.prior.scale = m$seasonality.prior.scale, + changepoint.prior.scale = m$changepoint.prior.scale, + holidays.prior.scale = m$holidays.prior.scale, + mcmc.samples = m$mcmc.samples, + interval.width = m$interval.width, + uncertainty.samples = m$uncertainty.samples, + fit = FALSE + ) + m2$extra_regressors <- m$extra_regressors + m2$seasonalities <- m$seasonalities + return(m2) +} diff --git a/R/R/prophet.R b/R/R/prophet.R index 0eb4525..e996dd9 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -1374,51 +1374,4 @@ make_future_dataframe <- function(m, periods, freq = 'day', return(data.frame(ds = dates)) } -#' Copy Prophet object. -#' -#' @param m Prophet model object. -#' @param cutoff Date, possibly as string. Changepoints are only retained if -#' changepoints <= cutoff. -#' -#' @return An unfitted Prophet model object with the same parameters as the -#' input model. -#' -#' @keywords internal -prophet_copy <- function(m, cutoff = NULL) { - if (is.null(m$history)) { - stop("This is for copying a fitted Prophet object.") - } - - if (m$specified.changepoints) { - changepoints <- m$changepoints - if (!is.null(cutoff)) { - cutoff <- set_date(cutoff) - changepoints <- changepoints[changepoints <= cutoff] - } - } else { - changepoints <- NULL - } - # Auto seasonalities are set to FALSE because they are already set in - # m$seasonalities. - m2 <- prophet( - growth = m$growth, - changepoints = changepoints, - n.changepoints = m$n.changepoints, - yearly.seasonality = FALSE, - weekly.seasonality = FALSE, - daily.seasonality = FALSE, - holidays = m$holidays, - seasonality.prior.scale = m$seasonality.prior.scale, - changepoint.prior.scale = m$changepoint.prior.scale, - holidays.prior.scale = m$holidays.prior.scale, - mcmc.samples = m$mcmc.samples, - interval.width = m$interval.width, - uncertainty.samples = m$uncertainty.samples, - fit = FALSE - ) - m2$extra_regressors <- m$extra_regressors - m2$seasonalities <- m$seasonalities - return(m2) -} - # fb-block 3 diff --git a/R/R/utils.R b/R/R/utils.R deleted file mode 100644 index e69de29..0000000 diff --git a/R/man/prophet_copy.Rd b/R/man/prophet_copy.Rd index 59704aa..0631c8c 100644 --- a/R/man/prophet_copy.Rd +++ b/R/man/prophet_copy.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prophet.R +% Please edit documentation in R/diagnostics.R \name{prophet_copy} \alias{prophet_copy} \title{Copy Prophet object.} diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index d29c254..e0dc470 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -10,13 +10,15 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +from copy import deepcopy +from functools import reduce import logging -logger = logging.getLogger(__name__) - import numpy as np import pandas as pd -from functools import reduce + + +logger = logging.getLogger(__name__) def _cutoffs(df, horizon, k, period): @@ -88,7 +90,7 @@ def simulated_historical_forecasts(model, horizon, k, period=None): predicts = [] for cutoff in cutoffs: # Generate new object with copying fitting options - m = model.copy(cutoff) + m = prophet_copy(model, cutoff) # Train model m.fit(df[df['ds'] <= cutoff]) # Calculate yhat @@ -146,6 +148,54 @@ def cross_validation(model, horizon, period=None, initial=None): 'Not enough data for specified horizon, period, and initial.') return simulated_historical_forecasts(model, horizon, k, period) + +def prophet_copy(m, cutoff=None): + """Copy Prophet object + + Parameters + ---------- + m: Prophet model. + cutoff: pd.Timestamp or None, default None. + cuttoff Timestamp for changepoints member variable. + changepoints are only retained if 'changepoints <= cutoff' + + Returns + ------- + Prophet class object with the same parameter with model variable + """ + if m.history is None: + raise Exception('This is for copying a fitted Prophet object.') + + if m.specified_changepoints: + changepoints = m.changepoints + if cutoff is not None: + # Filter change points '<= cutoff' + changepoints = changepoints[changepoints <= cutoff] + else: + changepoints = None + + # Auto seasonalities are set to False because they are already set in + # m.seasonalities. + m2 = m.__class__( + growth=m.growth, + n_changepoints=m.n_changepoints, + changepoints=changepoints, + yearly_seasonality=False, + weekly_seasonality=False, + daily_seasonality=False, + holidays=m.holidays, + seasonality_prior_scale=m.seasonality_prior_scale, + changepoint_prior_scale=m.changepoint_prior_scale, + holidays_prior_scale=m.holidays_prior_scale, + mcmc_samples=m.mcmc_samples, + interval_width=m.interval_width, + uncertainty_samples=m.uncertainty_samples, + ) + m2.extra_regressors = deepcopy(m.extra_regressors) + m2.seasonalities = deepcopy(m.seasonalities) + return m2 + + def me(df): return((df['yhat'] - df['y']).sum()/len(df['yhat'])) def mse(df): @@ -209,4 +259,4 @@ def all_metrics(model, df_cv = None): 'MAE': mae(df), 'MPE': mpe(df), 'MAPE': mape(df) - } \ No newline at end of file + } diff --git a/python/fbprophet/forecaster.py b/python/fbprophet/forecaster.py index 943af40..7856229 100644 --- a/python/fbprophet/forecaster.py +++ b/python/fbprophet/forecaster.py @@ -11,7 +11,6 @@ from __future__ import print_function from __future__ import unicode_literals from collections import defaultdict -from copy import deepcopy from datetime import timedelta import logging import warnings @@ -30,6 +29,7 @@ from fbprophet.plot import ( plot_yearly, plot_seasonality, ) +from fbprophet.diagnostics import prophet_copy # fb-block 1 end logging.basicConfig() @@ -1340,46 +1340,9 @@ class Prophet(object): ) def copy(self, cutoff=None): - """Copy Prophet object - - Parameters - ---------- - cutoff: pd.Timestamp or None, default None. - cuttoff Timestamp for changepoints member variable. - changepoints are only retained if 'changepoints <= cutoff' - - Returns - ------- - Prophet class object with the same parameter with model variable - """ - if self.history is None: - raise Exception('This is for copying a fitted Prophet object.') - - if self.specified_changepoints: - changepoints = self.changepoints - if cutoff is not None: - # Filter change points '<= cutoff' - changepoints = changepoints[changepoints <= cutoff] - else: - changepoints = None - - # Auto seasonalities are set to False because they are already set in - # self.seasonalities. - m = Prophet( - growth=self.growth, - n_changepoints=self.n_changepoints, - changepoints=changepoints, - yearly_seasonality=False, - weekly_seasonality=False, - daily_seasonality=False, - holidays=self.holidays, - seasonality_prior_scale=self.seasonality_prior_scale, - changepoint_prior_scale=self.changepoint_prior_scale, - holidays_prior_scale=self.holidays_prior_scale, - mcmc_samples=self.mcmc_samples, - interval_width=self.interval_width, - uncertainty_samples=self.uncertainty_samples, + warnings.warn( + 'This method will be removed in the next version. ' + 'Please use fbprophet.diagnostics.prophet_copy. ', + DeprecationWarning, ) - m.extra_regressors = deepcopy(self.extra_regressors) - m.seasonalities = deepcopy(self.seasonalities) - return m + return prophet_copy(m=self, cutoff=cutoff)