change init into inits for CmdStanPyBackend

This commit is contained in:
loulo1 2021-04-14 15:39:02 +02:00 committed by GitHub
parent fc8fa49aac
commit 3a0061e8e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -87,20 +87,18 @@ class CmdStanPyBackend(IStanBackend):
def fit(self, stan_init, stan_data, **kwargs):
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'init' in kwargs:
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
args = dict(
data=stan_data,
init=stan_init,
inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
iter=int(1e4),
)
args.update(kwargs)
args['inits'] = args['init']
del args['init']
try:
self.stan_fit = self.model.optimize(**args)
except RuntimeError as e:
@ -122,12 +120,13 @@ class CmdStanPyBackend(IStanBackend):
def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'init' in kwargs:
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]
if 'inits' not in kwargs and 'init' in kwargs:
kwargs['inits'] = self.prepare_data(kwargs['init'], stan_data)[0]
args = dict(
data=stan_data,
init=stan_init,
inits=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
)
@ -140,9 +139,6 @@ class CmdStanPyBackend(IStanBackend):
args.update(kwargs)
args['inits'] = args['init']
del args['init']
self.stan_fit = self.model.sample(**args)
res = self.stan_fit.sample
(samples, c, columns) = res.shape