From 2d72f6d1b5b2b9b707fc03b6546ca5becd443eac Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:46:53 +0100 Subject: [PATCH] Added SAC, TD3, A2C Missing CEMRL --- tests/test_save_load.py | 8 +++++- torchy_baselines/a2c/a2c.py | 13 --------- torchy_baselines/cem_rl/cem_rl.py | 13 --------- torchy_baselines/common/base_class.py | 13 +++++++++ torchy_baselines/ppo/ppo.py | 12 --------- torchy_baselines/sac/sac.py | 38 +++++++++++++++++++-------- torchy_baselines/td3/td3.py | 37 +++++++++++++++++--------- 7 files changed, 72 insertions(+), 62 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1eb4fa0..6368127 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -11,6 +11,8 @@ from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ PPO, A2C, + TD3, + SAC, ] @@ -52,7 +54,6 @@ def test_save_load(model_class): model.save("test_save.zip") del model model = model_class.load("test_save") - model.learn(total_timesteps=1000, eval_freq=500) # check if params are still the same after load new_params = model.get_policy_parameters() @@ -66,4 +67,9 @@ def test_save_load(model_class): # check if keys are the same assert opt_params.keys() == new_opt_params.keys() # check if values are the same: don't know how to to that + + # check if learn still works + model.learn(total_timesteps=1000, eval_freq=500) + + # clear file from os os.remove("test_save.zip") diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 8811356..5babed5 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -124,16 +124,3 @@ class A2C(PPO): return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) - - def save(self, path): - """ - saves all the params from init and pytorch params in a file for continous learning - - :param path: path to the file where the data should be safed - :return: - """ - - data = self.__dict__ - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index ad798ba..f35fa43 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -157,16 +157,3 @@ class CEMRL(TD3): self.es.tell(self.es_params, self.fitnesses) timesteps_since_eval += actor_steps return self - - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) - - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 5eb7ff0..b2c6056 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -511,3 +511,16 @@ class BaseRLModel(object): for file_name, dict in opt_params.items(): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) + + + def save(self, path): + """ + saves all the params from init and pytorch params in a file for continous learning + + :param path: path to the file where the data should be safed + :return: + """ + data = self.__dict__ + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 3af5b08..3b4734e 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -314,15 +314,3 @@ class PPO(BaseRLModel): """ self.policy.optimizer.load_state_dict(opt_params["opt"]) self.policy.load_state_dict(load_dict) - - def save(self, path): - """ - saves all the params from init and pytorch params in a file for continous learning - - :param path: path to the file where the data should be safed - :return: - """ - data = self.__dict__ - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) \ No newline at end of file diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index bad470a..2497523 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -274,15 +274,31 @@ class SAC(BaseRLModel): return self - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + def get_opt_parameters(self): + """ + returns a dict of all the optimizers and their parameters - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() + :return: (Dict) of optimizer names and their state_dict + """ + if self.ent_coef_optimizer is not None: + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(),"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()} + else: + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + + def load_parameters(self, load_dict, opt_params): + """ + Load model parameters and optimizer parameters from a dictionary + + Dictionary should be of shape torch model.state_dict() + + This does not load agent's hyper-parameters. + + + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + """ + self.actor.optimizer.load_state_dict(opt_params["actor"]) + self.critic.optimizer.load_state_dict(opt_params["critic"]) + if "ent_coef_optimizer" in opt_params: + self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"]) + self.policy.load_state_dict(load_dict) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 66ea72e..035d344 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -46,6 +46,7 @@ class TD3(BaseRLModel): Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ + def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3, policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, @@ -148,7 +149,8 @@ class TD3(BaseRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, tau_critic: object = 0.005, + def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, + tau_critic: object = 0.005, replay_data: object = None) -> object: # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer) @@ -235,15 +237,26 @@ class TD3(BaseRLModel): return self - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + def get_opt_parameters(self): + """ + returns a dict of all the optimizers and their parameters - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() + :return: (Dict) of optimizer names and their state_dict + """ + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + + def load_parameters(self, load_dict, opt_params): + """ + Load model parameters and optimizer parameters from a dictionary + + Dictionary should be of shape torch model.state_dict() + + This does not load agent's hyper-parameters. + + + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + """ + self.actor.optimizer.load_state_dict(opt_params["actor"]) + self.critic.optimizer.load_state_dict(opt_params["critic"]) + self.policy.load_state_dict(load_dict)