mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-26 03:01:03 +00:00
Move copy from method to function in diagnostics file
This commit is contained in:
parent
3da46503ed
commit
3afdaaf4e1
6 changed files with 109 additions and 96 deletions
|
|
@ -144,3 +144,50 @@ cross_validation <- function(
|
|||
}
|
||||
return(simulated_historical_forecasts(model, horizon, units, k, period))
|
||||
}
|
||||
|
||||
#' Copy Prophet object.
|
||||
#'
|
||||
#' @param m Prophet model object.
|
||||
#' @param cutoff Date, possibly as string. Changepoints are only retained if
|
||||
#' changepoints <= cutoff.
|
||||
#'
|
||||
#' @return An unfitted Prophet model object with the same parameters as the
|
||||
#' input model.
|
||||
#'
|
||||
#' @keywords internal
|
||||
prophet_copy <- function(m, cutoff = NULL) {
|
||||
if (is.null(m$history)) {
|
||||
stop("This is for copying a fitted Prophet object.")
|
||||
}
|
||||
|
||||
if (m$specified.changepoints) {
|
||||
changepoints <- m$changepoints
|
||||
if (!is.null(cutoff)) {
|
||||
cutoff <- set_date(cutoff)
|
||||
changepoints <- changepoints[changepoints <= cutoff]
|
||||
}
|
||||
} else {
|
||||
changepoints <- NULL
|
||||
}
|
||||
# Auto seasonalities are set to FALSE because they are already set in
|
||||
# m$seasonalities.
|
||||
m2 <- prophet(
|
||||
growth = m$growth,
|
||||
changepoints = changepoints,
|
||||
n.changepoints = m$n.changepoints,
|
||||
yearly.seasonality = FALSE,
|
||||
weekly.seasonality = FALSE,
|
||||
daily.seasonality = FALSE,
|
||||
holidays = m$holidays,
|
||||
seasonality.prior.scale = m$seasonality.prior.scale,
|
||||
changepoint.prior.scale = m$changepoint.prior.scale,
|
||||
holidays.prior.scale = m$holidays.prior.scale,
|
||||
mcmc.samples = m$mcmc.samples,
|
||||
interval.width = m$interval.width,
|
||||
uncertainty.samples = m$uncertainty.samples,
|
||||
fit = FALSE
|
||||
)
|
||||
m2$extra_regressors <- m$extra_regressors
|
||||
m2$seasonalities <- m$seasonalities
|
||||
return(m2)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1374,51 +1374,4 @@ make_future_dataframe <- function(m, periods, freq = 'day',
|
|||
return(data.frame(ds = dates))
|
||||
}
|
||||
|
||||
#' Copy Prophet object.
|
||||
#'
|
||||
#' @param m Prophet model object.
|
||||
#' @param cutoff Date, possibly as string. Changepoints are only retained if
|
||||
#' changepoints <= cutoff.
|
||||
#'
|
||||
#' @return An unfitted Prophet model object with the same parameters as the
|
||||
#' input model.
|
||||
#'
|
||||
#' @keywords internal
|
||||
prophet_copy <- function(m, cutoff = NULL) {
|
||||
if (is.null(m$history)) {
|
||||
stop("This is for copying a fitted Prophet object.")
|
||||
}
|
||||
|
||||
if (m$specified.changepoints) {
|
||||
changepoints <- m$changepoints
|
||||
if (!is.null(cutoff)) {
|
||||
cutoff <- set_date(cutoff)
|
||||
changepoints <- changepoints[changepoints <= cutoff]
|
||||
}
|
||||
} else {
|
||||
changepoints <- NULL
|
||||
}
|
||||
# Auto seasonalities are set to FALSE because they are already set in
|
||||
# m$seasonalities.
|
||||
m2 <- prophet(
|
||||
growth = m$growth,
|
||||
changepoints = changepoints,
|
||||
n.changepoints = m$n.changepoints,
|
||||
yearly.seasonality = FALSE,
|
||||
weekly.seasonality = FALSE,
|
||||
daily.seasonality = FALSE,
|
||||
holidays = m$holidays,
|
||||
seasonality.prior.scale = m$seasonality.prior.scale,
|
||||
changepoint.prior.scale = m$changepoint.prior.scale,
|
||||
holidays.prior.scale = m$holidays.prior.scale,
|
||||
mcmc.samples = m$mcmc.samples,
|
||||
interval.width = m$interval.width,
|
||||
uncertainty.samples = m$uncertainty.samples,
|
||||
fit = FALSE
|
||||
)
|
||||
m2$extra_regressors <- m$extra_regressors
|
||||
m2$seasonalities <- m$seasonalities
|
||||
return(m2)
|
||||
}
|
||||
|
||||
# fb-block 3
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
% Please edit documentation in R/diagnostics.R
|
||||
\name{prophet_copy}
|
||||
\alias{prophet_copy}
|
||||
\title{Copy Prophet object.}
|
||||
|
|
|
|||
|
|
@ -10,13 +10,15 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from functools import reduce
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cutoffs(df, horizon, k, period):
|
||||
|
|
@ -88,7 +90,7 @@ def simulated_historical_forecasts(model, horizon, k, period=None):
|
|||
predicts = []
|
||||
for cutoff in cutoffs:
|
||||
# Generate new object with copying fitting options
|
||||
m = model.copy(cutoff)
|
||||
m = prophet_copy(model, cutoff)
|
||||
# Train model
|
||||
m.fit(df[df['ds'] <= cutoff])
|
||||
# Calculate yhat
|
||||
|
|
@ -146,6 +148,54 @@ def cross_validation(model, horizon, period=None, initial=None):
|
|||
'Not enough data for specified horizon, period, and initial.')
|
||||
return simulated_historical_forecasts(model, horizon, k, period)
|
||||
|
||||
|
||||
def prophet_copy(m, cutoff=None):
|
||||
"""Copy Prophet object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
m: Prophet model.
|
||||
cutoff: pd.Timestamp or None, default None.
|
||||
cuttoff Timestamp for changepoints member variable.
|
||||
changepoints are only retained if 'changepoints <= cutoff'
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prophet class object with the same parameter with model variable
|
||||
"""
|
||||
if m.history is None:
|
||||
raise Exception('This is for copying a fitted Prophet object.')
|
||||
|
||||
if m.specified_changepoints:
|
||||
changepoints = m.changepoints
|
||||
if cutoff is not None:
|
||||
# Filter change points '<= cutoff'
|
||||
changepoints = changepoints[changepoints <= cutoff]
|
||||
else:
|
||||
changepoints = None
|
||||
|
||||
# Auto seasonalities are set to False because they are already set in
|
||||
# m.seasonalities.
|
||||
m2 = m.__class__(
|
||||
growth=m.growth,
|
||||
n_changepoints=m.n_changepoints,
|
||||
changepoints=changepoints,
|
||||
yearly_seasonality=False,
|
||||
weekly_seasonality=False,
|
||||
daily_seasonality=False,
|
||||
holidays=m.holidays,
|
||||
seasonality_prior_scale=m.seasonality_prior_scale,
|
||||
changepoint_prior_scale=m.changepoint_prior_scale,
|
||||
holidays_prior_scale=m.holidays_prior_scale,
|
||||
mcmc_samples=m.mcmc_samples,
|
||||
interval_width=m.interval_width,
|
||||
uncertainty_samples=m.uncertainty_samples,
|
||||
)
|
||||
m2.extra_regressors = deepcopy(m.extra_regressors)
|
||||
m2.seasonalities = deepcopy(m.seasonalities)
|
||||
return m2
|
||||
|
||||
|
||||
def me(df):
|
||||
return((df['yhat'] - df['y']).sum()/len(df['yhat']))
|
||||
def mse(df):
|
||||
|
|
@ -209,4 +259,4 @@ def all_metrics(model, df_cv = None):
|
|||
'MAE': mae(df),
|
||||
'MPE': mpe(df),
|
||||
'MAPE': mape(df)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import warnings
|
||||
|
|
@ -30,6 +29,7 @@ from fbprophet.plot import (
|
|||
plot_yearly,
|
||||
plot_seasonality,
|
||||
)
|
||||
from fbprophet.diagnostics import prophet_copy
|
||||
# fb-block 1 end
|
||||
|
||||
logging.basicConfig()
|
||||
|
|
@ -1340,46 +1340,9 @@ class Prophet(object):
|
|||
)
|
||||
|
||||
def copy(self, cutoff=None):
|
||||
"""Copy Prophet object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cutoff: pd.Timestamp or None, default None.
|
||||
cuttoff Timestamp for changepoints member variable.
|
||||
changepoints are only retained if 'changepoints <= cutoff'
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prophet class object with the same parameter with model variable
|
||||
"""
|
||||
if self.history is None:
|
||||
raise Exception('This is for copying a fitted Prophet object.')
|
||||
|
||||
if self.specified_changepoints:
|
||||
changepoints = self.changepoints
|
||||
if cutoff is not None:
|
||||
# Filter change points '<= cutoff'
|
||||
changepoints = changepoints[changepoints <= cutoff]
|
||||
else:
|
||||
changepoints = None
|
||||
|
||||
# Auto seasonalities are set to False because they are already set in
|
||||
# self.seasonalities.
|
||||
m = Prophet(
|
||||
growth=self.growth,
|
||||
n_changepoints=self.n_changepoints,
|
||||
changepoints=changepoints,
|
||||
yearly_seasonality=False,
|
||||
weekly_seasonality=False,
|
||||
daily_seasonality=False,
|
||||
holidays=self.holidays,
|
||||
seasonality_prior_scale=self.seasonality_prior_scale,
|
||||
changepoint_prior_scale=self.changepoint_prior_scale,
|
||||
holidays_prior_scale=self.holidays_prior_scale,
|
||||
mcmc_samples=self.mcmc_samples,
|
||||
interval_width=self.interval_width,
|
||||
uncertainty_samples=self.uncertainty_samples,
|
||||
warnings.warn(
|
||||
'This method will be removed in the next version. '
|
||||
'Please use fbprophet.diagnostics.prophet_copy. ',
|
||||
DeprecationWarning,
|
||||
)
|
||||
m.extra_regressors = deepcopy(self.extra_regressors)
|
||||
m.seasonalities = deepcopy(self.seasonalities)
|
||||
return m
|
||||
return prophet_copy(m=self, cutoff=cutoff)
|
||||
|
|
|
|||
Loading…
Reference in a new issue