[BUG] Ensure regressor/seasonality names are valid; fixes #996

Ensures that `add_regressor` and `add_seasonality` are valid column names
to R, to ensure that the generated columns are then used downstream to fit
the model.

Why not put it in `validate_column_names`? Because `validate_column_names` is
also used to validate if holiday names (which can be scalar values in columns) are
valid. We want to allow `c('seans-bday', 'Xmas')` as a valid holiday input, so we
cannot then put it there.

Tested these changes by using devtools::testthat().

Resolves: #996
This commit is contained in:
Jireh Tan 2019-06-22 10:46:59 -07:00 committed by Ben Letham
parent 4225bb5fc1
commit e8ddded4fe
2 changed files with 22 additions and 0 deletions

View file

@ -689,6 +689,11 @@ add_regressor <- function(
if (!is.null(m$history)) {
stop('Regressors must be added prior to model fitting.')
}
if (make.names(name, allow_ = TRUE) != name) {
stop("You have provided a name that is not syntactically valid in R, ", name, ". ",
"A syntactically valid name consists of letters, numbers and the dot or underline, ",
"characters and starts with a letter or the dot not followed by a number.")
}
validate_column_name(m, name, check_regressors = FALSE)
if (is.null(prior.scale)) {
prior.scale <- m$holidays.prior.scale
@ -752,6 +757,11 @@ add_seasonality <- function(
}
if (!(name %in% c('daily', 'weekly', 'yearly'))) {
# Allow overriding built-in seasonalities
if (make.names(name, allow_ = TRUE) != name) {
stop("You have provided a name that is not syntactically valid in R, ", name, ". ",
"A syntactically valid name consists of letters, numbers and the dot or underline, ",
"characters and starts with a letter or the dot not followed by a number.")
}
validate_column_name(m, name, check_seasonalities = FALSE)
}
if (is.null(prior.scale)) {

View file

@ -88,6 +88,18 @@ test_that("setup_dataframe", {
expect_equal(max(history$y_scaled), 1)
})
test_that("setup_names_errors", {
m <- prophet()
expect_error(
m <- add_seasonality(m, "3monthly"),
"You have provided a name that is not syntactically valid in R, 3monthly"
)
expect_error(
m <- add_regressor(m, "2monthsale"),
"You have provided a name that is not syntactically valid in R, 2monthsale"
)
})
test_that("logistic_floor", {
skip_if_not(Sys.getenv('R_ARCH') != '/i386')
skip_on_os('mac') # Resolves mysterious CRAN build issue