Move copy from method to function in diagnostics file

This commit is contained in:
Ben Letham 2018-05-03 11:42:10 -07:00
parent 3da46503ed
commit 3afdaaf4e1
6 changed files with 109 additions and 96 deletions

View file

@ -144,3 +144,50 @@ cross_validation <- function(
}
return(simulated_historical_forecasts(model, horizon, units, k, period))
}
#' Copy Prophet object.
#'
#' @param m Prophet model object.
#' @param cutoff Date, possibly as string. Changepoints are only retained if
#' changepoints <= cutoff.
#'
#' @return An unfitted Prophet model object with the same parameters as the
#' input model.
#'
#' @keywords internal
prophet_copy <- function(m, cutoff = NULL) {
if (is.null(m$history)) {
stop("This is for copying a fitted Prophet object.")
}
if (m$specified.changepoints) {
changepoints <- m$changepoints
if (!is.null(cutoff)) {
cutoff <- set_date(cutoff)
changepoints <- changepoints[changepoints <= cutoff]
}
} else {
changepoints <- NULL
}
# Auto seasonalities are set to FALSE because they are already set in
# m$seasonalities.
m2 <- prophet(
growth = m$growth,
changepoints = changepoints,
n.changepoints = m$n.changepoints,
yearly.seasonality = FALSE,
weekly.seasonality = FALSE,
daily.seasonality = FALSE,
holidays = m$holidays,
seasonality.prior.scale = m$seasonality.prior.scale,
changepoint.prior.scale = m$changepoint.prior.scale,
holidays.prior.scale = m$holidays.prior.scale,
mcmc.samples = m$mcmc.samples,
interval.width = m$interval.width,
uncertainty.samples = m$uncertainty.samples,
fit = FALSE
)
m2$extra_regressors <- m$extra_regressors
m2$seasonalities <- m$seasonalities
return(m2)
}

View file

@ -1374,51 +1374,4 @@ make_future_dataframe <- function(m, periods, freq = 'day',
return(data.frame(ds = dates))
}
#' Copy Prophet object.
#'
#' @param m Prophet model object.
#' @param cutoff Date, possibly as string. Changepoints are only retained if
#' changepoints <= cutoff.
#'
#' @return An unfitted Prophet model object with the same parameters as the
#' input model.
#'
#' @keywords internal
prophet_copy <- function(m, cutoff = NULL) {
if (is.null(m$history)) {
stop("This is for copying a fitted Prophet object.")
}
if (m$specified.changepoints) {
changepoints <- m$changepoints
if (!is.null(cutoff)) {
cutoff <- set_date(cutoff)
changepoints <- changepoints[changepoints <= cutoff]
}
} else {
changepoints <- NULL
}
# Auto seasonalities are set to FALSE because they are already set in
# m$seasonalities.
m2 <- prophet(
growth = m$growth,
changepoints = changepoints,
n.changepoints = m$n.changepoints,
yearly.seasonality = FALSE,
weekly.seasonality = FALSE,
daily.seasonality = FALSE,
holidays = m$holidays,
seasonality.prior.scale = m$seasonality.prior.scale,
changepoint.prior.scale = m$changepoint.prior.scale,
holidays.prior.scale = m$holidays.prior.scale,
mcmc.samples = m$mcmc.samples,
interval.width = m$interval.width,
uncertainty.samples = m$uncertainty.samples,
fit = FALSE
)
m2$extra_regressors <- m$extra_regressors
m2$seasonalities <- m$seasonalities
return(m2)
}
# fb-block 3

View file

View file

@ -1,5 +1,5 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
% Please edit documentation in R/diagnostics.R
\name{prophet_copy}
\alias{prophet_copy}
\title{Copy Prophet object.}

View file

