From fcbd957bcc638862f4665fc6afbb01181d29a82c Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Wed, 8 Mar 2017 17:38:56 +0200 Subject: [PATCH] Refactor R component plotting to match #84 --- R/R/prophet.R | 205 ++++++++++++++++++++++++----------------- R/man/plot_holidays.Rd | 22 +++++ R/man/plot_trend.Rd | 20 ++++ R/man/plot_weekly.Rd | 20 ++++ R/man/plot_yearly.Rd | 20 ++++ 5 files changed, 204 insertions(+), 83 deletions(-) create mode 100644 R/man/plot_holidays.Rd create mode 100644 R/man/plot_trend.Rd create mode 100644 R/man/plot_weekly.Rd create mode 100644 R/man/plot_yearly.Rd diff --git a/R/R/prophet.R b/R/R/prophet.R index 270950d..92cbe05 100644 --- a/R/R/prophet.R +++ b/R/R/prophet.R @@ -873,7 +873,6 @@ df_for_plotting <- function(m, fcst) { plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds', ylabel = 'y', ...) { df <- df_for_plotting(x, fcst) - forecast.color <- "#0072B2" gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) + ggplot2::labs(x = xlabel, y = ylabel) if (exists('cap', where = df)) { @@ -884,12 +883,12 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds', gg <- gg + ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper), alpha = 0.2, - fill = forecast.color, + fill = "#0072B2", na.rm = TRUE) } gg <- gg + ggplot2::geom_point(na.rm=TRUE) + - ggplot2::geom_line(ggplot2::aes(y = yhat), color = forecast.color, + ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2", na.rm = TRUE) + ggplot2::theme(aspect.ratio = 3 / 5) return(gg) @@ -908,95 +907,19 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds', #' @importFrom dplyr "%>%" prophet_plot_components <- function(m, fcst, uncertainty = TRUE) { df <- df_for_plotting(m, fcst) - forecast.color <- "#0072B2" # Plot the trend - gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) + - ggplot2::geom_line(color = forecast.color, na.rm = TRUE) - if (exists('cap', where = df)) { - gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap), - linetype = 'dashed', - na.rm = TRUE) - } - if (uncertainty) { - gg.trend <- gg.trend + - ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower, - ymax = trend_upper), - alpha = 0.2, - fill = forecast.color, - na.rm = TRUE) - } - panels <- list(gg.trend) + panels <- list(plot_trend(df, uncertainty)) # Plot holiday components, if present. if (!is.null(m$holidays)) { - holiday.comps <- unique(m$holidays$holiday) %>% as.character() - df.s <- data.frame(ds = df$ds, - holidays = rowSums(df[, holiday.comps, drop = FALSE]), - holidays_lower = rowSums(df[, paste0(holiday.comps, - "_lower"), drop = FALSE]), - holidays_upper = rowSums(df[, paste0(holiday.comps, - "_upper"), drop = FALSE])) - # NOTE the above CI calculation is incorrect if holidays overlap in time. - # Since it is just for the visualization we will not worry about it now. - gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) + - ggplot2::geom_line(color = forecast.color, na.rm = TRUE) - if (uncertainty) { - gg.holidays <- gg.holidays + - ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower, - ymax = holidays_upper), - alpha = 0.2, - fill = forecast.color, - na.rm = TRUE) - } - panels[[length(panels) + 1]] <- gg.holidays + panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty) } # Plot weekly seasonality, if present if ("weekly" %in% colnames(df)) { - # Get weekday names in current locale - days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7)) - df.s <- df %>% - dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>% - dplyr::group_by(dow) %>% - dplyr::slice(1) %>% - dplyr::ungroup() %>% - dplyr::arrange(dow) - gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly, - group = 1)) + - ggplot2::geom_line(color = forecast.color, na.rm = TRUE) + - ggplot2::labs(x = "Day of week") - if (uncertainty) { - gg.weekly <- gg.weekly + - ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower, - ymax = weekly_upper), - alpha = 0.2, - fill = forecast.color, - na.rm = TRUE) - } - panels[[length(panels) + 1]] <- gg.weekly + panels[[length(panels) + 1]] <- plot_weekly(df, uncertainty) } # Plot yearly seasonality, if present if ("yearly" %in% colnames(df)) { - # Drop year from the dates - df.s <- df %>% - dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>% - dplyr::group_by(doy) %>% - dplyr::slice(1) %>% - dplyr::ungroup() %>% - dplyr::mutate(doy = zoo::as.Date(doy)) %>% - dplyr::arrange(doy) - gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly, - group = 1)) + - ggplot2::geom_line(color = forecast.color, na.rm = TRUE) + - ggplot2::scale_x_date(labels = scales::date_format('%B %d')) + - ggplot2::labs(x = "Day of year") - if (uncertainty) { - gg.yearly <- gg.yearly + - ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower, - ymax = yearly_upper), - alpha = 0.2, - fill = forecast.color, - na.rm = TRUE) - } - panels[[length(panels) + 1]] = gg.yearly + panels[[length(panels) + 1]] <- plot_yearly(df, uncertainty) } # Make the plot. grid::grid.newpage() @@ -1008,4 +931,120 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) { } } +#' Plot the prophet trend. +#' +#' @param df Forecast dataframe for plotting. +#' @param uncertainty Boolean to plot uncertainty intervals. +#' +#' @return A ggplot2 plot. +plot_trend <- function(df, uncertainty = TRUE) { + gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) + + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + if (exists('cap', where = df)) { + gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap), + linetype = 'dashed', + na.rm = TRUE) + } + if (uncertainty) { + gg.trend <- gg.trend + + ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower, + ymax = trend_upper), + alpha = 0.2, + fill = "#0072B2", + na.rm = TRUE) + } + return(gg.trend) +} + +#' Plot the holidays component of the forecast. +#' +#' @param m Prophet model +#' @param df Forecast dataframe for plotting. +#' @param uncertainty Boolean to plot uncertainty intervals. +#' +#' @return A ggplot2 plot. +plot_holidays <- function(m, df, uncertainty = TRUE) { + holiday.comps <- unique(m$holidays$holiday) %>% as.character() + df.s <- data.frame(ds = df$ds, + holidays = rowSums(df[, holiday.comps, drop = FALSE]), + holidays_lower = rowSums(df[, paste0(holiday.comps, + "_lower"), drop = FALSE]), + holidays_upper = rowSums(df[, paste0(holiday.comps, + "_upper"), drop = FALSE])) + # NOTE the above CI calculation is incorrect if holidays overlap in time. + # Since it is just for the visualization we will not worry about it now. + gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) + + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + if (uncertainty) { + gg.holidays <- gg.holidays + + ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower, + ymax = holidays_upper), + alpha = 0.2, + fill = "#0072B2", + na.rm = TRUE) + } + return(gg.holidays) +} + +#' Plot the weekly component of the forecast. +#' +#' @param df Forecast dataframe for plotting. +#' @param uncertainty Boolean to plot uncertainty intervals. +#' +#' @return A ggplot2 plot. +plot_weekly <- function(df, uncertainty = TRUE) { + # Get weekday names in current locale + days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7)) + df.s <- df %>% + dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>% + dplyr::group_by(dow) %>% + dplyr::slice(1) %>% + dplyr::ungroup() %>% + dplyr::arrange(dow) + gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly, + group = 1)) + + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + + ggplot2::labs(x = "Day of week") + if (uncertainty) { + gg.weekly <- gg.weekly + + ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower, + ymax = weekly_upper), + alpha = 0.2, + fill = "#0072B2", + na.rm = TRUE) + } + return(gg.weekly) +} + +#' Plot the yearly component of the forecast. +#' +#' @param df Forecast dataframe for plotting. +#' @param uncertainty Boolean to plot uncertainty intervals. +#' +#' @return A ggplot2 plot. +plot_yearly <- function(df, uncertainty = TRUE) { + # Drop year from the dates + df.s <- df %>% + dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>% + dplyr::group_by(doy) %>% + dplyr::slice(1) %>% + dplyr::ungroup() %>% + dplyr::mutate(doy = zoo::as.Date(doy)) %>% + dplyr::arrange(doy) + gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly, + group = 1)) + + ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) + + ggplot2::scale_x_date(labels = scales::date_format('%B %d')) + + ggplot2::labs(x = "Day of year") + if (uncertainty) { + gg.yearly <- gg.yearly + + ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower, + ymax = yearly_upper), + alpha = 0.2, + fill = "#0072B2", + na.rm = TRUE) + } + return(gg.yearly) +} + # fb-block 3 diff --git a/R/man/plot_holidays.Rd b/R/man/plot_holidays.Rd new file mode 100644 index 0000000..75420a7 --- /dev/null +++ b/R/man/plot_holidays.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prophet.R +\name{plot_holidays} +\alias{plot_holidays} +\title{Plot the holidays component of the forecast.} +\usage{ +plot_holidays(m, df, uncertainty = TRUE) +} +\arguments{ +\item{m}{Prophet model} + +\item{df}{Forecast dataframe for plotting.} + +\item{uncertainty}{Boolean to plot uncertainty intervals.} +} +\value{ +A ggplot2 plot. +} +\description{ +Plot the holidays component of the forecast. +} + diff --git a/R/man/plot_trend.Rd b/R/man/plot_trend.Rd new file mode 100644 index 0000000..19ecc78 --- /dev/null +++ b/R/man/plot_trend.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prophet.R +\name{plot_trend} +\alias{plot_trend} +\title{Plot the prophet trend.} +\usage{ +plot_trend(df, uncertainty = TRUE) +} +\arguments{ +\item{df}{Forecast dataframe for plotting.} + +\item{uncertainty}{Boolean to plot uncertainty intervals.} +} +\value{ +A ggplot2 plot. +} +\description{ +Plot the prophet trend. +} + diff --git a/R/man/plot_weekly.Rd b/R/man/plot_weekly.Rd new file mode 100644 index 0000000..dca10e6 --- /dev/null +++ b/R/man/plot_weekly.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prophet.R +\name{plot_weekly} +\alias{plot_weekly} +\title{Plot the weekly component of the forecast.} +\usage{ +plot_weekly(df, uncertainty = TRUE) +} +\arguments{ +\item{df}{Forecast dataframe for plotting.} + +\item{uncertainty}{Boolean to plot uncertainty intervals.} +} +\value{ +A ggplot2 plot. +} +\description{ +Plot the weekly component of the forecast. +} + diff --git a/R/man/plot_yearly.Rd b/R/man/plot_yearly.Rd new file mode 100644 index 0000000..c33f32d --- /dev/null +++ b/R/man/plot_yearly.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prophet.R +\name{plot_yearly} +\alias{plot_yearly} +\title{Plot the yearly component of the forecast.} +\usage{ +plot_yearly(df, uncertainty = TRUE) +} +\arguments{ +\item{df}{Forecast dataframe for plotting.} + +\item{uncertainty}{Boolean to plot uncertainty intervals.} +} +\value{ +A ggplot2 plot. +} +\description{ +Plot the yearly component of the forecast. +} +