mirror of
https://github.com/saymrwulf/prophet.git
synced 2026-06-01 23:30:43 +00:00
change init into inits for CmdStanPyBackend
This commit is contained in:
parent
fc8fa49aac
commit
3a0061e8e1
1 changed files with 8 additions and 12 deletions
|
|
@ -87,20 +87,18 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def fit(self, stan_init, stan_data, **kwargs):
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
if 'init' in kwargs:
|
||||
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
init=stan_init,
|
||||
inits=stan_init,
|
||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
|
||||
iter=int(1e4),
|
||||
)
|
||||
args.update(kwargs)
|
||||
|
||||
args['inits'] = args['init']
|
||||
del args['init']
|
||||
|
||||
try:
|
||||
self.stan_fit = self.model.optimize(**args)
|
||||
except RuntimeError as e:
|
||||
|
|
@ -122,12 +120,13 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
|
||||
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
|
||||
if 'init' in kwargs:
|
||||
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
if 'inits' not in kwargs and 'init' in kwargs:
|
||||
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
|
||||
|
||||
args = dict(
|
||||
data=stan_data,
|
||||
init=stan_init,
|
||||
inits=stan_init,
|
||||
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
|
||||
)
|
||||
|
||||
|
|
@ -140,9 +139,6 @@ class CmdStanPyBackend(IStanBackend):
|
|||
|
||||
args.update(kwargs)
|
||||
|
||||
args['inits'] = args['init']
|
||||
del args['init']
|
||||
|
||||
self.stan_fit = self.model.sample(**args)
|
||||
res = self.stan_fit.sample
|
||||
(samples, c, columns) = res.shape
|
||||
|
|
|
|||
Loading…
Reference in a new issue