From e8ddded4fec128191f08be83ae9b5acd7db6f563 Mon Sep 17 00:00:00 2001 From: Jireh Tan Date: Sat, 22 Jun 2019 10:46:59 -0700 Subject: [PATCH] [BUG] Ensure regressor/seasonality names are valid; fixes #996 Ensures that `add_regressor` and `add_seasonality` are valid column names to R, to ensure that the generated columns are then used downstream to fit the model. Why not put it in `validate_column_names`? Because `validate_column_names` is also used to validate if holiday names (which can be scalar values in columns) are valid. We want to allow `c('seans-bday', 'Xmas')` as a valid holiday input, so we cannot then put it there. Tested these changes by using devtools::testthat(). Resolves: #996 --- R/R/prophet.R | 10 ++++++++++ R/tests/testthat/test_prophet.R | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/R/R/prophet.R b/R/R/prophet.R index 9bb0deb..92ae3fa 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -689,6 +689,11 @@ add_regressor <- function( if (!is.null(m$history)) { stop('Regressors must be added prior to model fitting.') } + if (make.names(name, allow_ = TRUE) != name) { + stop("You have provided a name that is not syntactically valid in R, ", name, ". ", + "A syntactically valid name consists of letters, numbers and the dot or underline, ", + "characters and starts with a letter or the dot not followed by a number.") + } validate_column_name(m, name, check_regressors = FALSE) if (is.null(prior.scale)) { prior.scale <- m$holidays.prior.scale @@ -752,6 +757,11 @@ add_seasonality <- function( } if (!(name %in% c('daily', 'weekly', 'yearly'))) { # Allow overriding built-in seasonalities + if (make.names(name, allow_ = TRUE) != name) { + stop("You have provided a name that is not syntactically valid in R, ", name, ". ", + "A syntactically valid name consists of letters, numbers and the dot or underline, ", + "characters and starts with a letter or the dot not followed by a number.") + } validate_column_name(m, name, check_seasonalities = FALSE) } if (is.null(prior.scale)) { diff --git a/R/tests/testthat/test_prophet.R b/R/tests/testthat/test_prophet.R index 65d5a09..c1f895f 100644 --- a/R/tests/testthat/test_prophet.R +++ b/R/tests/testthat/test_prophet.R @@ -88,6 +88,18 @@ test_that("setup_dataframe", { expect_equal(max(history$y_scaled), 1) }) +test_that("setup_names_errors", { + m <- prophet() + expect_error( + m <- add_seasonality(m, "3monthly"), + "You have provided a name that is not syntactically valid in R, 3monthly" + ) + expect_error( + m <- add_regressor(m, "2monthsale"), + "You have provided a name that is not syntactically valid in R, 2monthsale" + ) +}) + test_that("logistic_floor", { skip_if_not(Sys.getenv('R_ARCH') != '/i386') skip_on_os('mac') # Resolves mysterious CRAN build issue