From d8dd2f53b73d346480929bacec5e2b50d600c2c0 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Wed, 4 Mar 2020 16:01:02 -0800 Subject: [PATCH] Add progress bar to cross_validation (#1338) * tqdm * Added progress bar to the crossvalidation In order to improve the user experiance a progress bar is added to the crossvalidation loop. * Update requirements.txt * Update python/fbprophet/diagnostics.py * updated further * Update requirements.txt --- python/fbprophet/diagnostics.py | 5 +++-- python/requirements.txt | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/fbprophet/diagnostics.py b/python/fbprophet/diagnostics.py index 1461d0f..93a6421 100644 --- a/python/fbprophet/diagnostics.py +++ b/python/fbprophet/diagnostics.py @@ -7,6 +7,7 @@ from __future__ import absolute_import, division, print_function import logging +from tqdm.autonotebook import tqdm from copy import deepcopy from functools import reduce @@ -54,7 +55,7 @@ def generate_cutoffs(df, horizon, initial, period): logger.info('Making {} forecasts with cutoffs between {} and {}'.format( len(result), result[-1], result[0] )) - return reversed(result) + return list(reversed(result)) def cross_validation(model, horizon, period=None, initial=None): @@ -107,7 +108,7 @@ def cross_validation(model, horizon, period=None, initial=None): cutoffs = generate_cutoffs(df, horizon, initial, period) predicts = [] - for cutoff in cutoffs: + for cutoff in tqdm(cutoffs): # Generate new object with copying fitting options m = prophet_copy(model, cutoff) # Train model diff --git a/python/requirements.txt b/python/requirements.txt index baf7893..4b5fd33 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -9,3 +9,4 @@ convertdate>=2.1.2 holidays>=0.9.5 setuptools-git>=1.2 python-dateutil>=2.8.0 +tqdm>=4.42.1