diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 81cfe52..1eb4fa0 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -9,7 +9,8 @@ from torchy_baselines.common.vec_env import DummyVecEnv from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ - PPO + PPO, + A2C, ] @@ -51,6 +52,7 @@ def test_save_load(model_class): model.save("test_save.zip") del model model = model_class.load("test_save") + model.learn(total_timesteps=1000, eval_freq=500) # check if params are still the same after load new_params = model.get_policy_parameters() diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 1c066cf..8811356 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -126,13 +126,14 @@ class A2C(PPO): tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + """ + saves all the params from init and pytorch params in a file for continous learning - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) + :param path: path to the file where the data should be safed + :return: + """ + + data = self.__dict__ + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 6955362..5eb7ff0 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -296,7 +296,7 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) - model = cls(policy=data["policy"], env=None, _init_setup_model=False) + model = cls(policy=data["policy"], env=data["env"], _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index fed5969..3af5b08 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -322,26 +322,7 @@ class PPO(BaseRLModel): :param path: path to the file where the data should be safed :return: """ - - data = { - "gamma": self.gamma, - "n_steps": self.n_steps, - "vf_coef": self.vf_coef, - "ent_coef": self.ent_coef, - "max_grad_norm": self.max_grad_norm, - "learning_rate": self.learning_rate, - "gae_lambda": self.gae_lambda, - "n_epochs": self.n_epochs, - "clip_range": self.clip_range, - "clip_range_vf": self.clip_range_vf, - "batch_size": self.batch_size, - "target_kl": self.target_kl, - "tensorboard_log": self.tensorboard_log, - "policy_kwargs": self.policy_kwargs, - "policy": self.policy, - - } - + data = self.__dict__ params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) \ No newline at end of file