saving all variables now added a2c support

This commit is contained in:
Noah Dormann 2019-11-21 16:24:18 +01:00
parent 526c37bf1f
commit 775a50cc5c
4 changed files with 15 additions and 31 deletions

View file

@ -9,7 +9,8 @@ from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines.common.identity_env import IdentityEnvBox
MODEL_LIST = [
PPO
PPO,
A2C,
]
@ -51,6 +52,7 @@ def test_save_load(model_class):
model.save("test_save.zip")
del model
model = model_class.load("test_save")
model.learn(total_timesteps=1000, eval_freq=500)
# check if params are still the same after load
new_params = model.get_policy_parameters()

View file

@ -126,13 +126,14 @@ class A2C(PPO):
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps)
def save(self, path):
if not path.endswith('.pth'):
path += '.pth'
th.save(self.policy.state_dict(), path)
"""
saves all the params from init and pytorch params in a file for continous learning
def load(self, path, env=None, **_kwargs):
if not path.endswith('.pth'):
path += '.pth'
if env is not None:
pass
self.policy.load_state_dict(th.load(path))
:param path: path to the file where the data should be safed
:return:
"""
data = self.__dict__
params_to_save = self.get_policy_parameters()
opt_params_to_save = self.get_opt_parameters()
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)

View file

@ -296,7 +296,7 @@ class BaseRLModel(object):
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))
model = cls(policy=data["policy"], env=None, _init_setup_model=False)
model = cls(policy=data["policy"], env=data["env"], _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
model.set_env(env)

View file

@ -322,26 +322,7 @@ class PPO(BaseRLModel):
:param path: path to the file where the data should be safed
:return:
"""
data = {
"gamma": self.gamma,
"n_steps": self.n_steps,
"vf_coef": self.vf_coef,
"ent_coef": self.ent_coef,
"max_grad_norm": self.max_grad_norm,
"learning_rate": self.learning_rate,
"gae_lambda": self.gae_lambda,
"n_epochs": self.n_epochs,
"clip_range": self.clip_range,
"clip_range_vf": self.clip_range_vf,
"batch_size": self.batch_size,
"target_kl": self.target_kl,
"tensorboard_log": self.tensorboard_log,
"policy_kwargs": self.policy_kwargs,
"policy": self.policy,
}
data = self.__dict__
params_to_save = self.get_policy_parameters()
opt_params_to_save = self.get_opt_parameters()
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)