Add cross-validation functions in R

This commit is contained in:
bletham 2017-08-26 14:31:33 -07:00
parent 509666d1d2
commit 3c09448018
14 changed files with 450 additions and 16 deletions

View file

@ -3,10 +3,12 @@
S3method(plot,prophet)
S3method(predict,prophet)
export(add_seasonality)
export(cross_validation)
export(fit.prophet)
export(make_future_dataframe)
export(predictive_samples)
export(prophet)
export(prophet_plot_components)
export(simulated_historical_forecasts)
import(Rcpp)
importFrom(dplyr,"%>%")

132
R/R/diagnostics.R Normal file
View file

@ -0,0 +1,132 @@
## Copyright (c) 2017-present, Facebook, Inc.
## All rights reserved.
## This source code is licensed under the BSD-style license found in the
## LICENSE file in the root directory of this source tree. An additional grant
## of patent rights can be found in the PATENTS file in the same directory.
## Makes R CMD CHECK happy due to dplyr syntax below
globalVariables(c(
"ds", "y", "cap", "yhat", "yhat_lower", "yhat_upper"))
#' Generate cutoff dates
#'
#' @param df Dataframe with historical data
#' @param horizon timediff forecast horizon
#' @param k integer number of forecast points
#' @param period timediff Simulated forecasts are done with this period.
#'
#' @return Array of datetimes
#'
#' @keywords internal
generate_cutoffs <- function(df, horizon, k, period) {
# Last cutoff is (latest date in data) - (horizon).
cutoff <- max(df$ds) - horizon
if (cutoff < min(df$ds)) {
stop('Less data than horizon.')
}
tzone <- attr(cutoff, "tzone") # Timezone is wiped by putting in array
result <- c(cutoff)
for (i in 2:k) {
cutoff <- cutoff - period
# If data does not exist in data range (cutoff, cutoff + horizon]
if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
# Next cutoff point is 'closest date before cutoff in data - horizon'
closest.date <- max(df$ds[df$ds <= cutoff])
cutoff <- closest.date - horizon
}
if (cutoff < min(df$ds)) {
warning('Not enough data for requested number of cutoffs! Using ', i)
break
}
result <- c(result, cutoff)
}
# Reset timezones
attr(result, "tzone") <- tzone
return(rev(result))
}
#' Simulated historical forecasts.
#' Make forecasts from k historical cutoff dates, and compare forecast values
#' to actual values.
#'
#' @param model Fitted Prophet model.
#' @param horizon Integer size of the horizon
#' @param units String unit of the horizon, e.g., "days", "secs".
#' @param k integer number of forecast points
#' @param period Integer amount of time between cutoff dates. Same units as
#' horizon. If not provided, will use 0.5 * horizon.
#'
#' @return A dataframe with the forecast, actual value, and cutoff date.
#'
#' @export
simulated_historical_forecasts <- function(model, horizon, units, k,
period = NULL) {
df <- model$history
horizon <- as.difftime(horizon, units = units)
if (is.null(period)) {
period <- horizon / 2
} else {
period <- as.difftime(period, units = units)
}
cutoffs <- generate_cutoffs(df, horizon, k, period)
predicts <- data.frame()
for (i in 1:length(cutoffs)) {
cutoff <- cutoffs[i]
# Copy the model
m <- prophet_copy(model, cutoff)
# Train model
history.c <- dplyr::filter(df, ds <= cutoff)
m <- fit.prophet(m, history.c)
# Calculate yhat
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon)
if (m$growth == 'logistic') {
future <- dplyr::select(df.predict, ds, cap)
} else{
future <- dplyr::select(df.predict, ds)
}
yhat <- stats::predict(m, future)
# Merge yhat, y, and cutoff.
df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
df.c$cutoff <- cutoff
predicts <- rbind(predicts, df.c)
}
return(predicts)
}
#' Cross-validation for time series.
#' Computes forecast error with cutoffs at the specified period. When the
#' period is the time interval of the data, is the procedure described in
#' https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
#' a cutoff every "period" amount of time, going back to "initial".
#'
#' @param model Fitted Prophet model.
#' @param horizon Integer size of the horizon
#' @param units String unit of the horizon, e.g., "days", "secs".
#' @param period Integer amount of time between cutoff dates. Same units as
#' horizon.
#' @param initial Integer size of the first training period. If not provided,
#' 3 * horizon is used. Same units as horizon.
#'
#' @return A dataframe with the forecast, actual value, and cutoff date.
#'
#' @export
cross_validation <- function(model, horizon, units, period, initial = NULL) {
te <- max(model$history$ds)
ts <- min(model$history$ds)
if (is.null(initial)) {
initial <- 3 * horizon
}
horizon.dt <- as.difftime(horizon, units = units)
initial.dt <- as.difftime(initial, units = units)
period.dt <- as.difftime(period, units = units)
k <- ceiling(
as.double((te - horizon.dt) - (ts + initial.dt), units='secs') /
as.double(period.dt, units = 'secs')
)
if (k < 1) {
stop('Not enough data for specified horizon and initial.')
}
return(simulated_historical_forecasts(model, horizon, units, k, period))
}