@ -10,13 +10,15 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from copy import deepcopy
from functools import reduce
import logging
logger = logging.getLogger(__name__)
import numpy as np
import pandas as pd
from functools import reduce
logger = logging.getLogger(__name__)
def _cutoffs(df, horizon, k, period):
@ -88,7 +90,7 @@ def simulated_historical_forecasts(model, horizon, k, period=None):
predicts = []
for cutoff in cutoffs:
# Generate new object with copying fitting options
m = model.copy(cutoff)
m = prophet_copy(model, cutoff)
# Train model
m.fit(df[df['ds'] <= cutoff])
# Calculate yhat
@ -146,6 +148,54 @@ def cross_validation(model, horizon, period=None, initial=None):
'Not enough data for specified horizon, period, and initial.')
return simulated_historical_forecasts(model, horizon, k, period)
def prophet_copy(m, cutoff=None):
"""Copy Prophet object
Parameters
----------
m: Prophet model.
cutoff: pd.Timestamp or None, default None.
cuttoff Timestamp for changepoints member variable.
changepoints are only retained if 'changepoints <= cutoff'
Returns
-------
Prophet class object with the same parameter with model variable
"""
if m.history is None:
raise Exception('This is for copying a fitted Prophet object.')
if m.specified_changepoints:
changepoints = m.changepoints
if cutoff is not None:
# Filter change points '<= cutoff'
changepoints = changepoints[changepoints <= cutoff]
else:
changepoints = None
# Auto seasonalities are set to False because they are already set in
# m.seasonalities.
m2 = m.__class__(
growth=m.growth,
n_changepoints=m.n_changepoints,
changepoints=changepoints,
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=False,
holidays=m.holidays,
seasonality_prior_scale=m.seasonality_prior_scale,
changepoint_prior_scale=m.changepoint_prior_scale,
holidays_prior_scale=m.holidays_prior_scale,
mcmc_samples=m.mcmc_samples,
interval_width=m.interval_width,
uncertainty_samples=m.uncertainty_samples,
)
m2.extra_regressors = deepcopy(m.extra_regressors)
m2.seasonalities = deepcopy(m.seasonalities)
return m2
def me(df):
return((df['yhat'] - df['y']).sum()/len(df['yhat']))
def mse(df):
@ -209,4 +259,4 @@ def all_metrics(model, df_cv = None):
'MAE': mae(df),
'MPE': mpe(df),
'MAPE': mape(df)
}
}

View file

@ -11,7 +11,6 @@ from __future__ import print_function
from __future__ import unicode_literals
from collections import defaultdict
from copy import deepcopy
from datetime import timedelta
import logging
import warnings
@ -30,6 +29,7 @@ from fbprophet.plot import (
plot_yearly,
plot_seasonality,
)
from fbprophet.diagnostics import prophet_copy
# fb-block 1 end
logging.basicConfig()
@ -1340,46 +1340,9 @@ class Prophet(object):
)
def copy(self, cutoff=None):
"""Copy Prophet object
Parameters
----------
cutoff: pd.Timestamp or None, default None.
cuttoff Timestamp for changepoints member variable.
changepoints are only retained if 'changepoints <= cutoff'
Returns
-------
Prophet class object with the same parameter with model variable
"""
if self.history is None:
raise Exception('This is for copying a fitted Prophet object.')
if self.specified_changepoints:
changepoints = self.changepoints
if cutoff is not None:
# Filter change points '<= cutoff'
changepoints = changepoints[changepoints <= cutoff]
else:
changepoints = None
# Auto seasonalities are set to False because they are already set in
# self.seasonalities.
m = Prophet(
growth=self.growth,
n_changepoints=self.n_changepoints,
changepoints=changepoints,
yearly_seasonality=False,
weekly_seasonality=False,
daily_seasonality=False,
holidays=self.holidays,
seasonality_prior_scale=self.seasonality_prior_scale,
changepoint_prior_scale=self.changepoint_prior_scale,
holidays_prior_scale=self.holidays_prior_scale,
mcmc_samples=self.mcmc_samples,
interval_width=self.interval_width,
uncertainty_samples=self.uncertainty_samples,
warnings.warn(
'This method will be removed in the next version. '
'Please use fbprophet.diagnostics.prophet_copy. ',
DeprecationWarning,
)
m.extra_regressors = deepcopy(self.extra_regressors)
m.seasonalities = deepcopy(self.seasonalities)
return m
return prophet_copy(m=self, cutoff=cutoff)