mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-17 21:10:45 +00:00
182 lines
6.6 KiB
R
182 lines
6.6 KiB
R
library(prophet)
|
|
context("Prophet diagnostics tests")
|
|
|
|
## Makes R CMD CHECK happy due to dplyr syntax below
|
|
globalVariables(c("y", "yhat"))
|
|
|
|
DATA_all <- read.csv('data.csv')
|
|
DATA_all$ds <- as.Date(DATA_all$ds)
|
|
DATA <- head(DATA_all, 100)
|
|
|
|
test_that("cross_validation", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
m <- prophet(DATA)
|
|
# Calculate the number of cutoff points
|
|
te <- max(DATA$ds)
|
|
ts <- min(DATA$ds)
|
|
horizon <- as.difftime(4, units = "days")
|
|
period <- as.difftime(10, units = "days")
|
|
initial <- as.difftime(115, units = "days")
|
|
df.cv <- cross_validation(
|
|
m, horizon = 4, units = "days", period = 10, initial = 115)
|
|
expect_equal(length(unique(df.cv$cutoff)), 3)
|
|
expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
|
|
expect_true(min(df.cv$cutoff) >= ts + initial)
|
|
dc <- diff(df.cv$cutoff)
|
|
dc <- min(dc[dc > 0])
|
|
expect_true(dc >= period)
|
|
expect_true(all(df.cv$cutoff < df.cv$ds))
|
|
# Each y in df.cv and DATA with same ds should be equal
|
|
df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
|
|
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
|
df.cv <- cross_validation(
|
|
m, horizon = 4, units = "days", period = 10, initial = 135)
|
|
expect_equal(length(unique(df.cv$cutoff)), 1)
|
|
expect_error(
|
|
cross_validation(
|
|
m, horizon = 10, units = "days", period = 10, initial = 140)
|
|
)
|
|
})
|
|
|
|
test_that("cross_validation_logistic", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
df <- DATA
|
|
df$cap <- 40
|
|
m <- prophet(df, growth = 'logistic')
|
|
df.cv <- cross_validation(
|
|
m, horizon = 1, units = "days", period = 1, initial = 140)
|
|
expect_equal(length(unique(df.cv$cutoff)), 2)
|
|
expect_true(all(df.cv$cutoff < df.cv$ds))
|
|
df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
|
|
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
|
})
|
|
|
|
test_that("cross_validation_extra_regressors", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
df <- DATA
|
|
df$extra <- seq(0, nrow(df) - 1)
|
|
m <- prophet()
|
|
m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5)
|
|
m <- add_regressor(m, 'extra')
|
|
m <- fit.prophet(m, df)
|
|
df.cv <- cross_validation(
|
|
m, horizon = 4, units = "days", period = 4, initial = 135)
|
|
expect_equal(length(unique(df.cv$cutoff)), 2)
|
|
period <- as.difftime(4, units = "days")
|
|
dc <- diff(df.cv$cutoff)
|
|
dc <- min(dc[dc > 0])
|
|
expect_true(dc >= period)
|
|
expect_true(all(df.cv$cutoff < df.cv$ds))
|
|
df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
|
|
expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
|
|
})
|
|
|
|
test_that("cross_validation_default_value_check", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
m <- prophet(DATA)
|
|
df.cv1 <- cross_validation(
|
|
m, horizon = 32, units = "days", period = 10)
|
|
df.cv2 <- cross_validation(
|
|
m, horizon = 32, units = 'days', period = 10, initial = 96)
|
|
expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
|
|
})
|
|
|
|
test_that("performance_metrics", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
m <- prophet(DATA)
|
|
df_cv <- cross_validation(
|
|
m, horizon = 4, units = "days", period = 10, initial = 90)
|
|
# Aggregation level none
|
|
df_none <- performance_metrics(df_cv, rolling_window = 0)
|
|
expect_true(all(
|
|
sort(colnames(df_none))
|
|
== sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
|
|
))
|
|
expect_equal(nrow(df_none), 16)
|
|
# Aggregation level 0.2
|
|
df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
|
|
expect_equal(length(unique(df_horizon$horizon)), 4)
|
|
expect_equal(nrow(df_horizon), 14)
|
|
# Aggregation level all
|
|
df_all <- performance_metrics(df_cv, rolling_window = 1)
|
|
expect_equal(nrow(df_all), 1)
|
|
for (metric in c('mse', 'mape', 'mae', 'coverage')) {
|
|
expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
|
|
}
|
|
# Custom list of metrics
|
|
df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
|
|
expect_true(all(
|
|
sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
|
|
))
|
|
})
|
|
|
|
test_that("copy", {
|
|
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
|
df <- DATA_all
|
|
df$cap <- 200.
|
|
df$binary_feature <- c(rep(0, 255), rep(1, 255))
|
|
inputs <- list(
|
|
growth = c('linear', 'logistic'),
|
|
yearly.seasonality = c(TRUE, FALSE),
|
|
weekly.seasonality = c(TRUE, FALSE),
|
|
daily.seasonality = c(TRUE, FALSE),
|
|
holidays = c('null', 'insert_dataframe'),
|
|
seasonality.mode = c('additive', 'multiplicative')
|
|
)
|
|
products <- expand.grid(inputs)
|
|
for (i in 1:length(products)) {
|
|
if (products$holidays[i] == 'insert_dataframe') {
|
|
holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
|
|
} else {
|
|
holidays <- NULL
|
|
}
|
|
m1 <- prophet(
|
|
growth = as.character(products$growth[i]),
|
|
changepoints = NULL,
|
|
n.changepoints = 3,
|
|
changepoint.range = 0.9,
|
|
yearly.seasonality = products$yearly.seasonality[i],
|
|
weekly.seasonality = products$weekly.seasonality[i],
|
|
daily.seasonality = products$daily.seasonality[i],
|
|
holidays = holidays,
|
|
seasonality.prior.scale = 1.1,
|
|
holidays.prior.scale = 1.1,
|
|
changepoints.prior.scale = 0.1,
|
|
mcmc.samples = 100,
|
|
interval.width = 0.9,
|
|
uncertainty.samples = 200,
|
|
fit = FALSE
|
|
)
|
|
out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
|
|
m1 <- out$m
|
|
m1$history <- out$df
|
|
m1 <- prophet:::set_auto_seasonalities(m1)
|
|
m2 <- prophet:::prophet_copy(m1)
|
|
# Values should be copied correctly
|
|
args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
|
|
'seasonality.prior.scale', 'holidays.prior.scale',
|
|
'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
|
|
'uncertainty.samples', 'seasonality.mode', 'changepoint.range')
|
|
for (arg in args) {
|
|
expect_equal(m1[[arg]], m2[[arg]])
|
|
}
|
|
expect_equal(FALSE, m2$yearly.seasonality)
|
|
expect_equal(FALSE, m2$weekly.seasonality)
|
|
expect_equal(FALSE, m2$daily.seasonality)
|
|
expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
|
|
expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
|
|
expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
|
|
}
|
|
# Check for cutoff and custom seasonality and extra regressors
|
|
changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
|
|
cutoff <- as.Date('2012-07-25')
|
|
m1 <- prophet(changepoints = changepoints)
|
|
m1 <- add_seasonality(m1, 'custom', 10, 5)
|
|
m1 <- add_regressor(m1, 'binary_feature')
|
|
m1 <- fit.prophet(m1, df)
|
|
m2 <- prophet:::prophet_copy(m1, cutoff)
|
|
changepoints <- changepoints[changepoints <= cutoff]
|
|
expect_equal(prophet:::set_date(changepoints), m2$changepoints)
|
|
expect_true('custom' %in% names(m2$seasonalities))
|
|
expect_true('binary_feature' %in% names(m2$extra_regressors))
|
|
})
|