mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-24 22:25:13 +00:00
Added SAC, TD3, A2C
Missing CEMRL
This commit is contained in:
parent
775a50cc5c
commit
2d72f6d1b5
7 changed files with 72 additions and 62 deletions
|
|
@ -11,6 +11,8 @@ from torchy_baselines.common.identity_env import IdentityEnvBox
|
|||
MODEL_LIST = [
|
||||
PPO,
|
||||
A2C,
|
||||
TD3,
|
||||
SAC,
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -52,7 +54,6 @@ def test_save_load(model_class):
|
|||
model.save("test_save.zip")
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.get_policy_parameters()
|
||||
|
|
@ -66,4 +67,9 @@ def test_save_load(model_class):
|
|||
# check if keys are the same
|
||||
assert opt_params.keys() == new_opt_params.keys()
|
||||
# check if values are the same: don't know how to to that
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
|
|
|||
|
|
@ -124,16 +124,3 @@ class A2C(PPO):
|
|||
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
|
||||
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps)
|
||||
|
||||
def save(self, path):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continous learning
|
||||
|
||||
:param path: path to the file where the data should be safed
|
||||
:return:
|
||||
"""
|
||||
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
|
|
|
|||
|
|
@ -157,16 +157,3 @@ class CEMRL(TD3):
|
|||
self.es.tell(self.es_params, self.fitnesses)
|
||||
timesteps_since_eval += actor_steps
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
|
|
|
|||
|
|
@ -511,3 +511,16 @@ class BaseRLModel(object):
|
|||
for file_name, dict in opt_params.items():
|
||||
with file_.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict, opt_param_file)
|
||||
|
||||
|
||||
def save(self, path):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continous learning
|
||||
|
||||
:param path: path to the file where the data should be safed
|
||||
:return:
|
||||
"""
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)
|
||||
|
|
|
|||
|
|
@ -314,15 +314,3 @@ class PPO(BaseRLModel):
|
|||
"""
|
||||
self.policy.optimizer.load_state_dict(opt_params["opt"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
||||
def save(self, path):
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continous learning
|
||||
|
||||
:param path: path to the file where the data should be safed
|
||||
:return:
|
||||
"""
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)
|
||||
|
|
@ -274,15 +274,31 @@ class SAC(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
returns a dict of all the optimizers and their parameters
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
"""
|
||||
if self.ent_coef_optimizer is not None:
|
||||
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(),"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()}
|
||||
else:
|
||||
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
|
||||
Dictionary should be of shape torch model.state_dict()
|
||||
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
if "ent_coef_optimizer" in opt_params:
|
||||
self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class TD3(BaseRLModel):
|
|||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,
|
||||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
|
|
@ -148,7 +149,8 @@ class TD3(BaseRLModel):
|
|||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||||
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|
||||
|
||||
def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, tau_critic: object = 0.005,
|
||||
def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005,
|
||||
tau_critic: object = 0.005,
|
||||
replay_data: object = None) -> object:
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.actor.optimizer)
|
||||
|
|
@ -235,15 +237,26 @@ class TD3(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
returns a dict of all the optimizers and their parameters
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
"""
|
||||
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
|
||||
Dictionary should be of shape torch model.state_dict()
|
||||
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
|
|
|||
Loading…
Reference in a new issue