Handle numpy fit_kwargs when serializing (#1701)

This commit is contained in:
Ben Letham 2021-03-03 15:23:12 -08:00
parent 9e4e87af9b
commit 29f14172f0

View file

@ -7,6 +7,7 @@
from __future__ import absolute_import, division, print_function
from collections import OrderedDict
from copy import deepcopy
import json
import numpy as np
@ -21,7 +22,7 @@ SIMPLE_ATTRIBUTES = [
'yearly_seasonality', 'weekly_seasonality', 'daily_seasonality',
'seasonality_mode', 'seasonality_prior_scale', 'changepoint_prior_scale',
'holidays_prior_scale', 'mcmc_samples', 'interval_width', 'uncertainty_samples',
'y_scale', 'logistic_floor', 'country_holidays', 'component_modes', 'fit_kwargs'
'y_scale', 'logistic_floor', 'country_holidays', 'component_modes'
]
PD_SERIES = ['changepoints', 'history_dates', 'train_holiday_names']
@ -84,6 +85,18 @@ def model_to_json(model):
list(getattr(model, attribute).keys()),
getattr(model, attribute),
]
# Other attributes with special handling
# fit_kwargs -> Transform any numpy types before serializing.
# They do not need to be transformed back on deserializing.
fit_kwargs = deepcopy(model.fit_kwargs)
if 'init' in fit_kwargs:
for k, v in fit_kwargs['init'].items():
if isinstance(v, np.ndarray):
fit_kwargs['init'][k] = v.tolist()
elif isinstance(v, np.floating):
fit_kwargs['init'][k] = float(v)
model_json['fit_kwargs'] = fit_kwargs
# Params (Dict[str, np.ndarray])
model_json['params'] = {k: v.tolist() for k, v in model.params.items()}
# Attributes that are skipped: stan_fit, stan_backend
@ -141,6 +154,9 @@ def model_from_json(model_json):
for key in key_list:
od[key] = unordered_dict[key]
setattr(model, attribute, od)
# Other attributes with special handling
# fit_kwargs
model.fit_kwargs = attr_dict['fit_kwargs']
# Params (Dict[str, np.ndarray])
model.params = {k: np.array(v) for k, v in attr_dict['params'].items()}
# Skipped attributes