View file

@ -109,6 +109,7 @@ prophet <- function(df = NULL,
mcmc.samples = mcmc.samples,
interval.width = interval.width,
uncertainty.samples = uncertainty.samples,
specified.changepoints = !is.null(changepoints),
start = NULL, # This and following attributes are set during fitting
y.scale = NULL,
t.scale = NULL,
@ -240,6 +241,7 @@ set_date <- function(ds = NULL, tz = "GMT") {
} else {
ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
}
attr(ds, "tzone") <- tz
return(ds)
}
@ -1411,4 +1413,42 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
return(gg.s)
}
#' Copy Prophet object.
#'
#' @param m Prophet model object.
#' @param cutoff Date, possibly as string. Changepoints are only retained if
#' changepoints <= cutoff.
#'
#' @return An unfitted Prophet model object with the same parameters as the
#' input model.
#'
#' @keywords internal
prophet_copy <- function(m, cutoff = NULL) {
if (m$specified.changepoints) {
changepoints <- m$changepoints
if (!is.null(cutoff)) {
cutoff <- set_date(cutoff)
changepoints <- changepoints[changepoints <= cutoff]
}
} else {
changepoints <- NULL
}
return(prophet(
growth = m$growth,
changepoints = changepoints,
n.changepoints = m$n.changepoints,
yearly.seasonality = m$yearly.seasonality,
weekly.seasonality = m$weekly.seasonality,
daily.seasonality = m$daily.seasonality,
holidays = m$holidays,
seasonality.prior.scale = m$seasonality.prior.scale,
changepoint.prior.scale = m$changepoint.prior.scale,
holidays.prior.scale = m$holidays.prior.scale,
mcmc.samples = m$mcmc.samples,
interval.width = m$interval.width,
uncertainty.samples = m$uncertainty.samples,
fit = FALSE,
))
}
# fb-block 3

View file

@ -21,5 +21,6 @@ The prophet model with the seasonality added.
}
\description{
Increasing the number of Fourier components allows the seasonality to change
more quickly (at risk of overfitting).
more quickly (at risk of overfitting). Default values for yearly and weekly
seasonalities are 10 and 3 respectively.
}

35
R/man/cross_validation.Rd Normal file
View file

