mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
saving all variables now added a2c support
This commit is contained in:
parent
526c37bf1f
commit
775a50cc5c
4 changed files with 15 additions and 31 deletions
|
|
@ -9,7 +9,8 @@ from torchy_baselines.common.vec_env import DummyVecEnv
|
|||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
|
||||
MODEL_LIST = [
|
||||
PPO
|
||||
PPO,
|
||||
A2C,
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -51,6 +52,7 @@ def test_save_load(model_class):
|
|||
model.save("test_save.zip")
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.get_policy_parameters()
|
||||
|
|
|
|||
|
|
@ -126,13 +126,14 @@ class A2C(PPO):
|
|||
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps)
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
"""
|
||||
saves all the params from init and pytorch params in a file for continous learning
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
:param path: path to the file where the data should be safed
|
||||
:return:
|
||||
"""
|
||||
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class BaseRLModel(object):
|
|||
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
|
||||
kwargs['policy_kwargs']))
|
||||
|
||||
model = cls(policy=data["policy"], env=None, _init_setup_model=False)
|
||||
model = cls(policy=data["policy"], env=data["env"], _init_setup_model=False)
|
||||
model.__dict__.update(data)
|
||||
model.__dict__.update(kwargs)
|
||||
model.set_env(env)
|
||||
|
|
|
|||
|
|
@ -322,26 +322,7 @@ class PPO(BaseRLModel):
|
|||
:param path: path to the file where the data should be safed
|
||||
:return:
|
||||
"""
|
||||
|
||||
data = {
|
||||
"gamma": self.gamma,
|
||||
"n_steps": self.n_steps,
|
||||
"vf_coef": self.vf_coef,
|
||||
"ent_coef": self.ent_coef,
|
||||
"max_grad_norm": self.max_grad_norm,
|
||||
"learning_rate": self.learning_rate,
|
||||
"gae_lambda": self.gae_lambda,
|
||||
"n_epochs": self.n_epochs,
|
||||
"clip_range": self.clip_range,
|
||||
"clip_range_vf": self.clip_range_vf,
|
||||
"batch_size": self.batch_size,
|
||||
"target_kl": self.target_kl,
|
||||
"tensorboard_log": self.tensorboard_log,
|
||||
"policy_kwargs": self.policy_kwargs,
|
||||
"policy": self.policy,
|
||||
|
||||
}
|
||||
|
||||
data = self.__dict__
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save)
|
||||
Loading…
Reference in a new issue