From 29f14172f0ed7dc9371017544c3e32a2a1c3ec30 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Wed, 3 Mar 2021 15:23:12 -0800 Subject: [PATCH] Handle numpy fit_kwargs when serializing (#1701) --- python/fbprophet/serialize.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python/fbprophet/serialize.py b/python/fbprophet/serialize.py index 64f6c79..c369615 100644 --- a/python/fbprophet/serialize.py +++ b/python/fbprophet/serialize.py @@ -7,6 +7,7 @@ from __future__ import absolute_import, division, print_function from collections import OrderedDict +from copy import deepcopy import json import numpy as np @@ -21,7 +22,7 @@ SIMPLE_ATTRIBUTES = [ 'yearly_seasonality', 'weekly_seasonality', 'daily_seasonality', 'seasonality_mode', 'seasonality_prior_scale', 'changepoint_prior_scale', 'holidays_prior_scale', 'mcmc_samples', 'interval_width', 'uncertainty_samples', - 'y_scale', 'logistic_floor', 'country_holidays', 'component_modes', 'fit_kwargs' + 'y_scale', 'logistic_floor', 'country_holidays', 'component_modes' ] PD_SERIES = ['changepoints', 'history_dates', 'train_holiday_names'] @@ -84,6 +85,18 @@ def model_to_json(model): list(getattr(model, attribute).keys()), getattr(model, attribute), ] + # Other attributes with special handling + # fit_kwargs -> Transform any numpy types before serializing. + # They do not need to be transformed back on deserializing. + fit_kwargs = deepcopy(model.fit_kwargs) + if 'init' in fit_kwargs: + for k, v in fit_kwargs['init'].items(): + if isinstance(v, np.ndarray): + fit_kwargs['init'][k] = v.tolist() + elif isinstance(v, np.floating): + fit_kwargs['init'][k] = float(v) + model_json['fit_kwargs'] = fit_kwargs + # Params (Dict[str, np.ndarray]) model_json['params'] = {k: v.tolist() for k, v in model.params.items()} # Attributes that are skipped: stan_fit, stan_backend @@ -141,6 +154,9 @@ def model_from_json(model_json): for key in key_list: od[key] = unordered_dict[key] setattr(model, attribute, od) + # Other attributes with special handling + # fit_kwargs + model.fit_kwargs = attr_dict['fit_kwargs'] # Params (Dict[str, np.ndarray]) model.params = {k: np.array(v) for k, v in attr_dict['params'].items()} # Skipped attributes