Sanitize warm start parameters (#2342)

This commit is contained in:
Cuong Duong 2023-01-15 00:05:40 +11:00 committed by GitHub
parent 0a5becb777
commit 8b3d09caf7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 67 deletions

View file

@ -82,11 +82,12 @@ A common setting for forecasting is fitting models that need to be updated as ad
```python
# Python
def get_stan_init(m):
"""Retrieve parameters from a trained model.
Retrieve parameters from a trained model in the format
used to initialize a new Stan model.
def warm_start_params(m):
"""
Retrieve parameters from a trained model in the format used to initialize a new Stan model.
Note that the new Stan model must have these same settings:
n_changepoints, seasonality features, mcmc sampling
for the retrieved parameters to be valid for the new model.
Parameters
----------
@ -115,7 +116,7 @@ m1 = Prophet().fit(df1) # A model fit to all data except the last day
%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch
%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1
%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1
```
1.33 s ± 55.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
185 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

View file

@ -216,11 +216,12 @@
}
],
"source": [
"def get_stan_init(m):\n",
" \"\"\"Retrieve parameters from a trained model.\n",
"\n",
" Retrieve parameters from a trained model in the format\n",
" used to initialize a new Stan model.\n",
"def warm_start_params(m):\n",
" \"\"\"\n",
" Retrieve parameters from a trained model in the format used to initialize a new Stan model.\n",
" Note that the new Stan model must have these same settings:\n",
" n_changepoints, seasonality features, mcmc sampling\n",
" for the retrieved parameters to be valid for the new model.\n",
"\n",
" Parameters\n",
" ----------\n",
@ -249,7 +250,7 @@
"\n",
"\n",
"%timeit m2 = Prophet().fit(df) # Adding the last day, fitting from scratch\n",
"%timeit m2 = Prophet().fit(df, init=get_stan_init(m1)) # Adding the last day, warm-starting from m1"
"%timeit m2 = Prophet().fit(df, init=warm_start_params(m1)) # Adding the last day, warm-starting from m1"
]
},
{
@ -275,7 +276,7 @@
"metadata": {
"celltoolbar": "Edit Metadata",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -289,7 +290,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.10 (default, Mar 25 2022, 22:18:25) \n[Clang 12.0.5 (clang-1205.0.22.11)]"
},
"vscode": {
"interpreter": {
"hash": "9f764ff66e3236555d51a00749f5293db82082e341e153963b8b2deea93b52fc"
}
}
},
"nbformat": 4,

View file

@ -80,14 +80,14 @@ class CmdStanPyBackend(IStanBackend):
def fit(self, stan_init, stan_data, **kwargs):
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
del kwargs['init']
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
inits_list, data_list = self.prepare_data(stan_init, stan_data)
args = dict(
data=stan_data,
inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
data=data_list,
inits=inits_list,
algorithm='Newton' if data_list['T'] < 100 else 'LBFGS',
iter=int(1e4),
)
args.update(kwargs)
@ -109,22 +109,20 @@ class CmdStanPyBackend(IStanBackend):
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'], _ = self.prepare_data(kwargs['init'], stan_data)
stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
del kwargs['init']
stan_init, stan_data = self.prepare_data(stan_init, stan_data)
inits_list, data_list = self.prepare_data(stan_init, stan_data)
args = dict(
data=stan_data,
inits=stan_init,
data=data_list,
inits=inits_list,
)
if 'chains' not in kwargs:
kwargs['chains'] = 4
iter_half = samples // 2
kwargs['iter_sampling'] = iter_half
if 'iter_warmup' not in kwargs:
kwargs['iter_warmup'] = iter_half
args.update(kwargs)
self.stan_fit = self.model.sample(**args)
@ -143,8 +141,25 @@ class CmdStanPyBackend(IStanBackend):
return params
@staticmethod
def sanitize_custom_inits(default_inits, custom_inits):
"""Validate that custom inits have the correct type and shape, otherwise use defaults."""
sanitized = {}
for param in ['k', 'm', 'sigma_obs']:
try:
sanitized[param] = float(custom_inits.get(param))
except Exception:
sanitized[param] = default_inits[param]
for param in ['delta', 'beta']:
if default_inits[param].shape == custom_inits[param].shape:
sanitized[param] = custom_inits[param]
else:
sanitized[param] = default_inits[param]
return sanitized
@staticmethod
def prepare_data(init, data) -> Tuple[dict, dict]:
"""Converts np.ndarrays to lists that can be read by cmdstanpy."""
cmdstanpy_data = {
'T': data['T'],
'S': data['S'],

View file

@ -15,6 +15,7 @@ from unittest import TestCase, skipUnless
import numpy as np
import pandas as pd
from prophet import Prophet
from prophet.utilities import warm_start_params
DATA = pd.read_csv(
@ -892,50 +893,12 @@ class TestProphet(TestCase):
},
)
@staticmethod
def get_stan_init(m):
"""Retrieve parameters from a trained model.
Retrieve parameters from a trained model in the format
used to initialize a new Stan model.
Parameters
----------
m: A trained model of the Prophet class.
Returns
-------
A Dictionary containing retrieved parameters of m.
"""
res = {}
for pname in ['k', 'm', 'sigma_obs']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0][0]
else:
res[pname] = np.mean(m.params[pname])
for pname in ['delta', 'beta']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0]
else:
res[pname] = np.mean(m.params[pname], axis=0)
return res
def test_fit_warm_start(self):
previous_df = DATA.iloc[:500]
df = DATA.iloc[:510]
m = Prophet()
m = m.fit(previous_df)
m_params = self.get_stan_init(m)
m2 = Prophet()
m2 = m2.fit(df, init=m_params)
m = Prophet().fit(DATA.iloc[:500])
m2 = Prophet().fit(DATA.iloc[:510], init=warm_start_params(m))
self.assertEqual(len(m2.params['delta'][0]), 25)
def test_sampling_warm_start(self):
previous_df = DATA.iloc[:500]
df = DATA.iloc[:510]
m = Prophet(mcmc_samples=100)
m = m.fit(previous_df, show_progress=False)
m_params = self.get_stan_init(m)
m2 = Prophet(mcmc_samples=100)
m2 = m2.fit(df, init=m_params, show_progress=False)
m = Prophet(mcmc_samples=100).fit(DATA.iloc[:500], show_progress=False)
m2 = Prophet(mcmc_samples=100).fit(DATA.iloc[:510], init=warm_start_params(m), show_progress=False)
self.assertEqual(m2.params['delta'].shape, (200, 25))

View file

@ -75,3 +75,31 @@ def regressor_coefficients(m):
coefs.append(record)
return pd.DataFrame(coefs)
def warm_start_params(m):
"""
Retrieve parameters from a trained model in the format used to initialize a new Stan model.
Note that the new Stan model must have these same settings:
n_changepoints, seasonality features, mcmc sampling
for the retrieved parameters to be valid for the new model.
Parameters
----------
m: A trained model of the Prophet class.
Returns
-------
A Dictionary containing retrieved parameters of m.
"""
res = {}
for pname in ['k', 'm', 'sigma_obs']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0][0]
else:
res[pname] = np.mean(m.params[pname])
for pname in ['delta', 'beta']:
if m.mcmc_samples == 0:
res[pname] = m.params[pname][0]
else:
res[pname] = np.mean(m.params[pname], axis=0)
return res