@ -0,0 +1,35 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{cross_validation}
\alias{cross_validation}
\title{Cross-validation for time series.
Computes forecast error with cutoffs at the specified period. When the
period is the time interval of the data, is the procedure described in
https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
a cutoff every "period" amount of time, going back to "initial".}
\usage{
cross_validation(model, horizon, units, period, initial = NULL)
}
\arguments{
\item{model}{Fitted Prophet model.}
\item{horizon}{Integer size of the horizon}
\item{units}{String unit of the horizon, e.g., "days", "secs".}
\item{period}{Integer amount of time between cutoff dates. Same units as
horizon.}
\item{initial}{Integer size of the first training period. If not provided,
3 * horizon is used. Same units as horizon.}
}
\value{
A dataframe with the forecast, actual value, and cutoff date.
}
\description{
Cross-validation for time series.
Computes forecast error with cutoffs at the specified period. When the
period is the time interval of the data, is the procedure described in
https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
a cutoff every "period" amount of time, going back to "initial".
}

24
R/man/generate_cutoffs.Rd Normal file
View file

@ -0,0 +1,24 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{generate_cutoffs}
\alias{generate_cutoffs}
\title{Generate cutoff dates}
\usage{
generate_cutoffs(df, horizon, k, period)
}
\arguments{
\item{df}{Dataframe with historical data}
\item{horizon}{timediff forecast horizon}
\item{k}{integer number of forecast points}
\item{period}{timediff Simulated forecasts are done with this period.}
}
\value{
Array of datetimes
}
\description{
Generate cutoff dates
}
\keyword{internal}

22
R/man/prophet_copy.Rd Normal file
View file

@ -0,0 +1,22 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/prophet.R
\name{prophet_copy}
\alias{prophet_copy}
\title{Copy Prophet object.}
\usage{
prophet_copy(m, cutoff = NULL)
}
\arguments{
\item{m}{Prophet model object.}
\item{cutoff}{Date, possibly as string. Changepoints are only retained if
changepoints <= cutoff.}
}
\value{
An unfitted Prophet model object with the same parameters as the
input model.
}
\description{
Copy Prophet object.
}
\keyword{internal}

View file

@ -0,0 +1,30 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/diagnostics.R
\name{simulated_historical_forecasts}
\alias{simulated_historical_forecasts}
\title{Simulated historical forecasts.
Make forecasts from k historical cutoff dates, and compare forecast values
to actual values.}
\usage{
simulated_historical_forecasts(model, horizon, units, k, period = NULL)
}
\arguments{
\item{model}{Fitted Prophet model.}
\item{horizon}{Integer size of the horizon}
\item{units}{String unit of the horizon, e.g., "days", "secs".}
\item{k}{integer number of forecast points}
\item{period}{Integer amount of time between cutoff dates. Same units as
horizon. If not provided, will use 0.5 * horizon.}
}
\value{
A dataframe with the forecast, actual value, and cutoff date.
}
\description{
Simulated historical forecasts.
Make forecasts from k historical cutoff dates, and compare forecast values
to actual values.
}

View file

