mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Merge pull request #3 from Antonin-Raffin/safe_load_modules
Save and load methods
This commit is contained in:
commit
2690fa4fba
15 changed files with 731 additions and 118 deletions
|
|
@ -22,6 +22,7 @@ sys.path.insert(0, os.path.abspath('..'))
|
|||
|
||||
class Mock(MagicMock):
|
||||
__subclasses__ = []
|
||||
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
return MagicMock()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from torchy_baselines import PPO
|
||||
|
||||
|
||||
@pytest.mark.parametrize('net_arch', [
|
||||
[12, dict(vf=[16], pi=[8])],
|
||||
[4],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import numpy as np
|
|||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.noise import NormalActionNoise
|
||||
|
||||
|
||||
action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
|
||||
|
||||
|
||||
|
|
@ -16,7 +15,7 @@ def test_td3():
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
def test_cemrl():
|
||||
|
|
@ -25,7 +24,7 @@ def test_cemrl():
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
|
|
@ -33,12 +32,16 @@ def test_cemrl():
|
|||
def test_onpolicy(model_class, env_id):
|
||||
model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
# model.save("test_save")
|
||||
# model.load("test_save")
|
||||
# os.remove("test_save.pth")
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
def test_sac():
|
||||
model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto',
|
||||
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)))
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.zip")
|
||||
|
|
|
|||
156
tests/test_save_load.py
Normal file
156
tests/test_save_load.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
import os
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
|
||||
MODEL_LIST = [
|
||||
CEMRL,
|
||||
PPO,
|
||||
A2C,
|
||||
TD3,
|
||||
SAC,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_save_load(model_class):
|
||||
"""
|
||||
Test if 'save' and 'load' saves and loads model correctly
|
||||
and if 'load_parameters' and 'get_policy_parameters' work correctly
|
||||
|
||||
''warning does not test function of optimizer parameter load
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
|
||||
observations = np.squeeze(observations)
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.get_policy_parameters())
|
||||
opt_params = deepcopy(model.get_opt_parameters())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
|
||||
# Update model parameters with the new random values
|
||||
model.load_parameters(random_params, opt_params)
|
||||
|
||||
new_params = model.get_policy_parameters()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
|
||||
# Check
|
||||
model.save("test_save.zip")
|
||||
del model
|
||||
model = model_class.load("test_save", env=env)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.get_policy_parameters()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for k in params:
|
||||
assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load."
|
||||
|
||||
# check if optimizer params are still the same after load
|
||||
new_opt_params = model.get_opt_parameters()
|
||||
# check if keys are the same
|
||||
assert opt_params.keys() == new_opt_params.keys()
|
||||
# check if values are the same: only tested for Adam and RMSProp so far
|
||||
# comparing states not implemented so far. hashes of state_entries are not the same for equal tensors
|
||||
# comparing every sub_entry does not work because of bool value of Tensor with more than one value is ambiguous
|
||||
# so far only comparing param_lists
|
||||
for optimizer, opt_state in opt_params.items():
|
||||
for param_group_idx, param_group in enumerate(opt_state['param_groups']):
|
||||
for param_key, param_value in param_group.items():
|
||||
if param_key == 'params': # don't know how to handle params correctly, therefore only check if we have the same amount
|
||||
assert len(param_value) == len(
|
||||
new_opt_params[optimizer]['param_groups'][param_group_idx][param_key])
|
||||
else:
|
||||
assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
# for i in range(len(selected_actionsselected_actions)):
|
||||
assert np.allclose(selected_actions, new_selected_actions)
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_set_env(model_class):
|
||||
"""
|
||||
Test if set_env function does work correct
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
env2 = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
env3 = IdentityEnvBox(10)
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True)
|
||||
# learn
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env
|
||||
model.set_env(env2)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
# change env test wrapping
|
||||
model.set_env(env3)
|
||||
# learn again
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_exclude_include_saved_params(model_class):
|
||||
"""
|
||||
Test if exclude and include parameters of save() work
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
# create model, set verbose as 2, which is not standard
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2, create_eval_env=True)
|
||||
|
||||
# Check if exclude works
|
||||
model.save("test_save.zip", exclude=["verbose"])
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
# check if verbose was not saved
|
||||
assert model.verbose != 2
|
||||
|
||||
# set verbose as something different then standard settings
|
||||
model.verbose = 2
|
||||
# Check if include works
|
||||
model.save("test_save.zip", exclude=["verbose"], include=["verbose"])
|
||||
del model
|
||||
model = model_class.load("test_save")
|
||||
assert model.verbose == 2
|
||||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
|
@ -59,8 +59,10 @@ class CustomGymEnv(gym.Env):
|
|||
@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
|
||||
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
||||
"""Test access to methods/attributes of vectorized environments"""
|
||||
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
||||
|
||||
if vec_env_wrapper is not None:
|
||||
|
|
@ -92,7 +94,6 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|||
assert (env_method_subset[1] == np.ones((1, 3))).all()
|
||||
assert len(env_method_subset) == 2
|
||||
|
||||
|
||||
# Test to change value for all the environments
|
||||
setattr_result = vec_env.set_attr('current_step', 42, indices=None)
|
||||
getattr_result = vec_env.get_attr('current_step')
|
||||
|
|
@ -193,8 +194,10 @@ SPACES = collections.OrderedDict([
|
|||
('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
||||
])
|
||||
|
||||
|
||||
def check_vecenv_spaces(vec_env_class, space, obs_assert):
|
||||
"""Helper method to check observation spaces in vectorized environments."""
|
||||
|
||||
def make_env():
|
||||
return CustomGymEnv(space)
|
||||
|
||||
|
|
@ -228,6 +231,7 @@ def test_vecenv_single_space(vec_env_class, space):
|
|||
|
||||
class _UnorderedDictSpace(gym.spaces.Dict):
|
||||
"""Like DictSpace, but returns an unordered dict when sampling."""
|
||||
|
||||
def sample(self):
|
||||
return dict(super().sample())
|
||||
|
||||
|
|
@ -301,14 +305,17 @@ class CustomWrapperB(VecNormalize):
|
|||
def name_test(self):
|
||||
return self.__class__
|
||||
|
||||
|
||||
class CustomWrapperBB(CustomWrapperB):
|
||||
def __init__(self, venv):
|
||||
CustomWrapperB.__init__(self, venv)
|
||||
self.var_bb = 'bb'
|
||||
|
||||
|
||||
def test_vecenv_wrapper_getattr():
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
|
||||
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
|
||||
assert wrapped.var_a == 'a'
|
||||
|
|
|
|||
|
|
@ -130,5 +130,5 @@ class A2C(PPO):
|
|||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True):
|
||||
|
||||
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
|
||||
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps)
|
||||
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,8 @@ class CEMRL(TD3):
|
|||
# set params
|
||||
self.actor.load_from_vector(self.es_params[i])
|
||||
self.actor_target.load_from_vector(self.es_params[i])
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate(self._current_progress))
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(),
|
||||
lr=self.learning_rate(self._current_progress))
|
||||
|
||||
# In the paper: 2 * actor_steps // self.n_grad
|
||||
# In the original implementation: actor_steps // self.n_grad
|
||||
|
|
@ -157,16 +158,3 @@ class CEMRL(TD3):
|
|||
self.es.tell(self.es_params, self.fitnesses)
|
||||
timesteps_since_eval += actor_steps
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import deque
|
||||
import os
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -11,6 +14,7 @@ from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, upda
|
|||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data
|
||||
|
||||
|
||||
class BaseRLModel(object):
|
||||
|
|
@ -34,9 +38,9 @@ class BaseRLModel(object):
|
|||
verbose=0, device='auto', support_multi_env=False,
|
||||
create_eval_env=False, monitor_wrapper=True, seed=None):
|
||||
if isinstance(policy, str) and policy_base is not None:
|
||||
self.policy = get_policy_from_name(policy_base, policy)
|
||||
self.policy_class = get_policy_from_name(policy_base, policy)
|
||||
else:
|
||||
self.policy = policy
|
||||
self.policy_class = policy
|
||||
|
||||
if device == 'auto':
|
||||
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
||||
|
|
@ -52,7 +56,6 @@ class BaseRLModel(object):
|
|||
self.action_space = None
|
||||
self.n_envs = None
|
||||
self.num_timesteps = 0
|
||||
self.params = None
|
||||
self.eval_env = None
|
||||
self.replay_buffer = None
|
||||
self.seed = seed
|
||||
|
|
@ -113,7 +116,7 @@ class BaseRLModel(object):
|
|||
(no need for symmetric action space)
|
||||
"""
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
def _setup_learning_rate(self):
|
||||
"""Transform to callable if needed."""
|
||||
|
|
@ -163,29 +166,63 @@ class BaseRLModel(object):
|
|||
"""
|
||||
return self.env
|
||||
|
||||
@staticmethod
|
||||
def check_env(env, observation_space, action_space):
|
||||
"""
|
||||
Checks the validity of the environment and returns if it is coherent
|
||||
Checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
:return: (bool) True if environment seems to be coherent
|
||||
"""
|
||||
if observation_space != env.observation_space:
|
||||
return False
|
||||
if action_space != env.action_space:
|
||||
return False
|
||||
# return true if no check failed
|
||||
return True
|
||||
|
||||
def set_env(self, env):
|
||||
"""
|
||||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
Furthermore wrap any non vectorized env into a vectorized
|
||||
checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
|
||||
:param env: (Gym Environment) The environment for learning a policy
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_parameter_list(self):
|
||||
"""
|
||||
Get pytorch Variables of model's parameters
|
||||
|
||||
This includes all variables necessary for continuing training (saving / loading).
|
||||
|
||||
:return: (list) List of pytorch Variables
|
||||
"""
|
||||
pass
|
||||
if self.check_env(env, self.observation_space, self.action_space) is False:
|
||||
raise ValueError("The given environment is not compatible with model: observation and action spaces do not match")
|
||||
# it must be coherent now
|
||||
# if it is not a VecEnv, make it a VecEnv
|
||||
if not isinstance(env, VecEnv):
|
||||
if self.verbose >= 1:
|
||||
print("Wrapping the env in a DummyVecEnv.")
|
||||
env = DummyVecEnv([lambda: env])
|
||||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
|
||||
def get_parameters(self):
|
||||
"""
|
||||
Get current model parameters as dictionary of variable name -> ndarray.
|
||||
Returns policy and optimizer parameters as a tuple
|
||||
:return: (dict,dict) policy_parameters, opt_parameters
|
||||
"""
|
||||
return self.get_policy_parameters(), self.get_opt_parameters()
|
||||
|
||||
:return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters.
|
||||
def get_policy_parameters(self):
|
||||
"""
|
||||
Get current model policy parameters as dictionary of variable name -> tensors.
|
||||
|
||||
:return: (dict) Dictionary of variable name -> tensor of model's policy parameters.
|
||||
"""
|
||||
return self.policy.state_dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
Get current model optimizer parameters as dictionary of variable names -> tensors
|
||||
:return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -237,50 +274,121 @@ class BaseRLModel(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def load_parameters(self, load_path_or_dict, exact_match=True):
|
||||
def load_parameters(self, load_dict, opt_params=None):
|
||||
"""
|
||||
Load model parameters from a file or a dictionary
|
||||
Load model parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
If opt_params are given this does also load agent's optimizer-parameters, but can only be handled in child classes.
|
||||
|
||||
Dictionary keys should be tensorflow variable names, which can be obtained
|
||||
with ``get_parameters`` function. If ``exact_match`` is True, dictionary
|
||||
should contain keys for all model's parameters, otherwise RunTimeError
|
||||
is raised. If False, only variables included in the dictionary will be updated.
|
||||
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
.. warning::
|
||||
This function does not update trainer/optimizer variables (e.g. momentum).
|
||||
As such training after using this function may lead to less-than-optimal results.
|
||||
|
||||
:param load_path_or_dict: (str or file-like or dict) Save parameter location
|
||||
or dict of parameters as variable.name -> ndarrays to be loaded.
|
||||
:param exact_match: (bool) If True, expects load dictionary to contain keys for
|
||||
all variables in the model. If False, loads parameters only for variables
|
||||
mentioned in the dictionary. Defaults to True.
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path):
|
||||
"""
|
||||
Save the current parameters to file
|
||||
|
||||
:param save_path: (str or file-like object) the save location
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
if opt_params is not None:
|
||||
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, load_path, env=None, **kwargs):
|
||||
"""
|
||||
Load the model from file
|
||||
Load the model from a zip-file
|
||||
|
||||
:param load_path: (str or file-like) the saved parameter location
|
||||
:param load_path: (str) the location of the saved data
|
||||
:param env: (Gym Envrionment) the new environment to run the loaded model on
|
||||
(can be None if you only need prediction from a trained model)
|
||||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
data, params, opt_params = cls._load_from_file(load_path)
|
||||
|
||||
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
|
||||
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs."
|
||||
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
|
||||
kwargs['policy_kwargs']))
|
||||
|
||||
# check if observation space and action space are part of the saved parameters
|
||||
if ("observation_space" not in data or "action_space" not in data) and "env" not in data:
|
||||
raise ValueError("The observation_space and action_space was not given, can't verify new environments")
|
||||
# check if given env is valid
|
||||
if env is not None and cls.check_env(env, data["observation_space"], data["action_space"]) is False:
|
||||
raise ValueError("The given environment does not comply to the model")
|
||||
# if no new env was given use stored env if possible
|
||||
if env is None and "env" in data:
|
||||
env = data["env"]
|
||||
|
||||
# first create model, but only setup if a env was given
|
||||
model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None)
|
||||
|
||||
# load parameters
|
||||
model.__dict__.update(data)
|
||||
model.__dict__.update(kwargs)
|
||||
model.load_parameters(params, opt_params)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(load_path, load_data=True):
|
||||
""" Load model data from a .zip archive
|
||||
|
||||
:param load_path: (str) Where to load the model from
|
||||
:param load_data: (bool) Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict)
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
if not os.path.exists(load_path):
|
||||
if os.path.exists(load_path + ".zip"):
|
||||
load_path += ".zip"
|
||||
else:
|
||||
raise ValueError("Error: the file {} could not be found".format(load_path))
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path, "r") as archive:
|
||||
namelist = archive.namelist()
|
||||
# If data or parameters is not in the
|
||||
# zip archive, assume they were stored
|
||||
# as None (_save_to_file_zip allows this).
|
||||
data = None
|
||||
params = None
|
||||
opt_params = None
|
||||
if "data" in namelist and load_data:
|
||||
# Load class parameters and convert to string
|
||||
json_data = archive.read("data").decode()
|
||||
data = json_to_data(json_data)
|
||||
|
||||
if "params.pth" in namelist:
|
||||
# Load parameters with build in torch function
|
||||
with archive.open("params.pth", mode="r") as param_file:
|
||||
# File has to be seekable, but param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
params = th.load(file_content)
|
||||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
opt_params = dict()
|
||||
for file_path in other_files:
|
||||
with archive.open(file_path, mode="r") as opt_param_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# save the parameters in dict with file name but trim file ending
|
||||
opt_params[os.path.splitext(file_path)[0]] = th.load(file_content)
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError("Error: the file {} wasn't a zip-file".format(load_path))
|
||||
|
||||
return data, params, opt_params
|
||||
|
||||
def set_random_seed(self, seed=None):
|
||||
"""
|
||||
|
|
@ -383,7 +491,8 @@ class BaseRLModel(object):
|
|||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (episode_num + total_episodes) % log_interval == 0:
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
# logger.logkv("mean 100 episode reward", mean_reward)
|
||||
|
|
@ -398,5 +507,80 @@ class BaseRLModel(object):
|
|||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path, data=None, params=None, opt_params=None):
|
||||
"""Save model to a zip archive
|
||||
|
||||
:param save_path: (str) Where to store the model
|
||||
:param data: (dict) Class parameters being stored
|
||||
:param params: (dict) Model parameters being stored expected to be state_dict
|
||||
:param opt_params: (dict) Optimizer parameters being stored expected to contain an entry for every
|
||||
optimizer with its name and the state_dict
|
||||
"""
|
||||
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Check postfix if save_path is a string
|
||||
if isinstance(save_path, str):
|
||||
_, ext = os.path.splitext(save_path)
|
||||
if ext == "":
|
||||
save_path += ".zip"
|
||||
|
||||
# Create a zip-archive and write our objects
|
||||
# there. This works when save_path is either
|
||||
# str or a file-like
|
||||
with zipfile.ZipFile(save_path, "w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if params is not None:
|
||||
with archive.open('params.pth', mode="w") as param_file:
|
||||
th.save(params, param_file)
|
||||
if opt_params is not None:
|
||||
for file_name, dict in opt_params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict, opt_param_file)
|
||||
|
||||
@staticmethod
|
||||
def excluded_save_params():
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
when saving the model.
|
||||
|
||||
:return: ([str]) List of parameters that should be excluded from save
|
||||
"""
|
||||
return ["env", "eval_env", "replay_buffer", "rollout_buffer"]
|
||||
|
||||
def save(self, path, exclude=None, include=None):
|
||||
"""
|
||||
Save all the attributes of the object and the model parameters in a zip-file.
|
||||
|
||||
:param path: (str) path to the file where the rl agent should be saved
|
||||
:param exclude: ([str]) name of parameters that should be excluded in addition to the default one
|
||||
:param include: ([str]) name of parameters that might be excluded but should be included anyway
|
||||
"""
|
||||
# copy parameter list so we don't mutate the original dict
|
||||
data = self.__dict__.copy()
|
||||
# use standard list of excluded parameters if none given
|
||||
if exclude is None:
|
||||
exclude = self.excluded_save_params()
|
||||
else:
|
||||
# append standard exclude params to the given params
|
||||
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
|
||||
# do not exclude params if they are specifically included
|
||||
if include is not None:
|
||||
exclude = [param_name for param_name in exclude if param_name not in include]
|
||||
|
||||
# remove parameter entries of parameters which are to be excluded
|
||||
for param_name in exclude:
|
||||
if param_name in data:
|
||||
data.pop(param_name, None)
|
||||
|
||||
params_to_save = self.get_policy_parameters()
|
||||
opt_params_to_save = self.get_opt_parameters()
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from torch.distributions import Normal, Categorical
|
|||
import torch.nn.functional as F
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
|
@ -144,7 +145,8 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
self.gaussian_action = None
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, deterministic=False):
|
||||
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic)
|
||||
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std,
|
||||
deterministic)
|
||||
return action, self
|
||||
|
||||
def mode(self):
|
||||
|
|
|
|||
105
torchy_baselines/common/identity_env.py
Normal file
105
torchy_baselines/common/identity_env.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import numpy as np
|
||||
|
||||
from gym import Env
|
||||
from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box
|
||||
|
||||
|
||||
class IdentityEnv(Env):
|
||||
def __init__(self, dim, ep_length=100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
:param dim: (int) the size of the dimensions you want to learn
|
||||
:param ep_length: (int) the length of each episodes in timesteps
|
||||
"""
|
||||
self.action_space = Discrete(dim)
|
||||
self.observation_space = self.action_space
|
||||
self.ep_length = ep_length
|
||||
self.current_step = 0
|
||||
self.dim = dim
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
self._choose_next_state()
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
reward = self._get_reward(action)
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
done = self.current_step >= self.ep_length
|
||||
return self.state, reward, done, {}
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.action_space.sample()
|
||||
|
||||
def _get_reward(self, action):
|
||||
return 1 if np.all(self.state == action) else 0
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass
|
||||
|
||||
|
||||
class IdentityEnvBox(IdentityEnv):
|
||||
def __init__(self, low=-1, high=1, eps=0.05, ep_length=100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
:param dim: (int) the size of the dimensions you want to learn
|
||||
:param low: (float) the lower bound of the box dim
|
||||
:param high: (float) the upper bound of the box dim
|
||||
:param eps: (float) the epsilon bound for correct value
|
||||
:param ep_length: (int) the length of each episodes in timesteps
|
||||
"""
|
||||
super(IdentityEnvBox, self).__init__(1, ep_length)
|
||||
self.action_space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
|
||||
self.observation_space = self.action_space
|
||||
self.eps = eps
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
self._choose_next_state()
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
reward = self._get_reward(action)
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
done = self.current_step >= self.ep_length
|
||||
return self.state, reward, done, {}
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.observation_space.sample()
|
||||
|
||||
def _get_reward(self, action):
|
||||
return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0
|
||||
|
||||
|
||||
class IdentityEnvMultiDiscrete(IdentityEnv):
|
||||
def __init__(self, dim, ep_length=100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
:param dim: (int) the size of the dimensions you want to learn
|
||||
:param ep_length: (int) the length of each episodes in timesteps
|
||||
"""
|
||||
super(IdentityEnvMultiDiscrete, self).__init__(dim, ep_length)
|
||||
self.action_space = MultiDiscrete([dim, dim])
|
||||
self.observation_space = self.action_space
|
||||
self.reset()
|
||||
|
||||
|
||||
class IdentityEnvMultiBinary(IdentityEnv):
|
||||
def __init__(self, dim, ep_length=100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
:param dim: (int) the size of the dimensions you want to learn
|
||||
:param ep_length: (int) the length of each episodes in timesteps
|
||||
"""
|
||||
super(IdentityEnvMultiBinary, self).__init__(dim, ep_length)
|
||||
self.action_space = MultiBinary(dim)
|
||||
self.observation_space = self.action_space
|
||||
self.reset()
|
||||
134
torchy_baselines/common/save_util.py
Normal file
134
torchy_baselines/common/save_util.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Save util taken from stable_baselines
|
||||
used to serialize data (class parameters) of model classes
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import base64
|
||||
import pickle
|
||||
import cloudpickle
|
||||
|
||||
|
||||
def is_json_serializable(item):
|
||||
"""
|
||||
Test if an object is serializable into JSON
|
||||
|
||||
:param item: (object) The object to be tested for JSON serialization.
|
||||
:return: (bool) True if object is JSON serializable, false otherwise.
|
||||
"""
|
||||
# Try with try-except struct.
|
||||
json_serializable = True
|
||||
try:
|
||||
_ = json.dumps(item)
|
||||
except TypeError:
|
||||
json_serializable = False
|
||||
return json_serializable
|
||||
|
||||
|
||||
def data_to_json(data):
|
||||
"""
|
||||
Turn data (class parameters) into a JSON string for storing
|
||||
|
||||
:param data: (Dict) Dictionary of class parameters to be
|
||||
stored. Items that are not JSON serializable will be
|
||||
pickled with Cloudpickle and stored as bytearray in
|
||||
the JSON file
|
||||
:return: (str) JSON string of the data serialized.
|
||||
"""
|
||||
# First, check what elements can not be JSONfied,
|
||||
# and turn them into byte-strings
|
||||
serializable_data = {}
|
||||
for data_key, data_item in data.items():
|
||||
# See if object is JSON serializable
|
||||
if is_json_serializable(data_item):
|
||||
# All good, store as it is
|
||||
serializable_data[data_key] = data_item
|
||||
else:
|
||||
# Not serializable, cloudpickle it into
|
||||
# bytes and convert to base64 string for storing.
|
||||
# Also store type of the class for consumption
|
||||
# from other languages/humans, so we have an
|
||||
# idea what was being stored.
|
||||
base64_encoded = base64.b64encode(
|
||||
cloudpickle.dumps(data_item)
|
||||
).decode()
|
||||
|
||||
# Use ":" to make sure we do
|
||||
# not override these keys
|
||||
# when we include variables of the object later
|
||||
cloudpickle_serialization = {
|
||||
":type:": str(type(data_item)),
|
||||
":serialized:": base64_encoded
|
||||
}
|
||||
|
||||
# Add first-level JSON-serializable items of the
|
||||
# object for further details (but not deeper than this to
|
||||
# avoid deep nesting).
|
||||
# First we check that object has attributes (not all do,
|
||||
# e.g. numpy scalars)
|
||||
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
|
||||
# Take elements from __dict__ for custom classes
|
||||
item_generator = (
|
||||
data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
|
||||
)
|
||||
for variable_name, variable_item in item_generator():
|
||||
# Check if serializable. If not, just include the
|
||||
# string-representation of the object.
|
||||
if is_json_serializable(variable_item):
|
||||
cloudpickle_serialization[variable_name] = variable_item
|
||||
else:
|
||||
cloudpickle_serialization[variable_name] = str(variable_item)
|
||||
|
||||
serializable_data[data_key] = cloudpickle_serialization
|
||||
json_string = json.dumps(serializable_data, indent=4)
|
||||
return json_string
|
||||
|
||||
|
||||
def json_to_data(json_string, custom_objects=None):
|
||||
"""
|
||||
Turn JSON serialization of class-parameters back into dictionary.
|
||||
|
||||
:param json_string: (str) JSON serialization of the class-parameters
|
||||
that should be loaded.
|
||||
:param custom_objects: (dict) Dictionary of objects to replace
|
||||
upon loading. If a variable is present in this dictionary as a
|
||||
key, it will not be deserialized and the corresponding item
|
||||
will be used instead. Similar to custom_objects in
|
||||
`keras.models.load_model`. Useful when you have an object in
|
||||
file that can not be deserialized.
|
||||
:return: (dict) Loaded class parameters.
|
||||
"""
|
||||
if custom_objects is not None and not isinstance(custom_objects, dict):
|
||||
raise ValueError("custom_objects argument must be a dict or None")
|
||||
|
||||
json_dict = json.loads(json_string)
|
||||
# This will be filled with deserialized data
|
||||
return_data = {}
|
||||
for data_key, data_item in json_dict.items():
|
||||
if custom_objects is not None and data_key in custom_objects.keys():
|
||||
# If item is provided in custom_objects, replace
|
||||
# the one from JSON with the one in custom_objects
|
||||
return_data[data_key] = custom_objects[data_key]
|
||||
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
|
||||
# If item is dictionary with ":serialized:"
|
||||
# key, this means it is serialized with cloudpickle.
|
||||
serialization = data_item[":serialized:"]
|
||||
# Try-except deserialization in case we run into
|
||||
# errors. If so, we can tell bit more information to
|
||||
# user.
|
||||
try:
|
||||
deserialized_object = cloudpickle.loads(
|
||||
base64.b64decode(serialization.encode())
|
||||
)
|
||||
except pickle.UnpicklingError:
|
||||
raise RuntimeError(
|
||||
"Could not deserialize object {}. ".format(data_key) +
|
||||
"Consider using `custom_objects` argument to replace " +
|
||||
"this object."
|
||||
)
|
||||
return_data[data_key] = deserialized_object
|
||||
else:
|
||||
# Read as it is
|
||||
return_data[data_key] = data_item
|
||||
return return_data
|
||||
|
|
@ -42,6 +42,7 @@ def _worker(remote, parent_remote, env_fn_wrapper):
|
|||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
def tile_images(img_nhwc):
|
||||
"""
|
||||
Tile N images into one big PxQ image
|
||||
|
|
@ -68,7 +69,6 @@ def tile_images(img_nhwc):
|
|||
return out_image
|
||||
|
||||
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
"""
|
||||
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import gym
|
|||
from gym import spaces
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Check if tensorboard is available for pytorch
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
|
@ -118,9 +119,9 @@ class PPO(BaseRLModel):
|
|||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
|
|
@ -193,7 +194,6 @@ class PPO(BaseRLModel):
|
|||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
approx_kl_divs = []
|
||||
# Sample replay buffer
|
||||
|
|
@ -227,7 +227,6 @@ class PPO(BaseRLModel):
|
|||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(return_batch, values_pred)
|
||||
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
|
|
@ -314,14 +313,22 @@ class PPO(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
|
||||
:return: (dict) of optimizer names and their state_dict
|
||||
"""
|
||||
return {"opt": self.policy.optimizer.state_dict()}
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts
|
||||
"""
|
||||
self.policy.optimizer.load_state_dict(opt_params["opt"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class SAC(BaseRLModel):
|
|||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6),
|
||||
learning_starts=100, batch_size=64,
|
||||
tau=0.005, ent_coef='auto', target_update_interval=1,
|
||||
|
|
@ -123,8 +124,8 @@ class SAC(BaseRLModel):
|
|||
self.ent_coef = float(self.ent_coef)
|
||||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
|
|
@ -274,15 +275,28 @@ class SAC(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
"""
|
||||
opt_dict = {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
if self.ent_coef_optimizer is not None:
|
||||
opt_dict.update({"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()})
|
||||
return opt_dict
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
if "ent_coef_optimizer" in opt_params:
|
||||
self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class TD3(BaseRLModel):
|
|||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,
|
||||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
|
|
@ -79,8 +80,8 @@ class TD3(BaseRLModel):
|
|||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
|
|
@ -148,7 +149,9 @@ class TD3(BaseRLModel):
|
|||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||||
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|
||||
|
||||
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, tau_critic=0.005, replay_data=None):
|
||||
def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005,
|
||||
tau_critic=0.005,
|
||||
replay_data=None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.actor.optimizer)
|
||||
|
||||
|
|
@ -234,15 +237,23 @@ class TD3(BaseRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def save(self, path):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
th.save(self.policy.state_dict(), path)
|
||||
def get_opt_parameters(self):
|
||||
"""
|
||||
Returns a dict of all the optimizers and their parameters
|
||||
|
||||
def load(self, path, env=None, **_kwargs):
|
||||
if not path.endswith('.pth'):
|
||||
path += '.pth'
|
||||
if env is not None:
|
||||
pass
|
||||
self.policy.load_state_dict(th.load(path))
|
||||
self._create_aliases()
|
||||
:return: (Dict) of optimizer names and their state_dict
|
||||
"""
|
||||
return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()}
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters and optimizer parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
This does not load agent's hyper-parameters.
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
self.actor.optimizer.load_state_dict(opt_params["actor"])
|
||||
self.critic.optimizer.load_state_dict(opt_params["critic"])
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
|
|
|||
Loading…
Reference in a new issue