mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-06 00:03:25 +00:00
Handle numpy fit_kwargs when serializing (#1701)
This commit is contained in:
parent
9e4e87af9b
commit
29f14172f0
1 changed files with 17 additions and 1 deletions
|
|
@ -7,6 +7,7 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -21,7 +22,7 @@ SIMPLE_ATTRIBUTES = [
|
|||
'yearly_seasonality', 'weekly_seasonality', 'daily_seasonality',
|
||||
'seasonality_mode', 'seasonality_prior_scale', 'changepoint_prior_scale',
|
||||
'holidays_prior_scale', 'mcmc_samples', 'interval_width', 'uncertainty_samples',
|
||||
'y_scale', 'logistic_floor', 'country_holidays', 'component_modes', 'fit_kwargs'
|
||||
'y_scale', 'logistic_floor', 'country_holidays', 'component_modes'
|
||||
]
|
||||
|
||||
PD_SERIES = ['changepoints', 'history_dates', 'train_holiday_names']
|
||||
|
|
@ -84,6 +85,18 @@ def model_to_json(model):
|
|||
list(getattr(model, attribute).keys()),
|
||||
getattr(model, attribute),
|
||||
]
|
||||
# Other attributes with special handling
|
||||
# fit_kwargs -> Transform any numpy types before serializing.
|
||||
# They do not need to be transformed back on deserializing.
|
||||
fit_kwargs = deepcopy(model.fit_kwargs)
|
||||
if 'init' in fit_kwargs:
|
||||
for k, v in fit_kwargs['init'].items():
|
||||
if isinstance(v, np.ndarray):
|
||||
fit_kwargs['init'][k] = v.tolist()
|
||||
elif isinstance(v, np.floating):
|
||||
fit_kwargs['init'][k] = float(v)
|
||||
model_json['fit_kwargs'] = fit_kwargs
|
||||
|
||||
# Params (Dict[str, np.ndarray])
|
||||
model_json['params'] = {k: v.tolist() for k, v in model.params.items()}
|
||||
# Attributes that are skipped: stan_fit, stan_backend
|
||||
|
|
@ -141,6 +154,9 @@ def model_from_json(model_json):
|
|||
for key in key_list:
|
||||
od[key] = unordered_dict[key]
|
||||
setattr(model, attribute, od)
|
||||
# Other attributes with special handling
|
||||
# fit_kwargs
|
||||
model.fit_kwargs = attr_dict['fit_kwargs']
|
||||
# Params (Dict[str, np.ndarray])
|
||||
model.params = {k: np.array(v) for k, v in attr_dict['params'].items()}
|
||||
# Skipped attributes
|
||||
|
|
|
|||
Loading…
Reference in a new issue