Fix copy with extra seasonalities / regressors R

This commit is contained in:
bletham 2017-11-04 21:40:45 -07:00
parent 5dbffbaa18
commit 0addabcad7
2 changed files with 50 additions and 27 deletions

View file

@ -1701,6 +1701,10 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
#'
#' @keywords internal
prophet_copy <- function(m, cutoff = NULL) {
if (is.null(m$history)) {
stop("This is for copying a fitted Prophet object.")
}
if (m$specified.changepoints) {
changepoints <- m$changepoints
if (!is.null(cutoff)) {
@ -1710,13 +1714,15 @@ prophet_copy <- function(m, cutoff = NULL) {
} else {
changepoints <- NULL
}
return(prophet(
# Auto seasonalities are set to FALSE because they are already set in
# m$seasonalities.
m2 <- prophet(
growth = m$growth,
changepoints = changepoints,
n.changepoints = m$n.changepoints,
yearly.seasonality = m$yearly.seasonality,
weekly.seasonality = m$weekly.seasonality,
daily.seasonality = m$daily.seasonality,
yearly.seasonality = FALSE,
weekly.seasonality = FALSE,
daily.seasonality = FALSE,
holidays = m$holidays,
seasonality.prior.scale = m$seasonality.prior.scale,
changepoint.prior.scale = m$changepoint.prior.scale,
@ -1724,8 +1730,11 @@ prophet_copy <- function(m, cutoff = NULL) {
mcmc.samples = m$mcmc.samples,
interval.width = m$interval.width,
uncertainty.samples = m$uncertainty.samples,
fit = FALSE,
))
fit = FALSE
)
m2$extra_regressors <- m$extra_regressors
m2$seasonalities <- m$seasonalities
return(m2)
}
# fb-block 3

View file

@ -520,20 +520,15 @@ test_that("added_regressors", {
test_that("copy", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
df <- DATA
df$cap <- 200.
df$binary_feature <- c(rep(0, 255), rep(1, 255))
inputs <- list(
growth = c('linear', 'logistic'),
changepoints = c(NULL, c('2016-12-25')),
n.changepoints = c(3),
yearly.seasonality = c(TRUE, FALSE),
weekly.seasonality = c(TRUE, FALSE),
daily.seasonality = c(TRUE, FALSE),
holidays = c(NULL, 'insert_dataframe'),
seasonality.prior.scale = c(1.1),
holidays.prior.scale = c(1.1),
changepoints.prior.scale = c(0.1),
mcmc.samples = c(100),
interval.width = c(0.9),
uncertainty.samples = c(200)
holidays = c('null', 'insert_dataframe')
)
products <- expand.grid(inputs)
for (i in 1:length(products)) {
@ -543,32 +538,51 @@ test_that("copy", {
holidays <- NULL
}
m1 <- prophet(
growth = products$growth[i],
changepoints = products$changepoints[i],
n.changepoints = products$n.changepoints[i],
growth = as.character(products$growth[i]),
changepoints = NULL,
n.changepoints = 3,
yearly.seasonality = products$yearly.seasonality[i],
weekly.seasonality = products$weekly.seasonality[i],
daily.seasonality = products$daily.seasonality[i],
holidays = holidays,
seasonality.prior.scale = products$seasonality.prior.scale[i],
holidays.prior.scale = products$holidays.prior.scale[i],
changepoints.prior.scale = products$changepoints.prior.scale[i],
mcmc.samples = products$mcmc.samples[i],
interval.width = products$interval.width[i],
uncertainty.samples = products$uncertainty.samples[i],
seasonality.prior.scale = 1.1,
holidays.prior.scale = 1.1,
changepoints.prior.scale = 0.1,
mcmc.samples = 100,
interval.width = 0.9,
uncertainty.samples = 200,
fit = FALSE
)
out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
m1 <- out$m
m1$history <- out$df
m1 <- prophet:::set_auto_seasonalities(m1)
m2 <- prophet:::prophet_copy(m1)
# Values should be copied correctly
for (arg in names(inputs)) {
args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
'seasonality.prior.scale', 'holidays.prior.scale',
'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
'uncertainty.samples')
for (arg in args) {
expect_equal(m1[[arg]], m2[[arg]])
}
expect_equal(FALSE, m2$yearly.seasonality)
expect_equal(FALSE, m2$weekly.seasonality)
expect_equal(FALSE, m2$daily.seasonality)
expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
}
# Check for cutoff
# Check for cutoff and custom seasonality and extra regressors
changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
cutoff <- as.Date('2012-07-25')
m1 <- prophet(DATA, changepoints = changepoints)
m1 <- prophet(changepoints = changepoints)
m1 <- add_seasonality(m1, 'custom', 10, 5)
m1 <- add_regressor(m1, 'binary_feature')
m1 <- fit.prophet(m1, df)
m2 <- prophet:::prophet_copy(m1, cutoff)
changepoints <- changepoints[changepoints <= cutoff]
expect_equal(prophet:::set_date(changepoints), m2$changepoints)
expect_true('custom' %in% names(m2$seasonalities))
expect_true('binary_feature' %in% names(m2$extra_regressors))
})