@ -0,0 +1,86 @@
library(prophet)
context("Prophet diagnostics tests")
## Makes R CMD CHECK happy due to dplyr syntax below
globalVariables(c("y", "yhat"))
DATA <- head(read.csv('data.csv'), 100)
DATA$ds <- as.Date(DATA$ds)
test_that("simulated_historical_forecasts", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
m <- prophet(DATA)
k <- 2
for (p in c(1, 10)) {
for (h in c(1, 3)) {
df.shf <- simulated_historical_forecasts(
m, horizon = h, units = 'days', k = k, period = p)
# All cutoff dates should be less than ds dates
expect_true(all(df.shf$cutoff < df.shf$ds))
# The unique size of output cutoff should be equal to 'k'
expect_equal(length(unique(df.shf$cutoff)), k)
expect_equal(max(df.shf$ds - df.shf$cutoff),
as.difftime(h, units = 'days'))
dc <- diff(df.shf$cutoff)
dc <- min(dc[dc > 0])
expect_true(dc >= as.difftime(p, units = 'days'))
# Each y in df_shf and DATA with same ds should be equal
df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
}
}
})
test_that("simulated_historical_forecasts_logistic", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA
df$cap <- 40
m <- prophet(df, growth='logistic')
df.shf <- simulated_historical_forecasts(
m, horizon = 3, units = 'days', k = 2, period = 3)
# All cutoff dates should be less than ds dates
expect_true(all(df.shf$cutoff < df.shf$ds))
# The unique size of output cutoff should be equal to 'k'
expect_equal(length(unique(df.shf$cutoff)), 2)
# Each y in df_shf and DATA with same ds should be equal
df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
})
test_that("simulated_historical_forecasts_default_value_check", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
m <- prophet(DATA)
df.shf1 <- simulated_historical_forecasts(
m, horizon = 10, units = 'days', k = 1)
df.shf2 <- simulated_historical_forecasts(
m, horizon = 10, units = 'days', k = 1, period = 5)
expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
})
test_that("cross_validation", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
m <- prophet(DATA)
# Calculate the number of cutoff points
te <- max(DATA$ds)
ts <- min(DATA$ds)
horizon <- as.difftime(4, units = "days")
period <- as.difftime(10, units = "days")
k <- 5
df.cv <- cross_validation(
m, horizon = 4, units = "days", period = 10, initial = 90)
expect_equal(length(unique(df.cv$cutoff)), k)
expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
dc <- diff(df.cv$cutoff)
dc <- min(dc[dc > 0])
expect_true(dc >= period)
})
test_that("cross_validation_default_value_check", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
m <- prophet(DATA)
df.cv1 <- cross_validation(
m, horizon = 32, units = "days", period = 10)
df.cv2 <- cross_validation(
m, horizon = 32, units = 'days', period = 10, initial = 96)
expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
})

View file

@ -330,3 +330,57 @@ test_that("custom_seasonality", {
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
expect_equal(m$seasonalities[['monthly']], c(30, 5))
})
test_that("copy", {
inputs <- list(
growth = c('linear', 'logistic'),
changepoints = c(NULL, c('2016-12-25')),
n.changepoints = c(3),
yearly.seasonality = c(TRUE, FALSE),
weekly.seasonality = c(TRUE, FALSE),
daily.seasonality = c(TRUE, FALSE),
holidays = c(NULL, 'insert_dataframe'),
seasonality.prior.scale = c(1.1),
holidays.prior.scale = c(1.1),
changepoints.prior.scale = c(0.1),
mcmc.samples = c(100),
interval.width = c(0.9),
uncertainty.samples = c(200)
)
products <- expand.grid(inputs)
for (i in 1:length(products)) {
if (products$holidays[i] == 'insert_dataframe') {
holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
} else {
holidays <- NULL
}
m1 <- prophet(
growth = products$growth[i],
changepoints = products$changepoints[i],
n.changepoints = products$n.changepoints[i],
yearly.seasonality = products$yearly.seasonality[i],
weekly.seasonality = products$weekly.seasonality[i],
daily.seasonality = products$daily.seasonality[i],
holidays = holidays,
seasonality.prior.scale = products$seasonality.prior.scale[i],
holidays.prior.scale = products$holidays.prior.scale[i],
changepoints.prior.scale = products$changepoints.prior.scale[i],
mcmc.samples = products$mcmc.samples[i],
interval.width = products$interval.width[i],
uncertainty.samples = products$uncertainty.samples[i],
fit = FALSE
)
m2 <- prophet:::prophet_copy(m1)
# Values should be copied correctly
for (arg in names(inputs)) {
expect_equal(m1[[arg]], m2[[arg]])
}
}
# Check for cutoff
changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
cutoff <- as.Date('2012-07-25')
m1 <- prophet(DATA, changepoints = changepoints)
m2 <- prophet:::prophet_copy(m1, cutoff)
changepoints <- changepoints[changepoints <= cutoff]
expect_equal(prophet:::set_date(changepoints), m2$changepoints)
})

View file

