mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-05-18 21:21:22 +00:00
Parallel Cross Validation (#1434)
* API: Refactor to parallel / cf * Added Dask-based parallelism * fix test * cover bad parallel * avoid multiprocess issue under setuptools tests * Update notebook docs * fix docstring * install note * arbitrary object * fixups * fixups * fixups * remove unused import
This commit is contained in:
parent
59e374b1ad
commit
5fe3be86c5
4 changed files with 106 additions and 17 deletions
|
|
@ -5,7 +5,7 @@ python:
|
|||
|
||||
install:
|
||||
- pip install --upgrade pip
|
||||
- pip install -U -r python/requirements.txt
|
||||
- pip install -U -r python/requirements.txt dask[dataframe] distributed
|
||||
|
||||
script:
|
||||
- cd python && python setup.py develop test
|
||||
|
|
|
|||
|
|
@ -263,7 +263,22 @@
|
|||
"source": [
|
||||
"In R, the argument `units` must be a type accepted by `as.difftime`, which is weeks or shorter. In Python, the string for `initial`, `period`, and `horizon` should be in the format used by Pandas Timedelta, which accepts units of days or shorter.\n",
|
||||
"\n",
|
||||
"Cross-validation can also be run in multiprocessing mode in Python, by setting the `multiprocess` argument to `True`\n",
|
||||
"Cross-validation can also be run in parallel mode in Python, by setting specifying the `parallel` keyword. Three modes are supported\n",
|
||||
"\n",
|
||||
"* `parallel=\"processes\"`\n",
|
||||
"* `parallel=\"threads\"`\n",
|
||||
"* `parallel=\"dask\"`\n",
|
||||
"\n",
|
||||
"For problems that aren't too big, we recommend using `parallel=\"processes\"`. It will achieve the highest performance when the parallel cross validation can be done on a single machine. For large problems, a [Dask](https://dask.org) cluster can be used to do the cross validation on many machines. You will need to [install Dask](https://docs.dask.org/en/latest/install.html) separately, as it will not be installed with `fbprophet`.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from dask.distributed import Client\n",
|
||||
"\n",
|
||||
"client = Client() # connect to the cluster\n",
|
||||
"df_cv = cross_validation(m, initial='730 days', period='180 days', horizon='365 days',\n",
|
||||
" parallel=\"dask\")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument."
|
||||
]
|
||||
|
|
@ -475,7 +490,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -9,11 +9,10 @@ from __future__ import absolute_import, division, print_function
|
|||
import logging
|
||||
from tqdm.autonotebook import tqdm
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
import concurrent.futures
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
|
||||
logger = logging.getLogger('fbprophet')
|
||||
|
||||
|
|
@ -59,7 +58,7 @@ def generate_cutoffs(df, horizon, initial, period):
|
|||
return list(reversed(result))
|
||||
|
||||
|
||||
def cross_validation(model, horizon, period=None, initial=None, multiprocess=False, cutoffs=None):
|
||||
def cross_validation(model, horizon, period=None, initial=None, parallel=None, cutoffs=None):
|
||||
"""Cross-Validation for time series.
|
||||
|
||||
Computes forecasts from historical cutoff points, which user can input.
|
||||
|
|
@ -82,9 +81,33 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
|
|||
cross-validtation. If not provided works beginning from
|
||||
(end - horizon), works backwards making cutoffs with a spacing of period
|
||||
until initial is reached.
|
||||
multiprocess: True, False, Optional (defaults to False). If `True`, use the
|
||||
`multiprocessing` module to distribute each task to a different processor
|
||||
core.
|
||||
parallel : {None, 'processes', 'threads', 'dask', object}
|
||||
|
||||
How to parallelize the forecast computation. By default no parallelism
|
||||
is used.
|
||||
|
||||
* None : No parallelism.
|
||||
* 'processes' : Parallelize with concurrent.futures.ProcessPoolExectuor.
|
||||
* 'threads' : Parallelize with concurrent.futures.ThreadPoolExecutor.
|
||||
Note that some operations currently hold Python's Global Interpreter
|
||||
Lock, so parallelizing with threads may be slower than training
|
||||
sequentially.
|
||||
* 'dask': Parallelize with Dask.
|
||||
This requires that a dask.distributed Client be created.
|
||||
* object : Any instance with a `.map` method. This method will
|
||||
be called with :func:`single_cutoff_forecast` and a sequence of
|
||||
iterables where each element is the tuple of arguments to pass to
|
||||
:func:`single_cutoff_forecast`
|
||||
|
||||
.. code-block::
|
||||
|
||||
class MyBackend:
|
||||
def map(self, func, *iterables):
|
||||
results = [
|
||||
func(*args)
|
||||
for args in zip(*iterables)
|
||||
]
|
||||
return results
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -127,11 +150,39 @@ def cross_validation(model, horizon, period=None, initial=None, multiprocess=Fal
|
|||
msg += 'Consider increasing initial.'
|
||||
logger.warning(msg)
|
||||
|
||||
if multiprocess is True:
|
||||
with Pool() as pool:
|
||||
logger.info('Running cross validation in multiprocessing mode')
|
||||
input_df = ((df, model, cutoff, horizon, predict_columns) for cutoff in cutoffs)
|
||||
predicts = pool.starmap(single_cutoff_forecast, input_df)
|
||||
if parallel:
|
||||
valid = {"threads", "processes", "dask"}
|
||||
|
||||
if parallel == "threads":
|
||||
pool = concurrent.futures.ThreadPoolExecutor()
|
||||
elif parallel == "processes":
|
||||
pool = concurrent.futures.ProcessPoolExecutor()
|
||||
elif parallel == "dask":
|
||||
try:
|
||||
from dask.distributed import get_client
|
||||
except ImportError as e:
|
||||
raise ImportError("parallel='dask' requies the optional "
|
||||
"dependency dask.") from e
|
||||
pool = get_client()
|
||||
# delay df and model to avoid large objects in task graph.
|
||||
df, model = pool.scatter([df, model])
|
||||
elif hasattr(parallel, "map"):
|
||||
pool = parallel
|
||||
else:
|
||||
msg = ("'parallel' should be one of {} for an instance with a "
|
||||
"'map' method".format(', '.join(valid)))
|
||||
raise ValueError(msg)
|
||||
|
||||
iterables = ((df, model, cutoff, horizon, predict_columns)
|
||||
for cutoff in cutoffs)
|
||||
iterables = zip(*iterables)
|
||||
|
||||
logger.info("Applying in parallel with %s", pool)
|
||||
predicts = pool.map(single_cutoff_forecast, *iterables)
|
||||
if parallel == "dask":
|
||||
# convert Futures to DataFrames
|
||||
predicts = pool.gather(predicts)
|
||||
|
||||
else:
|
||||
predicts = [
|
||||
single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,12 @@ DATA_all = pd.read_csv(
|
|||
DATA = DATA_all.head(100)
|
||||
|
||||
|
||||
class CustomParallelBackend:
|
||||
def map(self, func, *iterables):
|
||||
results = [func(*args) for args in zip(*iterables)]
|
||||
return results
|
||||
|
||||
|
||||
class TestDiagnostics(TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -40,11 +46,19 @@ class TestDiagnostics(TestCase):
|
|||
horizon = pd.Timedelta('4 days')
|
||||
period = pd.Timedelta('10 days')
|
||||
initial = pd.Timedelta('115 days')
|
||||
# Run for both cases of multiprocess on or off
|
||||
for multiprocess in [False, True]:
|
||||
methods = [None, 'processes', 'threads', CustomParallelBackend()]
|
||||
|
||||
try:
|
||||
from dask.distributed import Client
|
||||
client = Client(processes=False) # noqa
|
||||
methods.append("dask")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for parallel in methods:
|
||||
df_cv = diagnostics.cross_validation(
|
||||
m, horizon='4 days', period='10 days', initial='115 days',
|
||||
multiprocess=multiprocess)
|
||||
parallel=parallel)
|
||||
self.assertEqual(len(np.unique(df_cv['cutoff'])), 3)
|
||||
self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)
|
||||
self.assertTrue(min(df_cv['cutoff']) >= min(self.__df['ds']) + initial)
|
||||
|
|
@ -63,6 +77,15 @@ class TestDiagnostics(TestCase):
|
|||
diagnostics.cross_validation(
|
||||
m, horizon='10 days', period='10 days', initial='140 days')
|
||||
|
||||
# invalid alias
|
||||
with self.assertRaises(ValueError, match="'parallel' should be one"):
|
||||
diagnostics.cross_validation(m, horizon="4 days", parallel="bad")
|
||||
|
||||
# no map method
|
||||
with self.assertRaises(ValueError, match="'parallel' should be one"):
|
||||
diagnostics.cross_validation(m, horizon="4 days", parallel=object())
|
||||
|
||||
|
||||
def test_check_single_cutoff_forecast_func_calls(self):
|
||||
m = Prophet()
|
||||
m.fit(self.__df)
|
||||
|
|
|
|||
Loading…
Reference in a new issue