From be2537209046a847c8cb8bc79b0a89245138bb7d Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Tue, 4 Feb 2020 13:22:08 -0800 Subject: [PATCH] improvements in docstrings and testing for disabling uncertainty --- R/R/diagnostics.R | 4 ++-- R/R/plot.R | 8 ++++---- R/tests/testthat/test_diagnostics.R | 2 ++ python/fbprophet/tests/test_diagnostics.py | 2 ++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/R/R/diagnostics.R b/R/R/diagnostics.R index 2a52f52..20b4bb8 100644 --- a/R/R/diagnostics.R +++ b/R/R/diagnostics.R @@ -242,8 +242,8 @@ performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) { if (is.null(metrics)) { metrics <- valid_metrics } - if (!('yhat_lower' %in% colnames(df)) | (!('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){ - metrics <- valid_metrics[valid_metrics != 'coverage'] + if ((!('yhat_lower' %in% colnames(df)) | !('yhat_upper' %in% colnames(df))) & ('coverage' %in% metrics)){ + metrics <- metrics[metrics != 'coverage'] } if (length(metrics) != length(unique(metrics))) { diff --git a/R/R/plot.R b/R/R/plot.R index 0c7b675..f7c3e43 100644 --- a/R/R/plot.R +++ b/R/R/plot.R @@ -167,7 +167,7 @@ prophet_plot_components <- function( #' @param fcst Dataframe output of `predict`. #' @param name String name of the component to plot (column of fcst). #' @param uncertainty Optional boolean to plot uncertainty intervals, which will -#' only be done if m$uncertainty.samples > 0. +#' only be done if m$uncertainty.samples > 0. #' @param plot_cap Boolean indicating if the capacity should be shown in the #' figure, if available. #' @@ -233,7 +233,7 @@ seasonality_plot_df <- function(m, ds) { #' #' @param m Prophet model object #' @param uncertainty Optional boolean to plot uncertainty intervals, which will -#' only be done if m$uncertainty.samples > 0. +#' only be done if m$uncertainty.samples > 0. #' @param weekly_start Integer specifying the start day of the weekly #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day #' to Monday, and so on. @@ -276,7 +276,7 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0, #' #' @param m Prophet model object. #' @param uncertainty Optional boolean to plot uncertainty intervals, which -#' will only be done if m$uncertainty.samples > 0. +#' will only be done if m$uncertainty.samples > 0. #' @param yearly_start Integer specifying the start day of the yearly #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day #' to Jan 2, and so on. @@ -321,7 +321,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0, #' @param m Prophet model object. #' @param name String name of the seasonality. #' @param uncertainty Optional boolean to plot uncertainty intervals, which -#' will only be done if m$uncertainty.samples > 0. +#' will only be done if m$uncertainty.samples > 0. #' #' @return A ggplot2 plot. #' diff --git a/R/tests/testthat/test_diagnostics.R b/R/tests/testthat/test_diagnostics.R index 9e8a9f6..ed3b0e1 100644 --- a/R/tests/testthat/test_diagnostics.R +++ b/R/tests/testthat/test_diagnostics.R @@ -99,6 +99,8 @@ test_that("cross_validation_uncertainty_disabled", { m, horizon = 4, units = "days", period = 4, initial = 115) expected.cols <- c('y', 'ds', 'yhat', 'cutoff') expect_equal(expected.cols, colnames(df.cv)) + df.p <- performance_metrics(df.cv) + expect_false('coverage' %in% colnames(df.p)) } }) diff --git a/python/fbprophet/tests/test_diagnostics.py b/python/fbprophet/tests/test_diagnostics.py index 07402c8..5c760bd 100644 --- a/python/fbprophet/tests/test_diagnostics.py +++ b/python/fbprophet/tests/test_diagnostics.py @@ -114,6 +114,8 @@ class TestDiagnostics(TestCase): m, horizon='4 days', period='4 days', initial='115 days') expected_cols = ['ds', 'yhat', 'y', 'cutoff'] self.assertTrue(all(col in expected_cols for col in df_cv.columns.tolist())) + df_p = diagnostics.performance_metrics(df_cv) + self.assertTrue('coverage' not in df_p.columns) def test_performance_metrics(self): m = Prophet()