mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-06 00:03:25 +00:00
Add cross-validation functions in R
This commit is contained in:
parent
509666d1d2
commit
3c09448018
14 changed files with 450 additions and 16 deletions
|
|
@ -3,10 +3,12 @@
|
|||
S3method(plot,prophet)
|
||||
S3method(predict,prophet)
|
||||
export(add_seasonality)
|
||||
export(cross_validation)
|
||||
export(fit.prophet)
|
||||
export(make_future_dataframe)
|
||||
export(predictive_samples)
|
||||
export(prophet)
|
||||
export(prophet_plot_components)
|
||||
export(simulated_historical_forecasts)
|
||||
import(Rcpp)
|
||||
importFrom(dplyr,"%>%")
|
||||
|
|
|
|||
132
R/R/diagnostics.R
Normal file
132
R/R/diagnostics.R
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
## Copyright (c) 2017-present, Facebook, Inc.
|
||||
## All rights reserved.
|
||||
|
||||
## This source code is licensed under the BSD-style license found in the
|
||||
## LICENSE file in the root directory of this source tree. An additional grant
|
||||
## of patent rights can be found in the PATENTS file in the same directory.
|
||||
|
||||
## Makes R CMD CHECK happy due to dplyr syntax below
|
||||
globalVariables(c(
|
||||
"ds", "y", "cap", "yhat", "yhat_lower", "yhat_upper"))
|
||||
|
||||
#' Generate cutoff dates
|
||||
#'
|
||||
#' @param df Dataframe with historical data
|
||||
#' @param horizon timediff forecast horizon
|
||||
#' @param k integer number of forecast points
|
||||
#' @param period timediff Simulated forecasts are done with this period.
|
||||
#'
|
||||
#' @return Array of datetimes
|
||||
#'
|
||||
#' @keywords internal
|
||||
generate_cutoffs <- function(df, horizon, k, period) {
|
||||
# Last cutoff is (latest date in data) - (horizon).
|
||||
cutoff <- max(df$ds) - horizon
|
||||
if (cutoff < min(df$ds)) {
|
||||
stop('Less data than horizon.')
|
||||
}
|
||||
tzone <- attr(cutoff, "tzone") # Timezone is wiped by putting in array
|
||||
result <- c(cutoff)
|
||||
for (i in 2:k) {
|
||||
cutoff <- cutoff - period
|
||||
# If data does not exist in data range (cutoff, cutoff + horizon]
|
||||
if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
|
||||
# Next cutoff point is 'closest date before cutoff in data - horizon'
|
||||
closest.date <- max(df$ds[df$ds <= cutoff])
|
||||
cutoff <- closest.date - horizon
|
||||
}
|
||||
if (cutoff < min(df$ds)) {
|
||||
warning('Not enough data for requested number of cutoffs! Using ', i)
|
||||
break
|
||||
}
|
||||
result <- c(result, cutoff)
|
||||
}
|
||||
# Reset timezones
|
||||
attr(result, "tzone") <- tzone
|
||||
return(rev(result))
|
||||
}
|
||||
|
||||
#' Simulated historical forecasts.
|
||||
#' Make forecasts from k historical cutoff dates, and compare forecast values
|
||||
#' to actual values.
|
||||
#'
|
||||
#' @param model Fitted Prophet model.
|
||||
#' @param horizon Integer size of the horizon
|
||||
#' @param units String unit of the horizon, e.g., "days", "secs".
|
||||
#' @param k integer number of forecast points
|
||||
#' @param period Integer amount of time between cutoff dates. Same units as
|
||||
#' horizon. If not provided, will use 0.5 * horizon.
|
||||
#'
|
||||
#' @return A dataframe with the forecast, actual value, and cutoff date.
|
||||
#'
|
||||
#' @export
|
||||
simulated_historical_forecasts <- function(model, horizon, units, k,
|
||||
period = NULL) {
|
||||
df <- model$history
|
||||
horizon <- as.difftime(horizon, units = units)
|
||||
if (is.null(period)) {
|
||||
period <- horizon / 2
|
||||
} else {
|
||||
period <- as.difftime(period, units = units)
|
||||
}
|
||||
cutoffs <- generate_cutoffs(df, horizon, k, period)
|
||||
predicts <- data.frame()
|
||||
for (i in 1:length(cutoffs)) {
|
||||
cutoff <- cutoffs[i]
|
||||
# Copy the model
|
||||
m <- prophet_copy(model, cutoff)
|
||||
# Train model
|
||||
history.c <- dplyr::filter(df, ds <= cutoff)
|
||||
m <- fit.prophet(m, history.c)
|
||||
# Calculate yhat
|
||||
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon)
|
||||
if (m$growth == 'logistic') {
|
||||
future <- dplyr::select(df.predict, ds, cap)
|
||||
} else{
|
||||
future <- dplyr::select(df.predict, ds)
|
||||
}
|
||||
yhat <- stats::predict(m, future)
|
||||
# Merge yhat, y, and cutoff.
|
||||
df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
|
||||
df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
|
||||
df.c$cutoff <- cutoff
|
||||
predicts <- rbind(predicts, df.c)
|
||||
}
|
||||
return(predicts)
|
||||
}
|
||||
|
||||
#' Cross-validation for time series.
|
||||
#' Computes forecast error with cutoffs at the specified period. When the
|
||||
#' period is the time interval of the data, is the procedure described in
|
||||
#' https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
|
||||
#' a cutoff every "period" amount of time, going back to "initial".
|
||||
#'
|
||||
#' @param model Fitted Prophet model.
|
||||
#' @param horizon Integer size of the horizon
|
||||
#' @param units String unit of the horizon, e.g., "days", "secs".
|
||||
#' @param period Integer amount of time between cutoff dates. Same units as
|
||||
#' horizon.
|
||||
#' @param initial Integer size of the first training period. If not provided,
|
||||
#' 3 * horizon is used. Same units as horizon.
|
||||
#'
|
||||
#' @return A dataframe with the forecast, actual value, and cutoff date.
|
||||
#'
|
||||
#' @export
|
||||
cross_validation <- function(model, horizon, units, period, initial = NULL) {
|
||||
te <- max(model$history$ds)
|
||||
ts <- min(model$history$ds)
|
||||
if (is.null(initial)) {
|
||||
initial <- 3 * horizon
|
||||
}
|
||||
horizon.dt <- as.difftime(horizon, units = units)
|
||||
initial.dt <- as.difftime(initial, units = units)
|
||||
period.dt <- as.difftime(period, units = units)
|
||||
k <- ceiling(
|
||||
as.double((te - horizon.dt) - (ts + initial.dt), units='secs') /
|
||||
as.double(period.dt, units = 'secs')
|
||||
)
|
||||
if (k < 1) {
|
||||
stop('Not enough data for specified horizon and initial.')
|
||||
}
|
||||
return(simulated_historical_forecasts(model, horizon, units, k, period))
|
||||
}
|
||||
|
|
@ -109,6 +109,7 @@ prophet <- function(df = NULL,
|
|||
mcmc.samples = mcmc.samples,
|
||||
interval.width = interval.width,
|
||||
uncertainty.samples = uncertainty.samples,
|
||||
specified.changepoints = !is.null(changepoints),
|
||||
start = NULL, # This and following attributes are set during fitting
|
||||
y.scale = NULL,
|
||||
t.scale = NULL,
|
||||
|
|
@ -240,6 +241,7 @@ set_date <- function(ds = NULL, tz = "GMT") {
|
|||
} else {
|
||||
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
|
||||
}
|
||||
attr(ds, "tzone") <- tz
|
||||
return(ds)
|
||||
}
|
||||
|
||||
|
|
@ -1411,4 +1413,42 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
|
|||
return(gg.s)
|
||||
}
|
||||
|
||||
#' Copy Prophet object.
|
||||
#'
|
||||
#' @param m Prophet model object.
|
||||
#' @param cutoff Date, possibly as string. Changepoints are only retained if
|
||||
#' changepoints <= cutoff.
|
||||
#'
|
||||
#' @return An unfitted Prophet model object with the same parameters as the
|
||||
#' input model.
|
||||
#'
|
||||
#' @keywords internal
|
||||
prophet_copy <- function(m, cutoff = NULL) {
|
||||
if (m$specified.changepoints) {
|
||||
changepoints <- m$changepoints
|
||||
if (!is.null(cutoff)) {
|
||||
cutoff <- set_date(cutoff)
|
||||
changepoints <- changepoints[changepoints <= cutoff]
|
||||
}
|
||||
} else {
|
||||
changepoints <- NULL
|
||||
}
|
||||
return(prophet(
|
||||
growth = m$growth,
|
||||
changepoints = changepoints,
|
||||
n.changepoints = m$n.changepoints,
|
||||
yearly.seasonality = m$yearly.seasonality,
|
||||
weekly.seasonality = m$weekly.seasonality,
|
||||
daily.seasonality = m$daily.seasonality,
|
||||
holidays = m$holidays,
|
||||
seasonality.prior.scale = m$seasonality.prior.scale,
|
||||
changepoint.prior.scale = m$changepoint.prior.scale,
|
||||
holidays.prior.scale = m$holidays.prior.scale,
|
||||
mcmc.samples = m$mcmc.samples,
|
||||
interval.width = m$interval.width,
|
||||
uncertainty.samples = m$uncertainty.samples,
|
||||
fit = FALSE,
|
||||
))
|
||||
}
|
||||
|
||||
# fb-block 3
|
||||
|
|
|
|||
|
|
@ -21,5 +21,6 @@ The prophet model with the seasonality added.
|
|||
}
|
||||
\description{
|
||||
Increasing the number of Fourier components allows the seasonality to change
|
||||
more quickly (at risk of overfitting).
|
||||
more quickly (at risk of overfitting). Default values for yearly and weekly
|
||||
seasonalities are 10 and 3 respectively.
|
||||
}
|
||||
|
|
|
|||
35
R/man/cross_validation.Rd
Normal file
35
R/man/cross_validation.Rd
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/diagnostics.R
|
||||
\name{cross_validation}
|
||||
\alias{cross_validation}
|
||||
\title{Cross-validation for time series.
|
||||
Computes forecast error with cutoffs at the specified period. When the
|
||||
period is the time interval of the data, is the procedure described in
|
||||
https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
|
||||
a cutoff every "period" amount of time, going back to "initial".}
|
||||
\usage{
|
||||
cross_validation(model, horizon, units, period, initial = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{model}{Fitted Prophet model.}
|
||||
|
||||
\item{horizon}{Integer size of the horizon}
|
||||
|
||||
\item{units}{String unit of the horizon, e.g., "days", "secs".}
|
||||
|
||||
\item{period}{Integer amount of time between cutoff dates. Same units as
|
||||
horizon.}
|
||||
|
||||
\item{initial}{Integer size of the first training period. If not provided,
|
||||
3 * horizon is used. Same units as horizon.}
|
||||
}
|
||||
\value{
|
||||
A dataframe with the forecast, actual value, and cutoff date.
|
||||
}
|
||||
\description{
|
||||
Cross-validation for time series.
|
||||
Computes forecast error with cutoffs at the specified period. When the
|
||||
period is the time interval of the data, is the procedure described in
|
||||
https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
|
||||
a cutoff every "period" amount of time, going back to "initial".
|
||||
}
|
||||
24
R/man/generate_cutoffs.Rd
Normal file
24
R/man/generate_cutoffs.Rd
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/diagnostics.R
|
||||
\name{generate_cutoffs}
|
||||
\alias{generate_cutoffs}
|
||||
\title{Generate cutoff dates}
|
||||
\usage{
|
||||
generate_cutoffs(df, horizon, k, period)
|
||||
}
|
||||
\arguments{
|
||||
\item{df}{Dataframe with historical data}
|
||||
|
||||
\item{horizon}{timediff forecast horizon}
|
||||
|
||||
\item{k}{integer number of forecast points}
|
||||
|
||||
\item{period}{timediff Simulated forecasts are done with this period.}
|
||||
}
|
||||
\value{
|
||||
Array of datetimes
|
||||
}
|
||||
\description{
|
||||
Generate cutoff dates
|
||||
}
|
||||
\keyword{internal}
|
||||
22
R/man/prophet_copy.Rd
Normal file
22
R/man/prophet_copy.Rd
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/prophet.R
|
||||
\name{prophet_copy}
|
||||
\alias{prophet_copy}
|
||||
\title{Copy Prophet object.}
|
||||
\usage{
|
||||
prophet_copy(m, cutoff = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{m}{Prophet model object.}
|
||||
|
||||
\item{cutoff}{Date, possibly as string. Changepoints are only retained if
|
||||
changepoints <= cutoff.}
|
||||
}
|
||||
\value{
|
||||
An unfitted Prophet model object with the same parameters as the
|
||||
input model.
|
||||
}
|
||||
\description{
|
||||
Copy Prophet object.
|
||||
}
|
||||
\keyword{internal}
|
||||
30
R/man/simulated_historical_forecasts.Rd
Normal file
30
R/man/simulated_historical_forecasts.Rd
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/diagnostics.R
|
||||
\name{simulated_historical_forecasts}
|
||||
\alias{simulated_historical_forecasts}
|
||||
\title{Simulated historical forecasts.
|
||||
Make forecasts from k historical cutoff dates, and compare forecast values
|
||||
to actual values.}
|
||||
\usage{
|
||||
simulated_historical_forecasts(model, horizon, units, k, period = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{model}{Fitted Prophet model.}
|
||||
|
||||
\item{horizon}{Integer size of the horizon}
|
||||
|
||||
\item{units}{String unit of the horizon, e.g., "days", "secs".}
|
||||
|
||||
\item{k}{integer number of forecast points}
|
||||
|
||||
\item{period}{Integer amount of time between cutoff dates. Same units as
|
||||
horizon. If not provided, will use 0.5 * horizon.}
|
||||
}
|
||||
\value{
|
||||
A dataframe with the forecast, actual value, and cutoff date.
|
||||
}
|
||||
\description{
|
||||
Simulated historical forecasts.
|
||||
Make forecasts from k historical cutoff dates, and compare forecast values
|
||||
to actual values.
|
||||
}
|
||||
86
R/tests/testthat/test_diagnostics.R
Normal file
86
R/tests/testthat/test_diagnostics.R
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
library(prophet)
|
||||
context("Prophet diagnostics tests")
|
||||
|
||||
## Makes R CMD CHECK happy due to dplyr syntax below
|
||||
globalVariables(c("y", "yhat"))
|
||||
|
||||
DATA <- head(read.csv('data.csv'), 100)
|
||||
DATA$ds <- as.Date(DATA$ds)
|
||||
|
||||
test_that("simulated_historical_forecasts", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet(DATA)
|
||||
k <- 2
|
||||
for (p in c(1, 10)) {
|
||||
for (h in c(1, 3)) {
|
||||
df.shf <- simulated_historical_forecasts(
|
||||
m, horizon = h, units = 'days', k = k, period = p)
|
||||
# All cutoff dates should be less than ds dates
|
||||
expect_true(all(df.shf$cutoff < df.shf$ds))
|
||||
# The unique size of output cutoff should be equal to 'k'
|
||||
expect_equal(length(unique(df.shf$cutoff)), k)
|
||||
expect_equal(max(df.shf$ds - df.shf$cutoff),
|
||||
as.difftime(h, units = 'days'))
|
||||
dc <- diff(df.shf$cutoff)
|
||||
dc <- min(dc[dc > 0])
|
||||
expect_true(dc >= as.difftime(p, units = 'days'))
|
||||
# Each y in df_shf and DATA with same ds should be equal
|
||||
df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
|
||||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
test_that("simulated_historical_forecasts_logistic", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
df <- DATA
|
||||
df$cap <- 40
|
||||
m <- prophet(df, growth='logistic')
|
||||
df.shf <- simulated_historical_forecasts(
|
||||
m, horizon = 3, units = 'days', k = 2, period = 3)
|
||||
# All cutoff dates should be less than ds dates
|
||||
expect_true(all(df.shf$cutoff < df.shf$ds))
|
||||
# The unique size of output cutoff should be equal to 'k'
|
||||
expect_equal(length(unique(df.shf$cutoff)), 2)
|
||||
# Each y in df_shf and DATA with same ds should be equal
|
||||
df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
|
||||
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
||||
})
|
||||
|
||||
test_that("simulated_historical_forecasts_default_value_check", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet(DATA)
|
||||
df.shf1 <- simulated_historical_forecasts(
|
||||
m, horizon = 10, units = 'days', k = 1)
|
||||
df.shf2 <- simulated_historical_forecasts(
|
||||
m, horizon = 10, units = 'days', k = 1, period = 5)
|
||||
expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
|
||||
})
|
||||
|
||||
test_that("cross_validation", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet(DATA)
|
||||
# Calculate the number of cutoff points
|
||||
te <- max(DATA$ds)
|
||||
ts <- min(DATA$ds)
|
||||
horizon <- as.difftime(4, units = "days")
|
||||
period <- as.difftime(10, units = "days")
|
||||
k <- 5
|
||||
df.cv <- cross_validation(
|
||||
m, horizon = 4, units = "days", period = 10, initial = 90)
|
||||
expect_equal(length(unique(df.cv$cutoff)), k)
|
||||
expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
|
||||
dc <- diff(df.cv$cutoff)
|
||||
dc <- min(dc[dc > 0])
|
||||
expect_true(dc >= period)
|
||||
})
|
||||
|
||||
test_that("cross_validation_default_value_check", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet(DATA)
|
||||
df.cv1 <- cross_validation(
|
||||
m, horizon = 32, units = "days", period = 10)
|
||||
df.cv2 <- cross_validation(
|
||||
m, horizon = 32, units = 'days', period = 10, initial = 96)
|
||||
expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
|
||||
})
|
||||
|
|
@ -330,3 +330,57 @@ test_that("custom_seasonality", {
|
|||
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
|
||||
expect_equal(m$seasonalities[['monthly']], c(30, 5))
|
||||
})
|
||||
|
||||
test_that("copy", {
|
||||
inputs <- list(
|
||||
growth = c('linear', 'logistic'),
|
||||
changepoints = c(NULL, c('2016-12-25')),
|
||||
n.changepoints = c(3),
|
||||
yearly.seasonality = c(TRUE, FALSE),
|
||||
weekly.seasonality = c(TRUE, FALSE),
|
||||
daily.seasonality = c(TRUE, FALSE),
|
||||
holidays = c(NULL, 'insert_dataframe'),
|
||||
seasonality.prior.scale = c(1.1),
|
||||
holidays.prior.scale = c(1.1),
|
||||
changepoints.prior.scale = c(0.1),
|
||||
mcmc.samples = c(100),
|
||||
interval.width = c(0.9),
|
||||
uncertainty.samples = c(200)
|
||||
)
|
||||
products <- expand.grid(inputs)
|
||||
for (i in 1:length(products)) {
|
||||
if (products$holidays[i] == 'insert_dataframe') {
|
||||
holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
|
||||
} else {
|
||||
holidays <- NULL
|
||||
}
|
||||
m1 <- prophet(
|
||||
growth = products$growth[i],
|
||||
changepoints = products$changepoints[i],
|
||||
n.changepoints = products$n.changepoints[i],
|
||||
yearly.seasonality = products$yearly.seasonality[i],
|
||||
weekly.seasonality = products$weekly.seasonality[i],
|
||||
daily.seasonality = products$daily.seasonality[i],
|
||||
holidays = holidays,
|
||||
seasonality.prior.scale = products$seasonality.prior.scale[i],
|
||||
holidays.prior.scale = products$holidays.prior.scale[i],
|
||||
changepoints.prior.scale = products$changepoints.prior.scale[i],
|
||||
mcmc.samples = products$mcmc.samples[i],
|
||||
interval.width = products$interval.width[i],
|
||||
uncertainty.samples = products$uncertainty.samples[i],
|
||||
fit = FALSE
|
||||
)
|
||||
m2 <- prophet:::prophet_copy(m1)
|
||||
# Values should be copied correctly
|
||||
for (arg in names(inputs)) {
|
||||
expect_equal(m1[[arg]], m2[[arg]])
|
||||
}
|
||||
}
|
||||
# Check for cutoff
|
||||
changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
|
||||
cutoff <- as.Date('2012-07-25')
|
||||
m1 <- prophet(DATA, changepoints = changepoints)
|
||||
m2 <- prophet:::prophet_copy(m1, cutoff)
|
||||
changepoints <- changepoints[changepoints <= cutoff]
|
||||
expect_equal(prophet:::set_date(changepoints), m2$changepoints)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ def _cutoffs(df, horizon, k, period):
|
|||
"""
|
||||
# Last cutoff is 'latest date in data - horizon' date
|
||||
cutoff = df['ds'].max() - horizon
|
||||
if cutoff < df['ds'].min():
|
||||
raise ValueError('Less data than horizon.')
|
||||
result = [cutoff]
|
||||
|
||||
for i in range(1, k):
|
||||
|
|
@ -48,7 +50,7 @@ def _cutoffs(df, horizon, k, period):
|
|||
closest_date = df[df['ds'] <= cutoff].max()['ds']
|
||||
cutoff = closest_date - horizon
|
||||
if cutoff < df['ds'].min():
|
||||
logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(k))
|
||||
logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(i))
|
||||
break
|
||||
result.append(cutoff)
|
||||
|
||||
|
|
@ -127,5 +129,7 @@ def cross_validation(model, horizon, period, initial=None):
|
|||
horizon = pd.Timedelta(horizon)
|
||||
period = pd.Timedelta(period)
|
||||
initial = 3 * horizon if initial is None else pd.Timedelta(initial)
|
||||
k = int(np.floor(((te - horizon) - (ts + initial)) / period))
|
||||
k = int(np.ceil(((te - horizon) - (ts + initial)) / period))
|
||||
if k < 1:
|
||||
raise ValueError('Not enough data for specified horizon and initial.')
|
||||
return simulated_historical_forecasts(model, horizon, k, period)
|
||||
|
|
|
|||
|
|
@ -100,8 +100,10 @@ class Prophet(object):
|
|||
self.changepoints = pd.to_datetime(changepoints)
|
||||
if self.changepoints is not None:
|
||||
self.n_changepoints = len(self.changepoints)
|
||||
self.specified_changepoints = True
|
||||
else:
|
||||
self.n_changepoints = n_changepoints
|
||||
self.specified_changepoints = False
|
||||
|
||||
self.yearly_seasonality = yearly_seasonality
|
||||
self.weekly_seasonality = weekly_seasonality
|
||||
|
|
@ -1420,21 +1422,24 @@ class Prophet(object):
|
|||
----------
|
||||
cutoff: pd.Timestamp or None, default None.
|
||||
cuttoff Timestamp for changepoints member variable.
|
||||
changepoints are only remained if 'changepoints <= cutoff'
|
||||
changepoints are only retained if 'changepoints <= cutoff'
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prophet class object with the same parameter with model variable
|
||||
"""
|
||||
if self.changepoints is not None and cutoff is not None:
|
||||
# Filter change points '<= cutoff'
|
||||
self.changepoints = self.changepoints[self.changepoints <= cutoff]
|
||||
self.n_changepoints = len(self.changepoints)
|
||||
if self.specified_changepoints:
|
||||
changepoints = self.changepoints
|
||||
if cutoff is not None:
|
||||
# Filter change points '<= cutoff'
|
||||
changepoints = changepoints[changepoints <= cutoff]
|
||||
else:
|
||||
changepoints = None
|
||||
|
||||
return Prophet(
|
||||
growth=self.growth,
|
||||
n_changepoints=self.n_changepoints,
|
||||
changepoints=self.changepoints,
|
||||
changepoints=changepoints,
|
||||
yearly_seasonality=self.yearly_seasonality,
|
||||
weekly_seasonality=self.weekly_seasonality,
|
||||
daily_seasonality=self.daily_seasonality,
|
||||
|
|
@ -1446,5 +1451,3 @@ class Prophet(object):
|
|||
interval_width=self.interval_width,
|
||||
uncertainty_samples=self.uncertainty_samples
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ class TestDiagnostics(TestCase):
|
|||
ts = self.__df['ds'].min()
|
||||
horizon = pd.Timedelta('4 days')
|
||||
period = pd.Timedelta('10 days')
|
||||
initial = pd.Timedelta('90 days')
|
||||
k = int(np.floor(((te - horizon) - (ts + initial)) / period))
|
||||
df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
|
||||
k = 5
|
||||
df_cv = diagnostics.cross_validation(
|
||||
m, horizon='4 days', period='10 days', initial='90 days')
|
||||
# The unique size of output cutoff should be equal to 'k'
|
||||
self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
|
||||
self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)
|
||||
|
|
|
|||
|
|
@ -490,9 +490,10 @@ class TestProphet(TestCase):
|
|||
self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
|
||||
|
||||
# Check for cutoff
|
||||
changepoints = pd.date_range('2016-12-15', '2017-01-15')
|
||||
cutoff = pd.Timestamp('2016-12-25')
|
||||
changepoints = pd.date_range('2012-06-15', '2012-09-15')
|
||||
cutoff = pd.Timestamp('2012-07-25')
|
||||
m1 = Prophet(changepoints=changepoints)
|
||||
m1.fit(DATA)
|
||||
m2 = m1.copy(cutoff=cutoff)
|
||||
changepoints = changepoints[changepoints <= cutoff]
|
||||
self.assertTrue((changepoints == m2.changepoints).all())
|
||||
|
|
|
|||
Loading…
Reference in a new issue