@ -38,6 +38,8 @@ def _cutoffs(df, horizon, k, period):
"""
# Last cutoff is 'latest date in data - horizon' date
cutoff = df['ds'].max() - horizon
if cutoff < df['ds'].min():
raise ValueError('Less data than horizon.')
result = [cutoff]
for i in range(1, k):
@ -48,7 +50,7 @@ def _cutoffs(df, horizon, k, period):
closest_date = df[df['ds'] <= cutoff].max()['ds']
cutoff = closest_date - horizon
if cutoff < df['ds'].min():
logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(k))
logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(i))
break
result.append(cutoff)
@ -127,5 +129,7 @@ def cross_validation(model, horizon, period, initial=None):
horizon = pd.Timedelta(horizon)
period = pd.Timedelta(period)
initial = 3 * horizon if initial is None else pd.Timedelta(initial)
k = int(np.floor(((te - horizon) - (ts + initial)) / period))
k = int(np.ceil(((te - horizon) - (ts + initial)) / period))
if k < 1:
raise ValueError('Not enough data for specified horizon and initial.')
return simulated_historical_forecasts(model, horizon, k, period)

View file

@ -100,8 +100,10 @@ class Prophet(object):
self.changepoints = pd.to_datetime(changepoints)
if self.changepoints is not None:
self.n_changepoints = len(self.changepoints)
self.specified_changepoints = True
else:
self.n_changepoints = n_changepoints
self.specified_changepoints = False
self.yearly_seasonality = yearly_seasonality
self.weekly_seasonality = weekly_seasonality
@ -1420,21 +1422,24 @@ class Prophet(object):
----------
cutoff: pd.Timestamp or None, default None.
cuttoff Timestamp for changepoints member variable.
changepoints are only remained if 'changepoints <= cutoff'
changepoints are only retained if 'changepoints <= cutoff'
Returns
-------
Prophet class object with the same parameter with model variable
"""
if self.changepoints is not None and cutoff is not None:
# Filter change points '<= cutoff'
self.changepoints = self.changepoints[self.changepoints <= cutoff]
self.n_changepoints = len(self.changepoints)
if self.specified_changepoints:
changepoints = self.changepoints
if cutoff is not None:
# Filter change points '<= cutoff'
changepoints = changepoints[changepoints <= cutoff]
else:
changepoints = None
return Prophet(
growth=self.growth,
n_changepoints=self.n_changepoints,
changepoints=self.changepoints,
changepoints=changepoints,
yearly_seasonality=self.yearly_seasonality,
weekly_seasonality=self.weekly_seasonality,
daily_seasonality=self.daily_seasonality,
@ -1446,5 +1451,3 @@ class Prophet(object):
interval_width=self.interval_width,
uncertainty_samples=self.uncertainty_samples
)

View file

@ -77,9 +77,9 @@ class TestDiagnostics(TestCase):
ts = self.__df['ds'].min()
horizon = pd.Timedelta('4 days')
period = pd.Timedelta('10 days')
initial = pd.Timedelta('90 days')
k = int(np.floor(((te - horizon) - (ts + initial)) / period))
df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
k = 5
df_cv = diagnostics.cross_validation(
m, horizon='4 days', period='10 days', initial='90 days')
# The unique size of output cutoff should be equal to 'k'
self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)

View file

@ -490,9 +490,10 @@ class TestProphet(TestCase):
self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
# Check for cutoff
changepoints = pd.date_range('2016-12-15', '2017-01-15')
cutoff = pd.Timestamp('2016-12-25')
changepoints = pd.date_range('2012-06-15', '2012-09-15')
cutoff = pd.Timestamp('2012-07-25')
m1 = Prophet(changepoints=changepoints)
m1.fit(DATA)
m2 = m1.copy(cutoff=cutoff)
changepoints = changepoints[changepoints <= cutoff]
self.assertTrue((changepoints == m2.changepoints).all())