mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-22 22:01:14 +00:00
R unit tests for added regressors
This commit is contained in:
parent
57c97f2e5e
commit
23d8bc25dc
2 changed files with 69 additions and 2 deletions
|
|
@ -625,8 +625,8 @@ make_all_seasonality_features <- function(m, df) {
|
|||
|
||||
# Additional regressors
|
||||
for (name in names(m$extra_regressors)) {
|
||||
seasonal.features <- cbind(seasonal.features, df[[name]])
|
||||
prior.scales <- cbind(prior.scales, m$extra_regressors[[name]]$prior.scale)
|
||||
seasonal.features[[name]] <- df[[name]]
|
||||
prior.scales <- c(prior.scales, m$extra_regressors[[name]]$prior.scale)
|
||||
}
|
||||
|
||||
if (ncol(seasonal.features) == 0) {
|
||||
|
|
|
|||
|
|
@ -329,6 +329,73 @@ test_that("custom_seasonality", {
|
|||
m <- prophet(holidays=holidays)
|
||||
m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
|
||||
expect_equal(m$seasonalities[['monthly']], c(30, 5))
|
||||
expect_error(
|
||||
add_seasonality(m, name='special_day', period=30, fourier_order=5)
|
||||
)
|
||||
expect_error(
|
||||
add_seasonality(m, name='trend', period=30, fourier_order=5)
|
||||
)
|
||||
m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
|
||||
})
|
||||
|
||||
test_that("added_regressors", {
|
||||
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
|
||||
m <- prophet()
|
||||
m <- add_regressor(m, 'binary_feature', prior.scale=0.2)
|
||||
m <- add_regressor(m, 'numeric_feature', prior.scale=0.5)
|
||||
m <- add_regressor(m, 'binary_feature2', standardize=TRUE)
|
||||
df <- DATA
|
||||
df$binary_feature <- c(rep(0, 255), rep(1, 255))
|
||||
df$numeric_feature <- 0:509
|
||||
# Require all regressors in df
|
||||
expect_error(
|
||||
fit.prophet(m, df)
|
||||
)
|
||||
df$binary_feature2 <- c(rep(1, 100), rep(0, 410))
|
||||
m <- fit.prophet(m, df)
|
||||
# Check that standardizations are correctly set
|
||||
true <- list(prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto')
|
||||
for (name in names(true)) {
|
||||
expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]])
|
||||
}
|
||||
true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585)
|
||||
for (name in names(true)) {
|
||||
expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]],
|
||||
tolerance = 1e-5)
|
||||
}
|
||||
true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183)
|
||||
for (name in names(true)) {
|
||||
expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]],
|
||||
tolerance = 1e-5)
|
||||
}
|
||||
# Check that standardization is done correctly
|
||||
df2 <- prophet:::setup_dataframe(m, df)$df
|
||||
expect_equal(df2$binary_feature[1], 0)
|
||||
expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4)
|
||||
expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4)
|
||||
# Check that feature matrix and prior scales are correctly constructed
|
||||
out <- prophet:::make_all_seasonality_features(m, df2)
|
||||
seasonal.features <- out$seasonal.features
|
||||
prior.scales <- out$prior.scales
|
||||
expect_true('binary_feature' %in% colnames(seasonal.features))
|
||||
expect_true('numeric_feature' %in% colnames(seasonal.features))
|
||||
expect_true('binary_feature2' %in% colnames(seasonal.features))
|
||||
expect_equal(ncol(seasonal.features), 29)
|
||||
expect_true(all(sort(prior.scales[27:29]) == c(0.2, 0.5, 10.)))
|
||||
# Check that forecast components are reasonable
|
||||
future <- data.frame(
|
||||
ds = c('2014-06-01'), binary_feature = c(0), numeric_feature = c(10))
|
||||
expect_error(predict(m, future))
|
||||
future$binary_feature2 <- 0.
|
||||
fcst <- predict(m, future)
|
||||
expect_equal(ncol(fcst), 31)
|
||||
expect_equal(fcst$binary_feature[1], 0)
|
||||
expect_equal(fcst$extra_regressors[1],
|
||||
fcst$numeric_feature[1] + fcst$binary_feature2[1])
|
||||
expect_equal(fcst$seasonalities[1], fcst$yearly[1] + fcst$weekly[1])
|
||||
expect_equal(fcst$seasonal[1],
|
||||
fcst$seasonalities[1] + fcst$extra_regressors[1])
|
||||
expect_equal(fcst$yhat[1], fcst$trend[1] + fcst$seasonal[1])
|
||||
})
|
||||
|
||||
test_that("copy", {
|
||||
|
|
|
|||
Loading…
Reference in a new issue