From cc744a48b5047edca8c02268d62f08e2e8f0f582 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Tue, 12 Nov 2019 17:03:57 +0100 Subject: [PATCH 01/50] first save and load features --- tests/test_run.py | 10 ++ tests/test_save_load.py | 50 +++++++++ torchy_baselines/a2c/a2c.py | 12 +++ torchy_baselines/common/base_class.py | 132 ++++++++++++++++++++--- torchy_baselines/common/identity_env.py | 105 +++++++++++++++++++ torchy_baselines/common/save_util.py | 134 ++++++++++++++++++++++++ torchy_baselines/ppo/ppo.py | 41 ++++++-- 7 files changed, 459 insertions(+), 25 deletions(-) create mode 100644 tests/test_save_load.py create mode 100644 torchy_baselines/common/identity_env.py create mode 100644 torchy_baselines/common/save_util.py diff --git a/tests/test_run.py b/tests/test_run.py index 32a4b30..b578824 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -19,6 +19,16 @@ def test_td3(): os.remove("test_save.pth") + + +def test_a2c(): + model = A2C('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save") + model.load("test_save") + os.remove("test_save.pth") + + def test_cemrl(): model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise) diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 0000000..ca3f637 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,50 @@ +import os + +import pytest +import numpy as np + +from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines.common.noise import NormalActionNoise +from torchy_baselines.common.vec_env import DummyVecEnv +from torchy_baselines.common.identity_env import IdentityEnvBox + +MODEL_LIST = [ + PPO +] + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_save_load(model_class): + """ + Test if 'save' and 'load' saves and loads model correctly + + :param model_class: (BaseRLModel) A RL model + """ + env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + + # create model + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + + # test action probability for given (obs, action) pair + env = model.get_env() + obs = env.reset() + observations = np.array([obs for _ in range(10)]) + observations = np.squeeze(observations) + + #actions = np.array([env.action_space.sample() for _ in range(10)]) + + # Get dictionary of current parameters + params = model.get_parameters() + + # Modify all parameters to be random values + random_params = dict((param_name,np.random.random(size=param.shape)) for param_name, param in params.items()) + # Update model parameters with the new zeroed values + model.load_parameters(random_params) + # Get new action probas + #... + + # Check + model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save.zip") + model = model.load("test_save") + os.remove("test_save.zip") diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 6ee6f4a..57faabd 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -124,3 +124,15 @@ class A2C(PPO): return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) + + def save(self, path): + if not path.endswith('.pth'): + path += '.pth' + th.save(self.policy.state_dict(), path) + + def load(self, path, env=None, **_kwargs): + if not path.endswith('.pth'): + path += '.pth' + if env is not None: + pass + self.policy.load_state_dict(th.load(path)) \ No newline at end of file diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 6adc45c..7daffe2 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -12,6 +12,12 @@ from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv from torchy_baselines.common.monitor import Monitor from torchy_baselines.common import logger +# for storing and loging +import os +import io +import zipfile +from torchy_baselines.common.save_util import (data_to_json, json_to_data) + class BaseRLModel(object): """ @@ -57,6 +63,7 @@ class BaseRLModel(object): self.replay_buffer = None self.seed = seed self.action_noise = None + self.params = None # Track the training progress (from 1 to 0) # this is used to update the learning rate self._current_progress = 1 @@ -113,7 +120,7 @@ class BaseRLModel(object): (no need for symmetric action space) """ low, high = self.action_space.low, self.action_space.high - return low + (0.5 * (scaled_action + 1.0) * (high - low)) + return low + (0.5 * (scaled_action + 1.0) * (high - low)) def _setup_learning_rate(self): """Transform to callable if needed.""" @@ -179,7 +186,7 @@ class BaseRLModel(object): :return: (list) List of pytorch Variables """ - pass + return self.params def get_parameters(self): """ @@ -187,7 +194,7 @@ class BaseRLModel(object): :return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters. """ - raise NotImplementedError() + return self.policy.state_dict() def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, adam_epsilon=1e-8, val_interval=None): @@ -237,14 +244,11 @@ class BaseRLModel(object): """ pass - def load_parameters(self, load_path_or_dict, exact_match=True): + def load_parameters(self, load_dict, exact_match=True): """ - Load model parameters from a file or a dictionary + Load model parameters from a dictionary - Dictionary keys should be tensorflow variable names, which can be obtained - with ``get_parameters`` function. If ``exact_match`` is True, dictionary - should contain keys for all model's parameters, otherwise RunTimeError - is raised. If False, only variables included in the dictionary will be updated. + Dictionary should be of shape torch model.state_dict() This does not load agent's hyper-parameters. @@ -252,13 +256,10 @@ class BaseRLModel(object): This function does not update trainer/optimizer variables (e.g. momentum). As such training after using this function may lead to less-than-optimal results. - :param load_path_or_dict: (str or file-like or dict) Save parameter location - or dict of parameters as variable.name -> ndarrays to be loaded. - :param exact_match: (bool) If True, expects load dictionary to contain keys for - all variables in the model. If False, loads parameters only for variables - mentioned in the dictionary. Defaults to True. + + :param load_path_or_dict: (dict) dict of parameters from model.state_dict() """ - raise NotImplementedError() + self.policy.load_state_dict(load_dict) @abstractmethod def save(self, save_path): @@ -280,7 +281,103 @@ class BaseRLModel(object): (can be None if you only need prediction from a trained model) :param kwargs: extra arguments to change the model when loading """ - raise NotImplementedError() + data, params = cls._load_from_file(load_path) + + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs." + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],kwargs['policy_kwargs'])) + + model = cls(policy=data["policy"],env=None, _init_setup_model=False) + model.__dict__.update(data) + model.__dict__.update(kwargs) + model.set_env(env) + model.load_parameters(params) + + return model + + + @staticmethod + def _save_to_file_zip(save_path, data=None, params=None): + """Save model to a zip archive + + :param save_path: (str or file-like) Where to store the model + :param data: (OrderedDict) Class parameters being stored + :param params: (OrderedDict) Model parameters being stored expexted to be state_dict + """ + + # data/params can be None, so do not + # try to serialize them blindly + if data is not None: + serialized_data = data_to_json(data) + + # Check postfix if save_path is a string + if isinstance(save_path, str): + _, ext = os.path.splitext(save_path) + if ext == "": + save_path += ".zip" + + # Create a zip-archive and write our objects + # there. This works when save_path is either + # str or a file-like + with zipfile.ZipFile(save_path, "w") as file_: + # Do not try to save "None" elements + if data is not None: + file_.writestr("data",serialized_data) + if params is not None: + with file_.open('param.pth', mode="w") as param_file: + th.save(params,param_file) + + @staticmethod + def _load_from_file(load_path, load_data = True): + """ Load model data from a .zip archive + + :param load_path: (str or file-like) Where to load the model from + :param load_data: (bool) Whether we should load and return data + (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) + :return: (dict. OrderedDict),(dict. OrderedDict) Class parameters and model parameters (state_dict) + """ + # Check if file exists if load_path is a string + if isinstance(load_path, str): + if not os.path.exists(load_path): + if os.path.exists(load_path + ".zip"): + load_path += ".zip" + else: + raise ValueError("Error: the file {} could not be found".format(load_path)) + + # Open the zip archive and load data + try: + with zipfile.ZipFile(load_path,"r") as file_: + namelist = file_.namelist() + # If data or parameters is not in the + # zip archive, assume they were stored + # as None (_save_to_file_zip allows this). + data = None + params = None + if "data" in namelist and load_data: + # Load class parameters and convert to string + json_data = file_.read("data").decode() + data = json_to_data(json_data) + + if "param.pth" in namelist: + # Load parameters with build in torch function + with file_.open("param.pth", mode="r") as param_file: + # File has to be seekable so load in BytesIO first + file_content = io.BytesIO() + file_content.write(param_file.read()) + # go to start of file + file_content.seek(0) + params = th.load(file_content) + except zipfile.BadZipFile: + # load_path wasn't a zip file + raise ValueError("Error: the file {} wasn't a zip-file".format(load_path)) + + return data, params + + + + + + def set_random_seed(self, seed=0): set_random_seed(seed, using_cuda=self.device == th.device('cuda')) @@ -375,7 +472,8 @@ class BaseRLModel(object): action_noise.reset() # Display training infos - if self.verbose >= 1 and log_interval is not None and (episode_num + total_episodes) % log_interval == 0: + if self.verbose >= 1 and log_interval is not None and ( + episode_num + total_episodes) % log_interval == 0: fps = int(num_timesteps / (time.time() - self.start_time)) logger.logkv("episodes", episode_num + total_episodes) # logger.logkv("mean 100 episode reward", mean_reward) diff --git a/torchy_baselines/common/identity_env.py b/torchy_baselines/common/identity_env.py new file mode 100644 index 0000000..d815220 --- /dev/null +++ b/torchy_baselines/common/identity_env.py @@ -0,0 +1,105 @@ +import numpy as np + +from gym import Env +from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box + + +class IdentityEnv(Env): + def __init__(self, dim, ep_length=100): + """ + Identity environment for testing purposes + + :param dim: (int) the size of the dimensions you want to learn + :param ep_length: (int) the length of each episodes in timesteps + """ + self.action_space = Discrete(dim) + self.observation_space = self.action_space + self.ep_length = ep_length + self.current_step = 0 + self.dim = dim + self.reset() + + def reset(self): + self.current_step = 0 + self._choose_next_state() + return self.state + + def step(self, action): + reward = self._get_reward(action) + self._choose_next_state() + self.current_step += 1 + done = self.current_step >= self.ep_length + return self.state, reward, done, {} + + def _choose_next_state(self): + self.state = self.action_space.sample() + + def _get_reward(self, action): + return 1 if np.all(self.state == action) else 0 + + def render(self, mode='human'): + pass + + +class IdentityEnvBox(IdentityEnv): + def __init__(self, low=-1, high=1, eps=0.05, ep_length=100): + """ + Identity environment for testing purposes + + :param dim: (int) the size of the dimensions you want to learn + :param low: (float) the lower bound of the box dim + :param high: (float) the upper bound of the box dim + :param eps: (float) the epsilon bound for correct value + :param ep_length: (int) the length of each episodes in timesteps + """ + super(IdentityEnvBox, self).__init__(1, ep_length) + self.action_space = Box(low=low, high=high, shape=(1,), dtype=np.float32) + self.observation_space = self.action_space + self.eps = eps + self.reset() + + def reset(self): + self.current_step = 0 + self._choose_next_state() + return self.state + + def step(self, action): + reward = self._get_reward(action) + self._choose_next_state() + self.current_step += 1 + done = self.current_step >= self.ep_length + return self.state, reward, done, {} + + def _choose_next_state(self): + self.state = self.observation_space.sample() + + def _get_reward(self, action): + return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0 + + +class IdentityEnvMultiDiscrete(IdentityEnv): + def __init__(self, dim, ep_length=100): + """ + Identity environment for testing purposes + + :param dim: (int) the size of the dimensions you want to learn + :param ep_length: (int) the length of each episodes in timesteps + """ + super(IdentityEnvMultiDiscrete, self).__init__(dim, ep_length) + self.action_space = MultiDiscrete([dim, dim]) + self.observation_space = self.action_space + self.reset() + + +class IdentityEnvMultiBinary(IdentityEnv): + def __init__(self, dim, ep_length=100): + """ + Identity environment for testing purposes + + :param dim: (int) the size of the dimensions you want to learn + :param ep_length: (int) the length of each episodes in timesteps + """ + super(IdentityEnvMultiBinary, self).__init__(dim, ep_length) + self.action_space = MultiBinary(dim) + self.observation_space = self.action_space + self.reset() diff --git a/torchy_baselines/common/save_util.py b/torchy_baselines/common/save_util.py new file mode 100644 index 0000000..30dd0b2 --- /dev/null +++ b/torchy_baselines/common/save_util.py @@ -0,0 +1,134 @@ +""" +Save util taken from stable_baselines +used to serialize data (class parameters) of model classes +""" + + +import json +import base64 +import pickle +import cloudpickle + + +def is_json_serializable(item): + """ + Test if an object is serializable into JSON + + :param item: (object) The object to be tested for JSON serialization. + :return: (bool) True if object is JSON serializable, false otherwise. + """ + # Try with try-except struct. + json_serializable = True + try: + _ = json.dumps(item) + except TypeError: + json_serializable = False + return json_serializable + + +def data_to_json(data): + """ + Turn data (class parameters) into a JSON string for storing + + :param data: (Dict) Dictionary of class parameters to be + stored. Items that are not JSON serializable will be + pickled with Cloudpickle and stored as bytearray in + the JSON file + :return: (str) JSON string of the data serialized. + """ + # First, check what elements can not be JSONfied, + # and turn them into byte-strings + serializable_data = {} + for data_key, data_item in data.items(): + # See if object is JSON serializable + if is_json_serializable(data_item): + # All good, store as it is + serializable_data[data_key] = data_item + else: + # Not serializable, cloudpickle it into + # bytes and convert to base64 string for storing. + # Also store type of the class for consumption + # from other languages/humans, so we have an + # idea what was being stored. + base64_encoded = base64.b64encode( + cloudpickle.dumps(data_item) + ).decode() + + # Use ":" to make sure we do + # not override these keys + # when we include variables of the object later + cloudpickle_serialization = { + ":type:": str(type(data_item)), + ":serialized:": base64_encoded + } + + # Add first-level JSON-serializable items of the + # object for further details (but not deeper than this to + # avoid deep nesting). + # First we check that object has attributes (not all do, + # e.g. numpy scalars) + if hasattr(data_item, "__dict__") or isinstance(data_item, dict): + # Take elements from __dict__ for custom classes + item_generator = ( + data_item.items if isinstance(data_item, dict) else data_item.__dict__.items + ) + for variable_name, variable_item in item_generator(): + # Check if serializable. If not, just include the + # string-representation of the object. + if is_json_serializable(variable_item): + cloudpickle_serialization[variable_name] = variable_item + else: + cloudpickle_serialization[variable_name] = str(variable_item) + + serializable_data[data_key] = cloudpickle_serialization + json_string = json.dumps(serializable_data, indent=4) + return json_string + + +def json_to_data(json_string, custom_objects=None): + """ + Turn JSON serialization of class-parameters back into dictionary. + + :param json_string: (str) JSON serialization of the class-parameters + that should be loaded. + :param custom_objects: (dict) Dictionary of objects to replace + upon loading. If a variable is present in this dictionary as a + key, it will not be deserialized and the corresponding item + will be used instead. Similar to custom_objects in + `keras.models.load_model`. Useful when you have an object in + file that can not be deserialized. + :return: (dict) Loaded class parameters. + """ + if custom_objects is not None and not isinstance(custom_objects, dict): + raise ValueError("custom_objects argument must be a dict or None") + + json_dict = json.loads(json_string) + # This will be filled with deserialized data + return_data = {} + for data_key, data_item in json_dict.items(): + if custom_objects is not None and data_key in custom_objects.keys(): + # If item is provided in custom_objects, replace + # the one from JSON with the one in custom_objects + return_data[data_key] = custom_objects[data_key] + elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): + # If item is dictionary with ":serialized:" + # key, this means it is serialized with cloudpickle. + serialization = data_item[":serialized:"] + # Try-except deserialization in case we run into + # errors. If so, we can tell bit more information to + # user. + try: + deserialized_object = cloudpickle.loads( + base64.b64decode(serialization.encode()) + ) + except pickle.UnpicklingError: + raise RuntimeError( + "Could not deserialize object {}. ".format(data_key) + + "Consider using `custom_objects` argument to replace " + + "this object." + ) + return_data[data_key] = deserialized_object + else: + # Read as it is + return_data[data_key] = data_item + return return_data \ No newline at end of file diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 13a1634..e6467f3 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -6,6 +6,7 @@ import gym from gym import spaces import torch as th import torch.nn.functional as F + # Check if tensorboard is available for pytorch try: from torch.utils.tensorboard import SummaryWriter @@ -185,7 +186,6 @@ class PPO(BaseRLModel): clip_range_vf = self.clip_range_vf(self._current_progress) logger.logkv("clip_range_vf", clip_range_vf) - for gradient_step in range(gradient_steps): approx_kl_divs = [] # Sample replay buffer @@ -219,7 +219,6 @@ class PPO(BaseRLModel): # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(return_batch, values_pred) - # Entropy loss favor exploration entropy_loss = -th.mean(entropy) @@ -234,7 +233,7 @@ class PPO(BaseRLModel): approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: - print("Early stopping at step {} due to reaching max kl: {:.2f}".format(it, np.mean(approx_kl_divs))) + print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step, np.mean(approx_kl_divs))) break # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), @@ -294,13 +293,39 @@ class PPO(BaseRLModel): return self def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + """ + saves all the params from init and pytorch params in a file for continous learning - def load(self, path, env=None, **_kwargs): + :param path: path to the file where the data should be safed + :return: + """ + + data = { + "gamma": self.gamma, + "n_steps": self.n_steps, + "vf_coef": self.vf_coef, + "ent_coef": self.ent_coef, + "max_grad_norm": self.max_grad_norm, + "learning_rate": self.learning_rate, + "gae_lambda": self.gae_lambda, + "n_epochs": self.n_epochs, + "clip_range": self.clip_range, + "clip_range_vf": self.clip_range_vf, + "batch_size": self.batch_size, + "target_kl": self.target_kl, + "tensorboard_log": self.tensorboard_log, + "policy_kwargs": self.policy_kwargs, + "policy": self.policy, + + } + + params_to_save = self.get_parameters() + + self._save_to_file_zip(path, data=data, params=params_to_save) + + """def load(self, path, env=None, **_kwargs): if not path.endswith('.pth'): path += '.pth' if env is not None: pass - self.policy.load_state_dict(th.load(path)) + self.policy.load_state_dict(th.load(path))""" From 6cf80ccfe2e20cb0e68c582afbfa543096a1ed04 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Tue, 12 Nov 2019 17:12:10 +0100 Subject: [PATCH 02/50] reordered imports --- torchy_baselines/common/base_class.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 7daffe2..fd85b0a 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -1,6 +1,9 @@ import time from abc import ABCMeta, abstractmethod from collections import deque +import os +import io +import zipfile import gym import torch as th @@ -11,12 +14,7 @@ from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, upda from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv from torchy_baselines.common.monitor import Monitor from torchy_baselines.common import logger - -# for storing and loging -import os -import io -import zipfile -from torchy_baselines.common.save_util import (data_to_json, json_to_data) +from torchy_baselines.common.save_util import data_to_json, json_to_data class BaseRLModel(object): From 4b6234a1c84384aa3a5c3be377006b1a4f749cb2 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 11:39:47 +0100 Subject: [PATCH 03/50] finished test_save_load.py test --- tests/test_save_load.py | 31 +++++++++++++++++---------- torchy_baselines/common/base_class.py | 13 ++++++++--- torchy_baselines/ppo/ppo.py | 2 +- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index ca3f637..9941e15 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,8 +1,9 @@ import os import pytest +import copy import numpy as np - +import torch from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.noise import NormalActionNoise from torchy_baselines.common.vec_env import DummyVecEnv @@ -17,6 +18,7 @@ MODEL_LIST = [ def test_save_load(model_class): """ Test if 'save' and 'load' saves and loads model correctly + and if 'load_parameters' and 'get_policy_parameters' work correctly :param model_class: (BaseRLModel) A RL model """ @@ -26,25 +28,32 @@ def test_save_load(model_class): model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) # test action probability for given (obs, action) pair - env = model.get_env() - obs = env.reset() - observations = np.array([obs for _ in range(10)]) - observations = np.squeeze(observations) - - #actions = np.array([env.action_space.sample() for _ in range(10)]) # Get dictionary of current parameters - params = model.get_parameters() + params = copy.deepcopy(model.get_policy_parameters()) # Modify all parameters to be random values - random_params = dict((param_name,np.random.random(size=param.shape)) for param_name, param in params.items()) + random_params = dict((param_name, torch.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new zeroed values model.load_parameters(random_params) - # Get new action probas - #... + + # shared items + new_params = model.get_policy_parameters() + shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + # Check that at least some actions are chosen different now + assert not len(shared_items) == len(new_params), "Selected actions did not change " \ + "after changing model parameters." + + params = new_params # Check model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save.zip") model = model.load("test_save") + + #check if params are still the same after load + new_params = model.get_policy_parameters() + shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + # Check that at least some actions are chosen different now + assert len(shared_items) == len(new_params), "Parameters not the same after save and load." os.remove("test_save.zip") diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index fd85b0a..d0c0d5d 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -186,14 +186,21 @@ class BaseRLModel(object): """ return self.params - def get_parameters(self): + def get_policy_parameters(self): """ - Get current model parameters as dictionary of variable name -> ndarray. + Get current model policy parameters as dictionary of variable name -> tensors. - :return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters. + :return: (OrderedDict) Dictionary of variable name -> tensor of model's policy parameters. """ return self.policy.state_dict() + def get_optim_parameters(self): + """ + Get current model optimizer parameters as dictionary of variable names -> tensors + :return: (OrderedDict) Dictionary of variable name -> tensor of model's optimizer parameters + """ + raise NotImplementedError() + def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, adam_epsilon=1e-8, val_interval=None): """ diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index e6467f3..f22297b 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -319,7 +319,7 @@ class PPO(BaseRLModel): } - params_to_save = self.get_parameters() + params_to_save = self.get_policy_parameters() self._save_to_file_zip(path, data=data, params=params_to_save) From 5bca52a87dc3068184afc1612e8f4bc85973e632 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 11:44:37 +0100 Subject: [PATCH 04/50] rearranged imports --- tests/test_save_load.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9941e15..aff6481 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,11 +1,10 @@ import os - import pytest -import copy -import numpy as np -import torch +from copy import deepcopy + +import torch as th + from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 -from torchy_baselines.common.noise import NormalActionNoise from torchy_baselines.common.vec_env import DummyVecEnv from torchy_baselines.common.identity_env import IdentityEnvBox @@ -30,16 +29,16 @@ def test_save_load(model_class): # test action probability for given (obs, action) pair # Get dictionary of current parameters - params = copy.deepcopy(model.get_policy_parameters()) + params = deepcopy(model.get_policy_parameters()) # Modify all parameters to be random values - random_params = dict((param_name, torch.rand_like(param)) for param_name, param in params.items()) + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new zeroed values model.load_parameters(random_params) # shared items new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now assert not len(shared_items) == len(new_params), "Selected actions did not change " \ "after changing model parameters." @@ -53,7 +52,7 @@ def test_save_load(model_class): #check if params are still the same after load new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and torch.all(torch.eq(params[k], new_params[k]))} + shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now assert len(shared_items) == len(new_params), "Parameters not the same after save and load." os.remove("test_save.zip") From b20b70db48caa8995df5c90596f7ff73e5eb4a72 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 11:51:47 +0100 Subject: [PATCH 05/50] Clean reformat --- tests/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index aff6481..01b0a25 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -50,7 +50,7 @@ def test_save_load(model_class): model.save("test_save.zip") model = model.load("test_save") - #check if params are still the same after load + # check if params are still the same after load new_params = model.get_policy_parameters() shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now From a7655ca6e1050a009a16a85f0df5e3879b603c82 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 13:01:03 +0100 Subject: [PATCH 06/50] Reformated every file with PEP 8 errors --- docs/conf.py | 12 +-- tests/test_custom_policy.py | 1 + tests/test_run.py | 4 +- tests/test_vec_envs.py | 9 ++- torchy_baselines/a2c/a2c.py | 10 +-- torchy_baselines/common/base_class.py | 80 +++++++++---------- torchy_baselines/common/distributions.py | 4 +- .../common/vec_env/subproc_vec_env.py | 2 +- torchy_baselines/ppo/ppo.py | 2 +- torchy_baselines/td3/td3.py | 3 +- 10 files changed, 62 insertions(+), 65 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 606a195..0352081 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,12 +16,14 @@ import os import sys from unittest.mock import MagicMock +import torchy_baselines # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath('..')) class Mock(MagicMock): __subclasses__ = [] + @classmethod def __getattr__(cls, name): return MagicMock() @@ -42,10 +44,6 @@ MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal', 'gym.wrappers', 'gym.wrappers.monitoring', 'zmq'] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) - -import torchy_baselines - - # -- Project information ----------------------------------------------------- project = 'Torchy Baselines' @@ -57,7 +55,6 @@ version = 'master (' + torchy_baselines.__version__ + ' )' # The full version, including alpha/beta/rc tags release = torchy_baselines.__version__ - # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -102,7 +99,6 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -121,6 +117,7 @@ html_logo = '_static/img/logo.png' def setup(app): app.add_stylesheet("css/baselines_theme.css") + # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. @@ -148,7 +145,6 @@ html_static_path = ['_static'] # Output file base name for HTML help builder. htmlhelp_basename = 'TorchyBaselinesdoc' - # -- Options for LaTeX output ------------------------------------------------ latex_elements = { @@ -177,7 +173,6 @@ latex_documents = [ 'Torchy Baselines Contributors', 'manual'), ] - # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples @@ -187,7 +182,6 @@ man_pages = [ [author], 1) ] - # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 50de59c..d45be88 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -5,6 +5,7 @@ import pytest from torchy_baselines import PPO + @pytest.mark.parametrize('net_arch', [ [12, dict(vf=[16], pi=[8])], [4], diff --git a/tests/test_run.py b/tests/test_run.py index b578824..7ca986f 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -6,7 +6,6 @@ import numpy as np from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.noise import NormalActionNoise - action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @@ -19,8 +18,6 @@ def test_td3(): os.remove("test_save.pth") - - def test_a2c(): model = A2C('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) @@ -47,6 +44,7 @@ def test_onpolicy(model_class, env_id): # model.load("test_save") # os.remove("test_save.pth") + def test_sac(): model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto', diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 2147e78..efa5119 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -59,8 +59,10 @@ class CustomGymEnv(gym.Env): @pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS) def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): """Test access to methods/attributes of vectorized environments""" + def make_env(): return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) if vec_env_wrapper is not None: @@ -92,7 +94,6 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): assert (env_method_subset[1] == np.ones((1, 3))).all() assert len(env_method_subset) == 2 - # Test to change value for all the environments setattr_result = vec_env.set_attr('current_step', 42, indices=None) getattr_result = vec_env.get_attr('current_step') @@ -193,8 +194,10 @@ SPACES = collections.OrderedDict([ ('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))), ]) + def check_vecenv_spaces(vec_env_class, space, obs_assert): """Helper method to check observation spaces in vectorized environments.""" + def make_env(): return CustomGymEnv(space) @@ -228,6 +231,7 @@ def test_vecenv_single_space(vec_env_class, space): class _UnorderedDictSpace(gym.spaces.Dict): """Like DictSpace, but returns an unordered dict when sampling.""" + def sample(self): return dict(super().sample()) @@ -301,14 +305,17 @@ class CustomWrapperB(VecNormalize): def name_test(self): return self.__class__ + class CustomWrapperBB(CustomWrapperB): def __init__(self, venv): CustomWrapperB.__init__(self, venv) self.var_bb = 'bb' + def test_vecenv_wrapper_getattr(): def make_env(): return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)]) wrapped = CustomWrapperA(CustomWrapperBB(vec_env)) assert wrapped.var_a == 'a' diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 57faabd..1c066cf 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -115,15 +115,15 @@ class A2C(PPO): self.policy.optimizer.step() # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) - # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), - # self.rollout_buffer.values.flatten().cpu().numpy())) + # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), + # self.rollout_buffer.values.flatten().cpu().numpy())) def learn(self, total_timesteps, callback=None, log_interval=100, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True): return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, - tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) + eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) def save(self, path): if not path.endswith('.pth'): @@ -135,4 +135,4 @@ class A2C(PPO): path += '.pth' if env is not None: pass - self.policy.load_state_dict(th.load(path)) \ No newline at end of file + self.policy.load_state_dict(th.load(path)) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index d0c0d5d..3bf1eed 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -290,9 +290,10 @@ class BaseRLModel(object): if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: raise ValueError("The specified policy kwargs do not equal the stored policy kwargs." - "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],kwargs['policy_kwargs'])) + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) - model = cls(policy=data["policy"],env=None, _init_setup_model=False) + model = cls(policy=data["policy"], env=None, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) @@ -300,40 +301,8 @@ class BaseRLModel(object): return model - @staticmethod - def _save_to_file_zip(save_path, data=None, params=None): - """Save model to a zip archive - - :param save_path: (str or file-like) Where to store the model - :param data: (OrderedDict) Class parameters being stored - :param params: (OrderedDict) Model parameters being stored expexted to be state_dict - """ - - # data/params can be None, so do not - # try to serialize them blindly - if data is not None: - serialized_data = data_to_json(data) - - # Check postfix if save_path is a string - if isinstance(save_path, str): - _, ext = os.path.splitext(save_path) - if ext == "": - save_path += ".zip" - - # Create a zip-archive and write our objects - # there. This works when save_path is either - # str or a file-like - with zipfile.ZipFile(save_path, "w") as file_: - # Do not try to save "None" elements - if data is not None: - file_.writestr("data",serialized_data) - if params is not None: - with file_.open('param.pth', mode="w") as param_file: - th.save(params,param_file) - - @staticmethod - def _load_from_file(load_path, load_data = True): + def _load_from_file(load_path, load_data=True): """ Load model data from a .zip archive :param load_path: (str or file-like) Where to load the model from @@ -351,7 +320,7 @@ class BaseRLModel(object): # Open the zip archive and load data try: - with zipfile.ZipFile(load_path,"r") as file_: + with zipfile.ZipFile(load_path, "r") as file_: namelist = file_.namelist() # If data or parameters is not in the # zip archive, assume they were stored @@ -378,12 +347,6 @@ class BaseRLModel(object): return data, params - - - - - - def set_random_seed(self, seed=0): set_random_seed(seed, using_cuda=self.device == th.device('cuda')) self.action_space.seed(seed) @@ -478,7 +441,7 @@ class BaseRLModel(object): # Display training infos if self.verbose >= 1 and log_interval is not None and ( - episode_num + total_episodes) % log_interval == 0: + episode_num + total_episodes) % log_interval == 0: fps = int(num_timesteps / (time.time() - self.start_time)) logger.logkv("episodes", episode_num + total_episodes) # logger.logkv("mean 100 episode reward", mean_reward) @@ -495,3 +458,34 @@ class BaseRLModel(object): mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0 return mean_reward, total_steps, total_episodes, obs + + +def _save_to_file_zip(save_path, data=None, params=None): + """Save model to a zip archive + + :param save_path: (str or file-like) Where to store the model + :param data: (OrderedDict) Class parameters being stored + :param params: (OrderedDict) Model parameters being stored expexted to be state_dict + """ + + # data/params can be None, so do not + # try to serialize them blindly + if data is not None: + serialized_data = data_to_json(data) + + # Check postfix if save_path is a string + if isinstance(save_path, str): + _, ext = os.path.splitext(save_path) + if ext == "": + save_path += ".zip" + + # Create a zip-archive and write our objects + # there. This works when save_path is either + # str or a file-like + with zipfile.ZipFile(save_path, "w") as file_: + # Do not try to save "None" elements + if data is not None: + file_.writestr("data", serialized_data) + if params is not None: + with file_.open('param.pth', mode="w") as param_file: + th.save(params, param_file) diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index a384ce3..2c5672a 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.distributions import Normal, Categorical from gym import spaces + class Distribution(object): def __init__(self): super(Distribution, self).__init__() @@ -97,7 +98,8 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): self.gaussian_action = None def proba_distribution(self, mean_actions, log_std, deterministic=False): - action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic) + action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, + deterministic) return action, self def mode(self): diff --git a/torchy_baselines/common/vec_env/subproc_vec_env.py b/torchy_baselines/common/vec_env/subproc_vec_env.py index 9fc57b1..b00e465 100644 --- a/torchy_baselines/common/vec_env/subproc_vec_env.py +++ b/torchy_baselines/common/vec_env/subproc_vec_env.py @@ -42,6 +42,7 @@ def _worker(remote, parent_remote, env_fn_wrapper): except EOFError: break + def tile_images(img_nhwc): """ Tile N images into one big PxQ image @@ -68,7 +69,6 @@ def tile_images(img_nhwc): return out_image - class SubprocVecEnv(VecEnv): """ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index f22297b..f4464ff 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -321,7 +321,7 @@ class PPO(BaseRLModel): params_to_save = self.get_policy_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save) + _save_to_file_zip(path, data=data, params=params_to_save) """def load(self, path, env=None, **_kwargs): if not path.endswith('.pth'): diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 49cf16e..66ea72e 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -148,7 +148,8 @@ class TD3(BaseRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, tau_critic=0.005, replay_data=None): + def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, tau_critic: object = 0.005, + replay_data: object = None) -> object: # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer) From fb5f192fc41f61c0c902db105446b8ab0968c6c3 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 14:39:44 +0100 Subject: [PATCH 07/50] Implemented Changes suggested from Antonin-Raffin Added Optimizer saving --- tests/test_run.py | 14 +--- tests/test_save_load.py | 22 +++--- torchy_baselines/common/base_class.py | 96 ++++++++++++++++----------- torchy_baselines/ppo/ppo.py | 34 +++++++--- 4 files changed, 101 insertions(+), 65 deletions(-) diff --git a/tests/test_run.py b/tests/test_run.py index 7ca986f..54d28f8 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -18,14 +18,6 @@ def test_td3(): os.remove("test_save.pth") -def test_a2c(): - model = A2C('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) - model.learn(total_timesteps=1000, eval_freq=500) - model.save("test_save") - model.load("test_save") - os.remove("test_save.pth") - - def test_cemrl(): model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise) @@ -40,9 +32,9 @@ def test_cemrl(): def test_onpolicy(model_class, env_id): model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) - # model.save("test_save") - # model.load("test_save") - # os.remove("test_save.pth") + model.save("test_save") + model.load("test_save") + #os.remove("test_save.pth") def test_sac(): diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 01b0a25..6500664 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -18,6 +18,8 @@ def test_save_load(model_class): """ Test if 'save' and 'load' saves and loads model correctly and if 'load_parameters' and 'get_policy_parameters' work correctly + + ''warning does not test function of optimizer parameter load :param model_class: (BaseRLModel) A RL model """ @@ -26,20 +28,23 @@ def test_save_load(model_class): # create model model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) - # test action probability for given (obs, action) pair - # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) + opt_params = deepcopy(model.get_opt_parameters()) # Modify all parameters to be random values random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) - # Update model parameters with the new zeroed values - model.load_parameters(random_params) - # shared items + # Update model parameters with the new random values + model.load_parameters(random_params, opt_params) + + # Get items that are the same in params and new_params new_params = model.get_policy_parameters() shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} - # Check that at least some actions are chosen different now + + # Check that the there are at least some parameters new random parameters + #for k in params.key(): + # assert not th.allclose(params[k], new_params[k]) assert not len(shared_items) == len(new_params), "Selected actions did not change " \ "after changing model parameters." @@ -48,9 +53,10 @@ def test_save_load(model_class): # Check model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save.zip") - model = model.load("test_save") + del model + model = model_class.load("test_save") - # check if params are still the same after load + #check if params are still the same after load new_params = model.get_policy_parameters() shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} # Check that at least some actions are chosen different now diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 3bf1eed..c3b2ddd 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -194,7 +194,8 @@ class BaseRLModel(object): """ return self.policy.state_dict() - def get_optim_parameters(self): + @abstractmethod + def get_opt_parameters(self): """ Get current model optimizer parameters as dictionary of variable names -> tensors :return: (OrderedDict) Dictionary of variable name -> tensor of model's optimizer parameters @@ -249,7 +250,7 @@ class BaseRLModel(object): """ pass - def load_parameters(self, load_dict, exact_match=True): + def load_parameters(self, load_dict, opt_params=None, exact_match=True): """ Load model parameters from a dictionary @@ -262,8 +263,11 @@ class BaseRLModel(object): As such training after using this function may lead to less-than-optimal results. - :param load_path_or_dict: (dict) dict of parameters from model.state_dict() + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class """ + if opt_params is not None: + raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") self.policy.load_state_dict(load_dict) @abstractmethod @@ -276,7 +280,6 @@ class BaseRLModel(object): raise NotImplementedError() @classmethod - @abstractmethod def load(cls, load_path, env=None, **kwargs): """ Load the model from file @@ -286,7 +289,7 @@ class BaseRLModel(object): (can be None if you only need prediction from a trained model) :param kwargs: extra arguments to change the model when loading """ - data, params = cls._load_from_file(load_path) + data, params, opt_params = cls._load_from_file(load_path) if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: raise ValueError("The specified policy kwargs do not equal the stored policy kwargs." @@ -297,7 +300,7 @@ class BaseRLModel(object): model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) - model.load_parameters(params) + model.load_parameters(params, opt_params) return model @@ -308,7 +311,7 @@ class BaseRLModel(object): :param load_path: (str or file-like) Where to load the model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) - :return: (dict. OrderedDict),(dict. OrderedDict) Class parameters and model parameters (state_dict) + :return: (dict. OrderedDict),(dict. OrderedDict),(dict. OrderedDict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) """ # Check if file exists if load_path is a string if isinstance(load_path, str): @@ -327,6 +330,7 @@ class BaseRLModel(object): # as None (_save_to_file_zip allows this). data = None params = None + opt_params = None if "data" in namelist and load_data: # Load class parameters and convert to string json_data = file_.read("data").decode() @@ -341,11 +345,24 @@ class BaseRLModel(object): # go to start of file file_content.seek(0) params = th.load(file_content) + # check for all other .pth files + other_files = [file_name for file_name in namelist if + os.path.splitext(file_name)[1] == ".pth" and file_name != "param.pth"] + if len(other_files) > 0: + opt_params = dict() + for file in other_files: + with file_.open(file, mode="r") as opt_param_file: + # File has to be seekable so load in BytesIO first + file_content = io.BytesIO() + file_content.write(opt_param_file.read()) + # go to start of file + file_content.seek(0) + opt_params[os.path.splitext(file)[0]] = th.load(file_content) except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError("Error: the file {} wasn't a zip-file".format(load_path)) - return data, params + return data, params, opt_params def set_random_seed(self, seed=0): set_random_seed(seed, using_cuda=self.device == th.device('cuda')) @@ -456,36 +473,41 @@ class BaseRLModel(object): logger.dumpkvs() mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0 - return mean_reward, total_steps, total_episodes, obs + @staticmethod + def _save_to_file_zip(save_path, data=None, params=None, opt_params=None): + """Save model to a zip archive + + :param save_path: (str or file-like) Where to store the model + :param data: (OrderedDict) Class parameters being stored + :param params: (OrderedDict) Model parameters being stored expected to be state_dict + :param opt_params: (OrderedDict) Optimizer parameters being stored expected to contain an entry for every + optimizer with its name and the state_dict + """ -def _save_to_file_zip(save_path, data=None, params=None): - """Save model to a zip archive - - :param save_path: (str or file-like) Where to store the model - :param data: (OrderedDict) Class parameters being stored - :param params: (OrderedDict) Model parameters being stored expexted to be state_dict - """ - - # data/params can be None, so do not - # try to serialize them blindly - if data is not None: - serialized_data = data_to_json(data) - - # Check postfix if save_path is a string - if isinstance(save_path, str): - _, ext = os.path.splitext(save_path) - if ext == "": - save_path += ".zip" - - # Create a zip-archive and write our objects - # there. This works when save_path is either - # str or a file-like - with zipfile.ZipFile(save_path, "w") as file_: - # Do not try to save "None" elements + # data/params can be None, so do not + # try to serialize them blindly if data is not None: - file_.writestr("data", serialized_data) - if params is not None: - with file_.open('param.pth', mode="w") as param_file: - th.save(params, param_file) + serialized_data = data_to_json(data) + + # Check postfix if save_path is a string + if isinstance(save_path, str): + _, ext = os.path.splitext(save_path) + if ext == "": + save_path += ".zip" + + # Create a zip-archive and write our objects + # there. This works when save_path is either + # str or a file-like + with zipfile.ZipFile(save_path, "w") as file_: + # Do not try to save "None" elements + if data is not None: + file_.writestr("data", serialized_data) + if params is not None: + with file_.open('param.pth', mode="w") as param_file: + th.save(params, param_file) + if opt_params is not None: + for file_name, dict in opt_params.items(): + with file_.open(file_name + '.pth', mode="w") as opt_param_file: + th.save(dict, opt_param_file) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index f4464ff..fed5969 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -292,6 +292,29 @@ class PPO(BaseRLModel): return self + def get_opt_parameters(self): + """ + returns a dict of all the optimizers and their parameters + + :return: (Dict) of optimizer names and their state_dict + """ + return {"opt": self.policy.optimizer.state_dict()} + + def load_parameters(self, load_dict, opt_params): + """ + Load model parameters and optimizer parameters from a dictionary + + Dictionary should be of shape torch model.state_dict() + + This does not load agent's hyper-parameters. + + + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + """ + self.policy.optimizer.load_state_dict(opt_params["opt"]) + self.policy.load_state_dict(load_dict) + def save(self, path): """ saves all the params from init and pytorch params in a file for continous learning @@ -320,12 +343,5 @@ class PPO(BaseRLModel): } params_to_save = self.get_policy_parameters() - - _save_to_file_zip(path, data=data, params=params_to_save) - - """def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path))""" + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) \ No newline at end of file From 17f84053b3c391ba382d75da148d706f244f052e Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 14:44:02 +0100 Subject: [PATCH 08/50] save implementation for a2c needed before uncommenting save and load test in test_run.py::test_onpolicy --- tests/test_run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_run.py b/tests/test_run.py index 54d28f8..445947b 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -32,8 +32,8 @@ def test_cemrl(): def test_onpolicy(model_class, env_id): model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) - model.save("test_save") - model.load("test_save") + #model.save("test_save") + #model.load("test_save") #os.remove("test_save.pth") From d31a39914067be00a1e30d1578e1e8ea8535f5c5 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 14:52:29 +0100 Subject: [PATCH 09/50] undo changes to conf.py --- docs/conf.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0352081..4e4a2de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,6 @@ import os import sys from unittest.mock import MagicMock -import torchy_baselines # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath('..')) @@ -44,6 +43,10 @@ MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal', 'gym.wrappers', 'gym.wrappers.monitoring', 'zmq'] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) + +import torchy_baselines + + # -- Project information ----------------------------------------------------- project = 'Torchy Baselines' @@ -55,6 +58,7 @@ version = 'master (' + torchy_baselines.__version__ + ' )' # The full version, including alpha/beta/rc tags release = torchy_baselines.__version__ + # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -99,6 +103,7 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -117,7 +122,6 @@ html_logo = '_static/img/logo.png' def setup(app): app.add_stylesheet("css/baselines_theme.css") - # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. @@ -145,6 +149,7 @@ html_static_path = ['_static'] # Output file base name for HTML help builder. htmlhelp_basename = 'TorchyBaselinesdoc' + # -- Options for LaTeX output ------------------------------------------------ latex_elements = { @@ -173,6 +178,7 @@ latex_documents = [ 'Torchy Baselines Contributors', 'manual'), ] + # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples @@ -182,6 +188,7 @@ man_pages = [ [author], 1) ] + # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples From 26f31fd25bd3a443350c74de8bde5afbd8ec6b86 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 14:55:56 +0100 Subject: [PATCH 10/50] corrected comment sections --- torchy_baselines/common/base_class.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index c3b2ddd..b4f73fc 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -190,7 +190,7 @@ class BaseRLModel(object): """ Get current model policy parameters as dictionary of variable name -> tensors. - :return: (OrderedDict) Dictionary of variable name -> tensor of model's policy parameters. + :return: (dict) Dictionary of variable name -> tensor of model's policy parameters. """ return self.policy.state_dict() @@ -198,7 +198,7 @@ class BaseRLModel(object): def get_opt_parameters(self): """ Get current model optimizer parameters as dictionary of variable names -> tensors - :return: (OrderedDict) Dictionary of variable name -> tensor of model's optimizer parameters + :return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters """ raise NotImplementedError() @@ -254,7 +254,7 @@ class BaseRLModel(object): """ Load model parameters from a dictionary - Dictionary should be of shape torch model.state_dict() + Dictionary should contain all entries of torch model.state_dict() This does not load agent's hyper-parameters. @@ -311,7 +311,7 @@ class BaseRLModel(object): :param load_path: (str or file-like) Where to load the model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) - :return: (dict. OrderedDict),(dict. OrderedDict),(dict. OrderedDict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) + :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) """ # Check if file exists if load_path is a string if isinstance(load_path, str): @@ -336,9 +336,9 @@ class BaseRLModel(object): json_data = file_.read("data").decode() data = json_to_data(json_data) - if "param.pth" in namelist: + if "params.pth" in namelist: # Load parameters with build in torch function - with file_.open("param.pth", mode="r") as param_file: + with file_.open("params.pth", mode="r") as param_file: # File has to be seekable so load in BytesIO first file_content = io.BytesIO() file_content.write(param_file.read()) @@ -347,7 +347,7 @@ class BaseRLModel(object): params = th.load(file_content) # check for all other .pth files other_files = [file_name for file_name in namelist if - os.path.splitext(file_name)[1] == ".pth" and file_name != "param.pth"] + os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] if len(other_files) > 0: opt_params = dict() for file in other_files: @@ -480,9 +480,9 @@ class BaseRLModel(object): """Save model to a zip archive :param save_path: (str or file-like) Where to store the model - :param data: (OrderedDict) Class parameters being stored - :param params: (OrderedDict) Model parameters being stored expected to be state_dict - :param opt_params: (OrderedDict) Optimizer parameters being stored expected to contain an entry for every + :param data: (dict) Class parameters being stored + :param params: (dict) Model parameters being stored expected to be state_dict + :param opt_params: (dict) Optimizer parameters being stored expected to contain an entry for every optimizer with its name and the state_dict """ From 526c37bf1f02c4899ea5a138f5eb14a921080f36 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 15:44:57 +0100 Subject: [PATCH 11/50] refactored the assets in test_save_load fixed base_class 'params.pth' --- tests/test_save_load.py | 27 +++++++++++++++------------ torchy_baselines/common/base_class.py | 2 +- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 6500664..81cfe52 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -38,15 +38,11 @@ def test_save_load(model_class): # Update model parameters with the new random values model.load_parameters(random_params, opt_params) - # Get items that are the same in params and new_params new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} - - # Check that the there are at least some parameters new random parameters - #for k in params.key(): - # assert not th.allclose(params[k], new_params[k]) - assert not len(shared_items) == len(new_params), "Selected actions did not change " \ - "after changing model parameters." + # Check that all params are different now + for k in params: + assert not th.allclose(params[k], new_params[k]), "Selected actions did not change " \ + "after changing model parameters." params = new_params @@ -56,9 +52,16 @@ def test_save_load(model_class): del model model = model_class.load("test_save") - #check if params are still the same after load + # check if params are still the same after load new_params = model.get_policy_parameters() - shared_items = {k: params[k] for k in params if k in new_params and th.all(th.eq(params[k], new_params[k]))} - # Check that at least some actions are chosen different now - assert len(shared_items) == len(new_params), "Parameters not the same after save and load." + + # Check that all params are the same as before save load procedure now + for k in params: + assert th.allclose(params[k], new_params[k]), "Model parameters not the same after save and load." + + # check if optimizer params are still the same after load + new_opt_params = model.get_opt_parameters() + # check if keys are the same + assert opt_params.keys() == new_opt_params.keys() + # check if values are the same: don't know how to to that os.remove("test_save.zip") diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index b4f73fc..6955362 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -505,7 +505,7 @@ class BaseRLModel(object): if data is not None: file_.writestr("data", serialized_data) if params is not None: - with file_.open('param.pth', mode="w") as param_file: + with file_.open('params.pth', mode="w") as param_file: th.save(params, param_file) if opt_params is not None: for file_name, dict in opt_params.items(): From 775a50cc5cbefcd30191f5a3c7245ba0ad304f95 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:24:18 +0100 Subject: [PATCH 12/50] saving all variables now added a2c support --- tests/test_save_load.py | 4 +++- torchy_baselines/a2c/a2c.py | 19 ++++++++++--------- torchy_baselines/common/base_class.py | 2 +- torchy_baselines/ppo/ppo.py | 21 +-------------------- 4 files changed, 15 insertions(+), 31 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 81cfe52..1eb4fa0 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -9,7 +9,8 @@ from torchy_baselines.common.vec_env import DummyVecEnv from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ - PPO + PPO, + A2C, ] @@ -51,6 +52,7 @@ def test_save_load(model_class): model.save("test_save.zip") del model model = model_class.load("test_save") + model.learn(total_timesteps=1000, eval_freq=500) # check if params are still the same after load new_params = model.get_policy_parameters() diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 1c066cf..8811356 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -126,13 +126,14 @@ class A2C(PPO): tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + """ + saves all the params from init and pytorch params in a file for continous learning - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) + :param path: path to the file where the data should be safed + :return: + """ + + data = self.__dict__ + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 6955362..5eb7ff0 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -296,7 +296,7 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) - model = cls(policy=data["policy"], env=None, _init_setup_model=False) + model = cls(policy=data["policy"], env=data["env"], _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index fed5969..3af5b08 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -322,26 +322,7 @@ class PPO(BaseRLModel): :param path: path to the file where the data should be safed :return: """ - - data = { - "gamma": self.gamma, - "n_steps": self.n_steps, - "vf_coef": self.vf_coef, - "ent_coef": self.ent_coef, - "max_grad_norm": self.max_grad_norm, - "learning_rate": self.learning_rate, - "gae_lambda": self.gae_lambda, - "n_epochs": self.n_epochs, - "clip_range": self.clip_range, - "clip_range_vf": self.clip_range_vf, - "batch_size": self.batch_size, - "target_kl": self.target_kl, - "tensorboard_log": self.tensorboard_log, - "policy_kwargs": self.policy_kwargs, - "policy": self.policy, - - } - + data = self.__dict__ params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) \ No newline at end of file From 2d72f6d1b5b2b9b707fc03b6546ca5becd443eac Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:46:53 +0100 Subject: [PATCH 13/50] Added SAC, TD3, A2C Missing CEMRL --- tests/test_save_load.py | 8 +++++- torchy_baselines/a2c/a2c.py | 13 --------- torchy_baselines/cem_rl/cem_rl.py | 13 --------- torchy_baselines/common/base_class.py | 13 +++++++++ torchy_baselines/ppo/ppo.py | 12 --------- torchy_baselines/sac/sac.py | 38 +++++++++++++++++++-------- torchy_baselines/td3/td3.py | 37 +++++++++++++++++--------- 7 files changed, 72 insertions(+), 62 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1eb4fa0..6368127 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -11,6 +11,8 @@ from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ PPO, A2C, + TD3, + SAC, ] @@ -52,7 +54,6 @@ def test_save_load(model_class): model.save("test_save.zip") del model model = model_class.load("test_save") - model.learn(total_timesteps=1000, eval_freq=500) # check if params are still the same after load new_params = model.get_policy_parameters() @@ -66,4 +67,9 @@ def test_save_load(model_class): # check if keys are the same assert opt_params.keys() == new_opt_params.keys() # check if values are the same: don't know how to to that + + # check if learn still works + model.learn(total_timesteps=1000, eval_freq=500) + + # clear file from os os.remove("test_save.zip") diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 8811356..5babed5 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -124,16 +124,3 @@ class A2C(PPO): return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) - - def save(self, path): - """ - saves all the params from init and pytorch params in a file for continous learning - - :param path: path to the file where the data should be safed - :return: - """ - - data = self.__dict__ - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index ad798ba..f35fa43 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -157,16 +157,3 @@ class CEMRL(TD3): self.es.tell(self.es_params, self.fitnesses) timesteps_since_eval += actor_steps return self - - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) - - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 5eb7ff0..b2c6056 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -511,3 +511,16 @@ class BaseRLModel(object): for file_name, dict in opt_params.items(): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) + + + def save(self, path): + """ + saves all the params from init and pytorch params in a file for continous learning + + :param path: path to the file where the data should be safed + :return: + """ + data = self.__dict__ + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 3af5b08..3b4734e 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -314,15 +314,3 @@ class PPO(BaseRLModel): """ self.policy.optimizer.load_state_dict(opt_params["opt"]) self.policy.load_state_dict(load_dict) - - def save(self, path): - """ - saves all the params from init and pytorch params in a file for continous learning - - :param path: path to the file where the data should be safed - :return: - """ - data = self.__dict__ - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) \ No newline at end of file diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index bad470a..2497523 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -274,15 +274,31 @@ class SAC(BaseRLModel): return self - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + def get_opt_parameters(self): + """ + returns a dict of all the optimizers and their parameters - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() + :return: (Dict) of optimizer names and their state_dict + """ + if self.ent_coef_optimizer is not None: + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(),"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()} + else: + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + + def load_parameters(self, load_dict, opt_params): + """ + Load model parameters and optimizer parameters from a dictionary + + Dictionary should be of shape torch model.state_dict() + + This does not load agent's hyper-parameters. + + + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + """ + self.actor.optimizer.load_state_dict(opt_params["actor"]) + self.critic.optimizer.load_state_dict(opt_params["critic"]) + if "ent_coef_optimizer" in opt_params: + self.ent_coef_optimizer.load_state_dict(opt_params["ent_coef_optimizer"]) + self.policy.load_state_dict(load_dict) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 66ea72e..035d344 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -46,6 +46,7 @@ class TD3(BaseRLModel): Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ + def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3, policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, @@ -148,7 +149,8 @@ class TD3(BaseRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, tau_critic: object = 0.005, + def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, + tau_critic: object = 0.005, replay_data: object = None) -> object: # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer) @@ -235,15 +237,26 @@ class TD3(BaseRLModel): return self - def save(self, path): - if not path.endswith('.pth'): - path += '.pth' - th.save(self.policy.state_dict(), path) + def get_opt_parameters(self): + """ + returns a dict of all the optimizers and their parameters - def load(self, path, env=None, **_kwargs): - if not path.endswith('.pth'): - path += '.pth' - if env is not None: - pass - self.policy.load_state_dict(th.load(path)) - self._create_aliases() + :return: (Dict) of optimizer names and their state_dict + """ + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + + def load_parameters(self, load_dict, opt_params): + """ + Load model parameters and optimizer parameters from a dictionary + + Dictionary should be of shape torch model.state_dict() + + This does not load agent's hyper-parameters. + + + :param load_dict: (dict) dict of parameters from model.state_dict() + :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + """ + self.actor.optimizer.load_state_dict(opt_params["actor"]) + self.critic.optimizer.load_state_dict(opt_params["critic"]) + self.policy.load_state_dict(load_dict) From 03a0d437ef54cb5e9278d7bd0f62a22fc6042496 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:49:49 +0100 Subject: [PATCH 14/50] refactor --- torchy_baselines/sac/sac.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 2497523..ad36f64 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -288,12 +288,9 @@ class SAC(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() - This does not load agent's hyper-parameters. - :param load_dict: (dict) dict of parameters from model.state_dict() :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class """ From 924ba9aea6cccd0d11381b0541cddab9820a120e Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:50:59 +0100 Subject: [PATCH 15/50] cleaned comments on model specific get and load functions --- torchy_baselines/ppo/ppo.py | 2 -- torchy_baselines/td3/td3.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 3b4734e..c4e0838 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -303,9 +303,7 @@ class PPO(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() - This does not load agent's hyper-parameters. diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 035d344..e21c65a 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -248,12 +248,9 @@ class TD3(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() - This does not load agent's hyper-parameters. - :param load_dict: (dict) dict of parameters from model.state_dict() :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class """ From cfb822aa916b9a8ebdf703ba23af55a5a8409b0e Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 16:54:30 +0100 Subject: [PATCH 16/50] Corrected test_run.py --- tests/test_run.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_run.py b/tests/test_run.py index 445947b..a5569e5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -15,7 +15,7 @@ def test_td3(): model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save") model.load("test_save") - os.remove("test_save.pth") + os.remove("test_save.zip") def test_cemrl(): @@ -24,7 +24,7 @@ def test_cemrl(): model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save") model.load("test_save") - os.remove("test_save.pth") + os.remove("test_save.zip") @pytest.mark.parametrize("model_class", [A2C, PPO]) @@ -32,9 +32,9 @@ def test_cemrl(): def test_onpolicy(model_class, env_id): model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) - #model.save("test_save") - #model.load("test_save") - #os.remove("test_save.pth") + model.save("test_save") + model.load("test_save") + os.remove("test_save.zip") def test_sac(): @@ -42,3 +42,6 @@ def test_sac(): learning_starts=100, verbose=1, create_eval_env=True, ent_coef='auto', action_noise=NormalActionNoise(np.zeros(1), np.zeros(1))) model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save") + model.load("test_save") + os.remove("test_save.zip") From 4f8f9364516d2fffc7616fe42434e0e2d0c1402b Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 21 Nov 2019 17:27:46 +0100 Subject: [PATCH 17/50] Don't save replay_buffer by default --- torchy_baselines/common/base_class.py | 29 +++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index b2c6056..57c27b0 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -270,15 +270,6 @@ class BaseRLModel(object): raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") self.policy.load_state_dict(load_dict) - @abstractmethod - def save(self, save_path): - """ - Save the current parameters to file - - :param save_path: (str or file-like object) the save location - """ - raise NotImplementedError() - @classmethod def load(cls, load_path, env=None, **kwargs): """ @@ -512,15 +503,27 @@ class BaseRLModel(object): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def save(self, path): """ - saves all the params from init and pytorch params in a file for continous learning + saves all the params from init and pytorch params in a file for continuous learning - :param path: path to the file where the data should be safed + :param path: path to the file where the data should be saved + :return: + """ + data = self.__dict__ + data.pop("replay_buffer") + params_to_save = self.get_policy_parameters() + opt_params_to_save = self.get_opt_parameters() + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) + + def save_with_replay_buffer(self, path): + """ + saves all the params from init and pytorch params in a file for continuous learning + + :param path: path to the file where the data should be saved :return: """ data = self.__dict__ params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save,opt_params=opt_params_to_save) + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) From b75ffe166a2c78efa7a9e9c183055c03cce0f033 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 10:36:21 +0100 Subject: [PATCH 18/50] Cleared base_class.load description --- torchy_baselines/common/base_class.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 57c27b0..3384775 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -275,7 +275,7 @@ class BaseRLModel(object): """ Load the model from file - :param load_path: (str or file-like) the saved parameter location + :param load_path: (str) the saved parameter location :param env: (Gym Envrionment) the new environment to run the loaded model on (can be None if you only need prediction from a trained model) :param kwargs: extra arguments to change the model when loading @@ -507,7 +507,7 @@ class BaseRLModel(object): """ saves all the params from init and pytorch params in a file for continuous learning - :param path: path to the file where the data should be saved + :param path: (str) path to the file where the data should be saved :return: """ data = self.__dict__ From 812cab84ac4f1a65b2d9015078bc52032b37d6dc Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 11:20:40 +0100 Subject: [PATCH 19/50] Changed PPO deterministic --- tests/test_save_load.py | 17 ++++++++++++----- torchy_baselines/common/base_class.py | 16 ++-------------- torchy_baselines/ppo/ppo.py | 6 +++--- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 6368127..e184c6e 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -10,9 +10,9 @@ from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ PPO, - A2C, - TD3, - SAC, + #A2C, + #TD3, + #SAC, ] @@ -30,6 +30,7 @@ def test_save_load(model_class): # create model model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model.learn(total_timesteps=1000, eval_freq=500) # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) @@ -50,7 +51,6 @@ def test_save_load(model_class): params = new_params # Check - model.learn(total_timesteps=1000, eval_freq=500) model.save("test_save.zip") del model model = model_class.load("test_save") @@ -66,7 +66,14 @@ def test_save_load(model_class): new_opt_params = model.get_opt_parameters() # check if keys are the same assert opt_params.keys() == new_opt_params.keys() - # check if values are the same: don't know how to to that + # check if values are the same: only tested for Adam and RMSProp so far + for optimizer,opt_state in opt_params.items(): + for step_entry, entry_dict in opt_state['state'].items(): + for value_key,value in entry_dict.items(): + print(value == new_opt_params[optimizer][step_entry][value_key]) + + + # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 3384775..597d7fe 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -503,7 +503,7 @@ class BaseRLModel(object): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def save(self, path): + def save(self, path, include=None):#TODO """ saves all the params from init and pytorch params in a file for continuous learning @@ -514,16 +514,4 @@ class BaseRLModel(object): data.pop("replay_buffer") params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) - - def save_with_replay_buffer(self, path): - """ - saves all the params from init and pytorch params in a file for continuous learning - - :param path: path to the file where the data should be saved - :return: - """ - data = self.__dict__ - params_to_save = self.get_policy_parameters() - opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) \ No newline at end of file diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index c4e0838..7582151 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -124,14 +124,14 @@ class PPO(BaseRLModel): if self.clip_range_vf is not None: self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - def select_action(self, observation): + def select_action(self, observation,deterministic=False): # Normally not needed observation = np.array(observation) with th.no_grad(): observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.policy.actor_forward(observation, deterministic=False) + return self.policy.actor_forward(observation, deterministic) - def predict(self, observation, state=None, mask=None, deterministic=True): + def predict(self, observation, state=None, mask=None, deterministic=False): """ Get the model's action from an observation From e26564e0ec038c4788c6634596ecaa25a5630696 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 13:35:16 +0100 Subject: [PATCH 20/50] Added function for setting up any attributes that weren't saved and thus not loaded --- tests/test_save_load.py | 29 +++++++++++++----------- torchy_baselines/common/base_class.py | 32 ++++++++++++++++++++++++--- torchy_baselines/sac/sac.py | 9 ++++++++ torchy_baselines/td3/td3.py | 11 ++++++++- 4 files changed, 64 insertions(+), 17 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e184c6e..4d7861e 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -10,9 +10,9 @@ from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ PPO, - #A2C, - #TD3, - #SAC, + A2C, + TD3, + SAC, ] @@ -30,7 +30,7 @@ def test_save_load(model_class): # create model model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) - model.learn(total_timesteps=1000, eval_freq=500) + model.learn(total_timesteps=500, eval_freq=250) # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) @@ -45,8 +45,7 @@ def test_save_load(model_class): new_params = model.get_policy_parameters() # Check that all params are different now for k in params: - assert not th.allclose(params[k], new_params[k]), "Selected actions did not change " \ - "after changing model parameters." + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params @@ -67,13 +66,17 @@ def test_save_load(model_class): # check if keys are the same assert opt_params.keys() == new_opt_params.keys() # check if values are the same: only tested for Adam and RMSProp so far - for optimizer,opt_state in opt_params.items(): - for step_entry, entry_dict in opt_state['state'].items(): - for value_key,value in entry_dict.items(): - print(value == new_opt_params[optimizer][step_entry][value_key]) - - - + # comparing states not implemented so far. hashes of state_entries are not the same for equal tensors + # comparing every sub_entry does not work because of bool value of Tensor with more than one value is ambiguous + # so far only comparing param_lists + for optimizer, opt_state in opt_params.items(): + for param_group_idx, param_group in enumerate(opt_state['param_groups']): + for param_key, param_value in param_group.items(): + if param_key == 'params': # don't know how to handle params correctly, therefore only check if we have the same amount + assert len(param_value) == len( + new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]) + else: + assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key] # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index fed3243..970ea5a 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -292,7 +292,8 @@ class BaseRLModel(object): model.__dict__.update(kwargs) model.set_env(env) model.load_parameters(params, opt_params) - + # resetup modul after load + model._resetup_model() return model @staticmethod @@ -511,15 +512,40 @@ class BaseRLModel(object): with file_.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def save(self, path, include=None):#TODO + def excluded_save_params(self): + """ + returns the names of the parameters that should be excluded from save + :return: (list) List of parameters that should be excluded from save + """ + return ["replay_buffer"] + + def _resetup_model(self): + """ + Function will be called at the end of load and should resetup anything that might not have been saved + warning: this function should always be in compliance with excluded_save_params + :return: + """ + pass + + def save(self, path, include=None): """ saves all the params from init and pytorch params in a file for continuous learning :param path: (str) path to the file where the data should be saved + :param include: (list) name of parameters that might be excluded but should be included anyway :return: """ data = self.__dict__ - data.pop("replay_buffer") + # get list of params to be excluded + exclude = self.excluded_save_params() + # do not exclude params if they are specifically included + if include is not None: + exclude = [param_name for param_name in exclude if param_name not in include] + + # remove parameter entries of parameters which are to be excluded + for param_name in exclude: + data.pop(param_name, None) + params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) \ No newline at end of file diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index ad36f64..42771bb 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -128,6 +128,15 @@ class SAC(BaseRLModel): self.policy = self.policy.to(self.device) self._create_aliases() + def _resetup_model(self): + """ + method used to resetup anything that was not saved + :return: + """ + if self.replay_buffer is None: + obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] + self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) + def _create_aliases(self): self.actor = self.policy.actor self.critic = self.policy.critic diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index e21c65a..07036cf 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -85,6 +85,15 @@ class TD3(BaseRLModel): self.policy = self.policy.to(self.device) self._create_aliases() + def _resetup_model(self): + """ + method used to resetup anything that was not saved + :return: + """ + if self.replay_buffer is None: + obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] + self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) + def _create_aliases(self): self.actor = self.policy.actor self.actor_target = self.policy.actor_target @@ -256,4 +265,4 @@ class TD3(BaseRLModel): """ self.actor.optimizer.load_state_dict(opt_params["actor"]) self.critic.optimizer.load_state_dict(opt_params["critic"]) - self.policy.load_state_dict(load_dict) + self.policy.load_state_dict(load_dict) \ No newline at end of file From 9ff59eaf3d08b304c65fe21bf6a7a035be3a295b Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 15:25:01 +0100 Subject: [PATCH 21/50] Added attribute self.policy_class to prevent errors when using self.policy as class --- tests/test_save_load.py | 15 +++++++++++++++ torchy_baselines/common/base_class.py | 6 +++--- torchy_baselines/ppo/ppo.py | 2 +- torchy_baselines/sac/sac.py | 11 +---------- torchy_baselines/td3/td3.py | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 4d7861e..dec98b4 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,6 +1,7 @@ import os import pytest from copy import deepcopy +import numpy as np import torch as th @@ -32,6 +33,11 @@ def test_save_load(model_class): model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=500, eval_freq=250) + env.reset() + observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) + observations = np.squeeze(observations) + + # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) opt_params = deepcopy(model.get_opt_parameters()) @@ -49,6 +55,10 @@ def test_save_load(model_class): params = new_params + + #get selected actions + selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + # Check model.save("test_save.zip") del model @@ -78,6 +88,11 @@ def test_save_load(model_class): else: assert param_value == new_opt_params[optimizer]['param_groups'][param_group_idx][param_key] + # check if model still selects the same actions + new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + for i in range(len(selected_actions)): + assert selected_actions[i] == new_selected_actions[i] + # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 970ea5a..cbdb50b 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -38,9 +38,9 @@ class BaseRLModel(object): verbose=0, device='auto', support_multi_env=False, create_eval_env=False, monitor_wrapper=True, seed=None): if isinstance(policy, str) and policy_base is not None: - self.policy = get_policy_from_name(policy_base, policy) + self.policy_class = get_policy_from_name(policy_base, policy) else: - self.policy = policy + self.policy_class = policy if device == 'auto': device = 'cuda' if th.cuda.is_available() else 'cpu' @@ -293,7 +293,7 @@ class BaseRLModel(object): model.set_env(env) model.load_parameters(params, opt_params) # resetup modul after load - model._resetup_model() + #model._setup_model() return model @staticmethod diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index dd52e8b..a0c2231 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -119,7 +119,7 @@ class PPO(BaseRLModel): self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 42771bb..cc539d5 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -123,20 +123,11 @@ class SAC(BaseRLModel): self.ent_coef = float(self.ent_coef) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() - def _resetup_model(self): - """ - method used to resetup anything that was not saved - :return: - """ - if self.replay_buffer is None: - obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - def _create_aliases(self): self.actor = self.policy.actor self.critic = self.policy.critic diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 07036cf..1c3639e 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -80,7 +80,7 @@ class TD3(BaseRLModel): obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy(self.observation_space, self.action_space, + self.policy = self.policy_class(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() From 751ccf85e7798139d285240dd3e0305895dd1361 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 15:33:52 +0100 Subject: [PATCH 22/50] _setup_model() is now called when model is loaded --- torchy_baselines/common/base_class.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index cbdb50b..94c1cc8 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -287,13 +287,13 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) - model = cls(policy=data["policy"], env=data["env"], _init_setup_model=False) + model = cls(policy=data["policy_class"], env=data["env"], _init_setup_model=True) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) model.load_parameters(params, opt_params) # resetup modul after load - #model._setup_model() + # model._setup_model() return model @staticmethod From e95858784ad69d2f39f79e5158e18a56aea97b65 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 15:38:04 +0100 Subject: [PATCH 23/50] Formatted all files --- tests/test_save_load.py | 5 ++--- torchy_baselines/cem_rl/cem_rl.py | 3 ++- torchy_baselines/common/base_class.py | 12 +----------- torchy_baselines/ppo/ppo.py | 5 ++--- torchy_baselines/sac/sac.py | 6 ++++-- torchy_baselines/td3/td3.py | 13 ++----------- 6 files changed, 13 insertions(+), 31 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index dec98b4..f3f8d10 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -10,6 +10,7 @@ from torchy_baselines.common.vec_env import DummyVecEnv from torchy_baselines.common.identity_env import IdentityEnvBox MODEL_LIST = [ + CEMRL, PPO, A2C, TD3, @@ -37,7 +38,6 @@ def test_save_load(model_class): observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) observations = np.squeeze(observations) - # Get dictionary of current parameters params = deepcopy(model.get_policy_parameters()) opt_params = deepcopy(model.get_opt_parameters()) @@ -55,8 +55,7 @@ def test_save_load(model_class): params = new_params - - #get selected actions + # get selected actions selected_actions = [model.predict(observation, deterministic=True) for observation in observations] # Check diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index f35fa43..3305cb5 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -78,7 +78,8 @@ class CEMRL(TD3): # set params self.actor.load_from_vector(self.es_params[i]) self.actor_target.load_from_vector(self.es_params[i]) - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate(self._current_progress)) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), + lr=self.learning_rate(self._current_progress)) # In the paper: 2 * actor_steps // self.n_grad # In the original implementation: actor_steps // self.n_grad diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 94c1cc8..8c363c9 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -292,8 +292,6 @@ class BaseRLModel(object): model.__dict__.update(kwargs) model.set_env(env) model.load_parameters(params, opt_params) - # resetup modul after load - # model._setup_model() return model @staticmethod @@ -519,14 +517,6 @@ class BaseRLModel(object): """ return ["replay_buffer"] - def _resetup_model(self): - """ - Function will be called at the end of load and should resetup anything that might not have been saved - warning: this function should always be in compliance with excluded_save_params - :return: - """ - pass - def save(self, path, include=None): """ saves all the params from init and pytorch params in a file for continuous learning @@ -548,4 +538,4 @@ class BaseRLModel(object): params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() - self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) \ No newline at end of file + self._save_to_file_zip(path, data=data, params=params_to_save, opt_params=opt_params_to_save) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index a0c2231..2938c12 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -120,8 +120,8 @@ class PPO(BaseRLModel): self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) self.policy = self.policy_class(self.observation_space, self.action_space, - self.learning_rate, use_sde=self.use_sde, device=self.device, - **self.policy_kwargs) + self.learning_rate, use_sde=self.use_sde, device=self.device, + **self.policy_kwargs) self.policy = self.policy.to(self.device) self.clip_range = get_schedule_fn(self.clip_range) @@ -227,7 +227,6 @@ class PPO(BaseRLModel): # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(return_batch, values_pred) - # Entropy loss favor exploration entropy_loss = -th.mean(entropy) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index cc539d5..06e107f 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -52,6 +52,7 @@ class SAC(BaseRLModel): Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ + def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6), learning_starts=100, batch_size=64, tau=0.005, ent_coef='auto', target_update_interval=1, @@ -124,7 +125,7 @@ class SAC(BaseRLModel): self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy_class(self.observation_space, self.action_space, - self.learning_rate, device=self.device, **self.policy_kwargs) + self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() @@ -281,7 +282,8 @@ class SAC(BaseRLModel): :return: (Dict) of optimizer names and their state_dict """ if self.ent_coef_optimizer is not None: - return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(),"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()} + return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(), + "ent_coef_optimizer": self.ent_coef_optimizer.state_dict()} else: return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 1c3639e..5a847ac 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -81,19 +81,10 @@ class TD3(BaseRLModel): self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy_class(self.observation_space, self.action_space, - self.learning_rate, device=self.device, **self.policy_kwargs) + self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() - def _resetup_model(self): - """ - method used to resetup anything that was not saved - :return: - """ - if self.replay_buffer is None: - obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - def _create_aliases(self): self.actor = self.policy.actor self.actor_target = self.policy.actor_target @@ -265,4 +256,4 @@ class TD3(BaseRLModel): """ self.actor.optimizer.load_state_dict(opt_params["actor"]) self.critic.optimizer.load_state_dict(opt_params["critic"]) - self.policy.load_state_dict(load_dict) \ No newline at end of file + self.policy.load_state_dict(load_dict) From ee6f938ddc0da2e3cec56fcea1d9d3dba794fb74 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 15:42:53 +0100 Subject: [PATCH 24/50] Added option to explicitly specify excluded parameters --- torchy_baselines/common/base_class.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8c363c9..ecfb31e 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -517,17 +517,19 @@ class BaseRLModel(object): """ return ["replay_buffer"] - def save(self, path, include=None): + def save(self, path, exclude=None, include=None): """ saves all the params from init and pytorch params in a file for continuous learning :param path: (str) path to the file where the data should be saved + :param exclude: (list) name of parameters that should be excluded, use standard exclude params if None :param include: (list) name of parameters that might be excluded but should be included anyway :return: """ data = self.__dict__ - # get list of params to be excluded - exclude = self.excluded_save_params() + # use standard list of excluded parameters if none given + if exclude is None: + exclude = self.excluded_save_params() # do not exclude params if they are specifically included if include is not None: exclude = [param_name for param_name in exclude if param_name not in include] From c82025e673344e76f8015b2da2f0f40e6217c98a Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 16:07:15 +0100 Subject: [PATCH 25/50] Add Test for exclude/include feature of save --- tests/test_save_load.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f3f8d10..36dfdc0 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -97,3 +97,37 @@ def test_save_load(model_class): # clear file from os os.remove("test_save.zip") + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_exclude_include_saved_params(model_class): + """ + Test if exclude and include parameters of save() work + + :param model_class: (BaseRLModel) A RL model + """ + env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + + # create model + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + # set verbose as something different then standard settings + model.verbose = 2 + + # Check if exclude works + model.save("test_save.zip", exclude=["verbose"]) + del model + model = model_class.load("test_save") + # check if verbose was not saved + assert not model.verbose == 2 + + # set verbose as something different then standard settings + model.verbose = 2 + # Check if include works + model.save("test_save.zip", exclude=["verbose"], include=["verbose"]) + del model + model = model_class.load("test_save") + assert model.verbose == 2 + + + # clear file from os + os.remove("test_save.zip") From 7ce610fadea582414ccc6713ee5862eaa2d6f906 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 16:11:35 +0100 Subject: [PATCH 26/50] Deleted exact match parameter of load_parameters --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index ecfb31e..dd26e4f 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -250,7 +250,7 @@ class BaseRLModel(object): """ pass - def load_parameters(self, load_dict, opt_params=None, exact_match=True): + def load_parameters(self, load_dict, opt_params=None): """ Load model parameters from a dictionary From 6928879f5aa638a747c709208d721d7459d31aa6 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 28 Nov 2019 16:30:13 +0100 Subject: [PATCH 27/50] Refactored doc-strings --- torchy_baselines/common/base_class.py | 38 ++++++++------------------- torchy_baselines/ppo/ppo.py | 7 +++-- torchy_baselines/sac/sac.py | 4 +-- torchy_baselines/td3/td3.py | 4 +-- 4 files changed, 18 insertions(+), 35 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index dd26e4f..1df0bac 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -56,12 +56,10 @@ class BaseRLModel(object): self.action_space = None self.n_envs = None self.num_timesteps = 0 - self.params = None self.eval_env = None self.replay_buffer = None self.seed = seed self.action_noise = None - self.params = None # Track the training progress (from 1 to 0) # this is used to update the learning rate self._current_progress = 1 @@ -176,16 +174,6 @@ class BaseRLModel(object): """ pass - def get_parameter_list(self): - """ - Get pytorch Variables of model's parameters - - This includes all variables necessary for continuing training (saving / loading). - - :return: (list) List of pytorch Variables - """ - return self.params - def get_policy_parameters(self): """ Get current model policy parameters as dictionary of variable name -> tensors. @@ -253,14 +241,8 @@ class BaseRLModel(object): def load_parameters(self, load_dict, opt_params=None): """ Load model parameters from a dictionary - - Dictionary should contain all entries of torch model.state_dict() - - This does not load agent's hyper-parameters. - - .. warning:: - This function does not update trainer/optimizer variables (e.g. momentum). - As such training after using this function may lead to less-than-optimal results. + load_dict should contain all keys from torch.model.state_dict() + If opt_params are given this does also load agent's optimizer-parameters, but can only be handled in child classes. :param load_dict: (dict) dict of parameters from model.state_dict() @@ -273,11 +255,11 @@ class BaseRLModel(object): @classmethod def load(cls, load_path, env=None, **kwargs): """ - Load the model from file + Load the model from a zip-file - :param load_path: (str) the saved parameter location + :param load_path: (str) the location of the saved data :param env: (Gym Envrionment) the new environment to run the loaded model on - (can be None if you only need prediction from a trained model) + (can be None if you only need prediction from a trained model) has priority over any saved environment :param kwargs: extra arguments to change the model when loading """ data, params, opt_params = cls._load_from_file(load_path) @@ -287,7 +269,9 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) - model = cls(policy=data["policy_class"], env=data["env"], _init_setup_model=True) + if env is None and "env" in data: + env = data["env"] + model = cls(policy=data["policy_class"], env=env, _init_setup_model=True) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) @@ -298,7 +282,7 @@ class BaseRLModel(object): def _load_from_file(load_path, load_data=True): """ Load model data from a .zip archive - :param load_path: (str or file-like) Where to load the model from + :param load_path: (str) Where to load the model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) @@ -477,7 +461,7 @@ class BaseRLModel(object): def _save_to_file_zip(save_path, data=None, params=None, opt_params=None): """Save model to a zip archive - :param save_path: (str or file-like) Where to store the model + :param save_path: (str) Where to store the model :param data: (dict) Class parameters being stored :param params: (dict) Model parameters being stored expected to be state_dict :param opt_params: (dict) Optimizer parameters being stored expected to contain an entry for every @@ -519,7 +503,7 @@ class BaseRLModel(object): def save(self, path, exclude=None, include=None): """ - saves all the params from init and pytorch params in a file for continuous learning + saves all the params from init and pytorch params in a zip-file for continuous learning :param path: (str) path to the file where the data should be saved :param exclude: (list) name of parameters that should be excluded, use standard exclude params if None diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 2938c12..bc66171 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -315,7 +315,7 @@ class PPO(BaseRLModel): def get_opt_parameters(self): """ - returns a dict of all the optimizers and their parameters + Returns a dict of all the optimizers and their parameters :return: (Dict) of optimizer names and their state_dict """ @@ -324,12 +324,11 @@ class PPO(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() + load_dict should contain all keys from torch.model.state_dict() This does not load agent's hyper-parameters. - :param load_dict: (dict) dict of parameters from model.state_dict() - :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + :param opt_params: (dict of dicts) dict of optimizer state_dicts """ self.policy.optimizer.load_state_dict(opt_params["opt"]) self.policy.load_state_dict(load_dict) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 06e107f..bfd873b 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -277,7 +277,7 @@ class SAC(BaseRLModel): def get_opt_parameters(self): """ - returns a dict of all the optimizers and their parameters + Returns a dict of all the optimizers and their parameters :return: (Dict) of optimizer names and their state_dict """ @@ -290,7 +290,7 @@ class SAC(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() + load_dict should contain all keys from torch.model.state_dict() This does not load agent's hyper-parameters. :param load_dict: (dict) dict of parameters from model.state_dict() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 5a847ac..557e901 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -239,7 +239,7 @@ class TD3(BaseRLModel): def get_opt_parameters(self): """ - returns a dict of all the optimizers and their parameters + Returns a dict of all the optimizers and their parameters :return: (Dict) of optimizer names and their state_dict """ @@ -248,7 +248,7 @@ class TD3(BaseRLModel): def load_parameters(self, load_dict, opt_params): """ Load model parameters and optimizer parameters from a dictionary - Dictionary should be of shape torch model.state_dict() + load_dict should contain all keys from torch.model.state_dict() This does not load agent's hyper-parameters. :param load_dict: (dict) dict of parameters from model.state_dict() From 362bba73ba88b6b7c0eeb893c43eddd1f6f4712d Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:07:43 +0100 Subject: [PATCH 28/50] adapted common style Co-Authored-By: Raffin, Antonin --- tests/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 36dfdc0..1ade709 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -118,7 +118,7 @@ def test_exclude_include_saved_params(model_class): del model model = model_class.load("test_save") # check if verbose was not saved - assert not model.verbose == 2 + assert model.verbose != 2 # set verbose as something different then standard settings model.verbose = 2 From 03bf513e5e3be8ae40483e82dc8e92ee1dde601d Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:08:20 +0100 Subject: [PATCH 29/50] comment refactoring Co-Authored-By: Raffin, Antonin --- torchy_baselines/ppo/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index bc66171..b0e9de2 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -317,7 +317,7 @@ class PPO(BaseRLModel): """ Returns a dict of all the optimizers and their parameters - :return: (Dict) of optimizer names and their state_dict + :return: (dict) of optimizer names and their state_dict """ return {"opt": self.policy.optimizer.state_dict()} From 85d37432249df7f45b73de02026a51c30c0f1d6b Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:09:26 +0100 Subject: [PATCH 30/50] added standart exclude parameters Co-Authored-By: Raffin, Antonin --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 1df0bac..199467c 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -499,7 +499,7 @@ class BaseRLModel(object): returns the names of the parameters that should be excluded from save :return: (list) List of parameters that should be excluded from save """ - return ["replay_buffer"] + return ["env", "eval_env", "replay_buffer", "rollout_buffer"] def save(self, path, exclude=None, include=None): """ From aa66d2f82ed175cb08bbc0dd061822c321b2dbc5 Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:09:51 +0100 Subject: [PATCH 31/50] comment refactoring Co-Authored-By: Raffin, Antonin --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 199467c..842aa5b 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -497,7 +497,7 @@ class BaseRLModel(object): def excluded_save_params(self): """ returns the names of the parameters that should be excluded from save - :return: (list) List of parameters that should be excluded from save + :return: ([str]) List of parameters that should be excluded from save """ return ["env", "eval_env", "replay_buffer", "rollout_buffer"] From fdb544e775c989d6724f8829eab55a136054cf35 Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:10:14 +0100 Subject: [PATCH 32/50] comment refactoring Co-Authored-By: Raffin, Antonin --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 842aa5b..2554e89 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -503,7 +503,7 @@ class BaseRLModel(object): def save(self, path, exclude=None, include=None): """ - saves all the params from init and pytorch params in a zip-file for continuous learning + Save all the attributes of the object and the model parameters in a zip-file for continuous learning :param path: (str) path to the file where the data should be saved :param exclude: (list) name of parameters that should be excluded, use standard exclude params if None From a756f40223d59f7662f2dc702ff9a0a76b484996 Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:10:25 +0100 Subject: [PATCH 33/50] comment refactoring Co-Authored-By: Raffin, Antonin --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 2554e89..c25a6bb 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -506,7 +506,7 @@ class BaseRLModel(object): Save all the attributes of the object and the model parameters in a zip-file for continuous learning :param path: (str) path to the file where the data should be saved - :param exclude: (list) name of parameters that should be excluded, use standard exclude params if None + :param exclude: ([str]) name of parameters that should be excluded in addition to the default one :param include: (list) name of parameters that might be excluded but should be included anyway :return: """ From bea279969139a7401c7e3d4ee3bacb7bbd8c0661 Mon Sep 17 00:00:00 2001 From: "Dormann, Noah" Date: Thu, 5 Dec 2019 08:10:39 +0100 Subject: [PATCH 34/50] comment refactoring Co-Authored-By: Raffin, Antonin --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index c25a6bb..0719771 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -507,7 +507,7 @@ class BaseRLModel(object): :param path: (str) path to the file where the data should be saved :param exclude: ([str]) name of parameters that should be excluded in addition to the default one - :param include: (list) name of parameters that might be excluded but should be included anyway + :param include: ([str]) name of parameters that might be excluded but should be included anyway :return: """ data = self.__dict__ From c3b0398d56f57bd2a300ad096ca7c120c79b489a Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:40:28 +0100 Subject: [PATCH 35/50] Changed load so it still works when env not saved improved save function --- tests/test_save_load.py | 8 +++----- torchy_baselines/common/base_class.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1ade709..bae9e0e 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -61,7 +61,7 @@ def test_save_load(model_class): # Check model.save("test_save.zip") del model - model = model_class.load("test_save") + model = model_class.load("test_save", env=env) # check if params are still the same after load new_params = model.get_policy_parameters() @@ -108,10 +108,8 @@ def test_exclude_include_saved_params(model_class): """ env = DummyVecEnv([lambda: IdentityEnvBox(10)]) - # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) - # set verbose as something different then standard settings - model.verbose = 2 + # create model, set verbose as 2, which is not standard + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2, create_eval_env=True) # Check if exclude works model.save("test_save.zip", exclude=["verbose"]) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 0719771..a195580 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -271,7 +271,10 @@ class BaseRLModel(object): if env is None and "env" in data: env = data["env"] - model = cls(policy=data["policy_class"], env=env, _init_setup_model=True) + if env is not None: + model = cls(policy=data["policy_class"], env=env, _init_setup_model=True) + else: + model = cls(policy=data["policy_class"], env=env, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) model.set_env(env) @@ -508,19 +511,23 @@ class BaseRLModel(object): :param path: (str) path to the file where the data should be saved :param exclude: ([str]) name of parameters that should be excluded in addition to the default one :param include: ([str]) name of parameters that might be excluded but should be included anyway - :return: """ - data = self.__dict__ + # copy parameter list so we don't mutate the original dict + data = self.__dict__.copy() # use standard list of excluded parameters if none given if exclude is None: exclude = self.excluded_save_params() + else: + # append standard exclude params to the given params + exclude.extend([param for param in self.excluded_save_params() if param not in exclude]) # do not exclude params if they are specifically included if include is not None: exclude = [param_name for param_name in exclude if param_name not in include] # remove parameter entries of parameters which are to be excluded for param_name in exclude: - data.pop(param_name, None) + if param_name in data: + data.pop(param_name, None) params_to_save = self.get_policy_parameters() opt_params_to_save = self.get_opt_parameters() From 4c0f6cbe53c31070d316f5c45d7c568a2e51d9d6 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:43:12 +0100 Subject: [PATCH 36/50] update get_opt_parameters to remove duplicate code --- torchy_baselines/sac/sac.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index bfd873b..5222754 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -281,11 +281,10 @@ class SAC(BaseRLModel): :return: (Dict) of optimizer names and their state_dict """ + opt_dict = {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} if self.ent_coef_optimizer is not None: - return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict(), - "ent_coef_optimizer": self.ent_coef_optimizer.state_dict()} - else: - return {"actor": self.actor.optimizer.state_dict(), "critic": self.critic.optimizer.state_dict()} + opt_dict.update({"ent_coef_optimizer": self.ent_coef_optimizer.state_dict()}) + return opt_dict def load_parameters(self, load_dict, opt_params): """ From ff7c4d24f459f830af2b8d0e03b5251e212165ad Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:46:26 +0100 Subject: [PATCH 37/50] deleted types in train_actor td3 --- torchy_baselines/td3/td3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 557e901..645ecf2 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -149,9 +149,9 @@ class TD3(BaseRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, - tau_critic: object = 0.005, - replay_data: object = None) -> object: + def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, + tau_critic=0.005, + replay_data=None): # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer) From 7c8d375bcb3d30bac3861a69680ffe0da986b4f0 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:50:11 +0100 Subject: [PATCH 38/50] added get_parameter_list function --- torchy_baselines/common/base_class.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a195580..680ac17 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -174,6 +174,13 @@ class BaseRLModel(object): """ pass + def get_parameter_list(self): + """ + Returns policy and optimizer parameters as a tuple + :return: (dict,dict) policy_parameters, opt_parameters + """ + return self.get_policy_parameters(),self.get_opt_parameters() + def get_policy_parameters(self): """ Get current model policy parameters as dictionary of variable name -> tensors. From 6560ae99524b54970285e7395edacf75523bf081 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:52:26 +0100 Subject: [PATCH 39/50] using other_file instead of other_files --- torchy_baselines/common/base_class.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 680ac17..8f0acca 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -330,11 +330,11 @@ class BaseRLModel(object): file_content.seek(0) params = th.load(file_content) # check for all other .pth files - other_files = [file_name for file_name in namelist if + other_file = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] - if len(other_files) > 0: + if len(other_file) > 0: opt_params = dict() - for file in other_files: + for file in other_file: with file_.open(file, mode="r") as opt_param_file: # File has to be seekable so load in BytesIO first file_content = io.BytesIO() From 8460bfe397d7339e18f1062ec0d2eb297f136af7 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 08:56:04 +0100 Subject: [PATCH 40/50] added some comments to _load_from_file --- torchy_baselines/common/base_class.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8f0acca..09fe798 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -329,9 +329,12 @@ class BaseRLModel(object): # go to start of file file_content.seek(0) params = th.load(file_content) + # check for all other .pth files other_file = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] + # if there are any other files which end with .pth and aren't "params.pth" + # assume that they each are optimizer parameters if len(other_file) > 0: opt_params = dict() for file in other_file: @@ -341,6 +344,7 @@ class BaseRLModel(object): file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) + # save the parameters in dict with file name but trim file ending opt_params[os.path.splitext(file)[0]] = th.load(file_content) except zipfile.BadZipFile: # load_path wasn't a zip file From 4b1bab7f858befa36aef0231b67e62c8f2f7d221 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 09:11:30 +0100 Subject: [PATCH 41/50] implemented set_env method --- torchy_baselines/common/base_class.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 09fe798..384ad27 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -169,10 +169,20 @@ class BaseRLModel(object): def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. + checked parameters: + - observation_space + - action_space :param env: (Gym Environment) The environment for learning a policy """ - pass + + if self.observation_space != env.observation_space: + raise ValueError("The given environment has a observation_space that doesn't fit the current model") + + if self.action_space != env.action_space: + raise ValueError("The given environment has a action_space that doesn't fit the current model") + # if all fits save new env + self.env = env def get_parameter_list(self): """ From 8062ed6036a80f594b0956a3b9f68ad538d3ec8b Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 13:36:19 +0100 Subject: [PATCH 42/50] fixed load, to check if environment ist correctly --- tests/test_save_load.py | 10 ++++-- torchy_baselines/common/base_class.py | 48 +++++++++++++++++++-------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index bae9e0e..eef023d 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -89,8 +89,8 @@ def test_save_load(model_class): # check if model still selects the same actions new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] - for i in range(len(selected_actions)): - assert selected_actions[i] == new_selected_actions[i] + # for i in range(len(selected_actionsselected_actions)): + assert np.allclose(selected_actions, new_selected_actions) # check if learn still works model.learn(total_timesteps=1000, eval_freq=500) @@ -99,6 +99,11 @@ def test_save_load(model_class): os.remove("test_save.zip") +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_set_env(model_class): + pass + + @pytest.mark.parametrize("model_class", MODEL_LIST) def test_exclude_include_saved_params(model_class): """ @@ -126,6 +131,5 @@ def test_exclude_include_saved_params(model_class): model = model_class.load("test_save") assert model.verbose == 2 - # clear file from os os.remove("test_save.zip") diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 384ad27..d10dec6 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -166,6 +166,22 @@ class BaseRLModel(object): """ return self.env + @staticmethod + def check_env(env, observation_space, action_space): + """ + Checks the validity of the environment and returns if it is coherent + Checked parameters: + - observation_space + - action_space + :return: (bool) True if environment seems to be coherent + """ + if observation_space != env.observation_space: + return False + if action_space != env.action_space: + return False + # return true if no check failed + return True + def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. @@ -175,21 +191,20 @@ class BaseRLModel(object): :param env: (Gym Environment) The environment for learning a policy """ - - if self.observation_space != env.observation_space: - raise ValueError("The given environment has a observation_space that doesn't fit the current model") - - if self.action_space != env.action_space: - raise ValueError("The given environment has a action_space that doesn't fit the current model") + if self.check_env(env, self.observation_space, self.action_space) is False: + raise ValueError("Given environment is not compatible with model") # if all fits save new env self.env = env + # and update observation and action space + self.observation_space = env.observation_space + self.action_space = env.action_space def get_parameter_list(self): """ Returns policy and optimizer parameters as a tuple :return: (dict,dict) policy_parameters, opt_parameters """ - return self.get_policy_parameters(),self.get_opt_parameters() + return self.get_policy_parameters(), self.get_opt_parameters() def get_policy_parameters(self): """ @@ -286,15 +301,22 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) + # check if observation space and action space is given + if ("observation_space" not in data or "action_space" not in data) and "env" not in data: + raise ValueError("The observation_space and action_space was not given, can't verify new environments") + # check if given env is valid + if env is not None and cls.check_env(env, data["observation_space"], data["action_space"]) is False: + raise ValueError("The given environment does not comply to the model") + # if no new env was given use stored env if possible if env is None and "env" in data: env = data["env"] - if env is not None: - model = cls(policy=data["policy_class"], env=env, _init_setup_model=True) - else: - model = cls(policy=data["policy_class"], env=env, _init_setup_model=False) + + # first create model, but only setup if a env was given + model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None) + + # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) - model.set_env(env) model.load_parameters(params, opt_params) return model @@ -342,7 +364,7 @@ class BaseRLModel(object): # check for all other .pth files other_file = [file_name for file_name in namelist if - os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] + os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters if len(other_file) > 0: From cf1d7118a5614b1c6743d496199c9e057455c480 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 13:44:02 +0100 Subject: [PATCH 43/50] replaced file with file_path --- torchy_baselines/common/base_class.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index d10dec6..bb98a27 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -339,8 +339,8 @@ class BaseRLModel(object): # Open the zip archive and load data try: - with zipfile.ZipFile(load_path, "r") as file_: - namelist = file_.namelist() + with zipfile.ZipFile(load_path, "r") as archive: + namelist = archive.namelist() # If data or parameters is not in the # zip archive, assume they were stored # as None (_save_to_file_zip allows this). @@ -349,12 +349,12 @@ class BaseRLModel(object): opt_params = None if "data" in namelist and load_data: # Load class parameters and convert to string - json_data = file_.read("data").decode() + json_data = archive.read("data").decode() data = json_to_data(json_data) if "params.pth" in namelist: # Load parameters with build in torch function - with file_.open("params.pth", mode="r") as param_file: + with archive.open("params.pth", mode="r") as param_file: # File has to be seekable so load in BytesIO first file_content = io.BytesIO() file_content.write(param_file.read()) @@ -363,21 +363,21 @@ class BaseRLModel(object): params = th.load(file_content) # check for all other .pth files - other_file = [file_name for file_name in namelist if + other_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters - if len(other_file) > 0: + if len(other_files) > 0: opt_params = dict() - for file in other_file: - with file_.open(file, mode="r") as opt_param_file: + for file_path in other_files: + with archive.open(file_path, mode="r") as opt_param_file: # File has to be seekable so load in BytesIO first file_content = io.BytesIO() file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) # save the parameters in dict with file name but trim file ending - opt_params[os.path.splitext(file)[0]] = th.load(file_content) + opt_params[os.path.splitext(file_path)[0]] = th.load(file_content) except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError("Error: the file {} wasn't a zip-file".format(load_path)) @@ -528,16 +528,16 @@ class BaseRLModel(object): # Create a zip-archive and write our objects # there. This works when save_path is either # str or a file-like - with zipfile.ZipFile(save_path, "w") as file_: + with zipfile.ZipFile(save_path, "w") as archive: # Do not try to save "None" elements if data is not None: - file_.writestr("data", serialized_data) + archive.writestr("data", serialized_data) if params is not None: - with file_.open('params.pth', mode="w") as param_file: + with archive.open('params.pth', mode="w") as param_file: th.save(params, param_file) if opt_params is not None: for file_name, dict in opt_params.items(): - with file_.open(file_name + '.pth', mode="w") as opt_param_file: + with archive.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) def excluded_save_params(self): From 88d4f44d554f31b4d668a53b1ecc47ada4a0b1be Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 13:59:07 +0100 Subject: [PATCH 44/50] added set_env test and set_env wrapping --- tests/test_save_load.py | 25 +++++++++++++++++++++++-- torchy_baselines/common/base_class.py | 17 +++++++++++------ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index eef023d..3a2cb42 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -7,7 +7,7 @@ import torch as th from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.vec_env import DummyVecEnv -from torchy_baselines.common.identity_env import IdentityEnvBox +from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv MODEL_LIST = [ CEMRL, @@ -101,7 +101,28 @@ def test_save_load(model_class): @pytest.mark.parametrize("model_class", MODEL_LIST) def test_set_env(model_class): - pass + """ + Test if set_env function does work correct + :param model_class: (BaseRLModel) A RL model + """ + env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + env2 = DummyVecEnv([lambda: IdentityEnvBox(10)]) + env3 = IdentityEnvBox(10) + + # create model + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True) + # learn + model.learn(total_timesteps=1000, eval_freq=500) + + # change env + model.set_env(env2) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) + + # change env test wrapping + model.set_env(env3) + # learn again + model.learn(total_timesteps=1000, eval_freq=500) @pytest.mark.parametrize("model_class", MODEL_LIST) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index bb98a27..4a08917 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -185,6 +185,7 @@ class BaseRLModel(object): def set_env(self, env): """ Checks the validity of the environment, and if it is coherent, set it as the current environment. + Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space @@ -193,13 +194,16 @@ class BaseRLModel(object): """ if self.check_env(env, self.observation_space, self.action_space) is False: raise ValueError("Given environment is not compatible with model") - # if all fits save new env + # it must be coherent now + # if it is not a VecEnv, make it a VecEnv + if not isinstance(env, VecEnv): + if self.verbose >= 1: + print("Wrapping the env in a DummyVecEnv.") + env = DummyVecEnv([lambda: env]) + self.n_envs = env.num_envs self.env = env - # and update observation and action space - self.observation_space = env.observation_space - self.action_space = env.action_space - def get_parameter_list(self): + def get_parameters(self): """ Returns policy and optimizer parameters as a tuple :return: (dict,dict) policy_parameters, opt_parameters @@ -540,7 +544,8 @@ class BaseRLModel(object): with archive.open(file_name + '.pth', mode="w") as opt_param_file: th.save(dict, opt_param_file) - def excluded_save_params(self): + @staticmethod + def excluded_save_params(): """ returns the names of the parameters that should be excluded from save :return: ([str]) List of parameters that should be excluded from save From 03ecb17ef635cd6585c38f6045e30bd02a5fff60 Mon Sep 17 00:00:00 2001 From: "Raffin, Antonin" Date: Thu, 5 Dec 2019 14:41:39 +0100 Subject: [PATCH 45/50] Update error message --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 4a08917..0bd1fae 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -193,7 +193,7 @@ class BaseRLModel(object): :param env: (Gym Environment) The environment for learning a policy """ if self.check_env(env, self.observation_space, self.action_space) is False: - raise ValueError("Given environment is not compatible with model") + raise ValueError("The given environment is not compatible with model: observation and action spaces do not match") # it must be coherent now # if it is not a VecEnv, make it a VecEnv if not isinstance(env, VecEnv): From 464dd773e60002229acd189fb114ebc89dde843d Mon Sep 17 00:00:00 2001 From: "Raffin, Antonin" Date: Thu, 5 Dec 2019 14:46:02 +0100 Subject: [PATCH 46/50] Update comment --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 0bd1fae..40c3061 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -305,7 +305,7 @@ class BaseRLModel(object): "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], kwargs['policy_kwargs'])) - # check if observation space and action space is given + # check if observation space and action space are part of the saved parameters if ("observation_space" not in data or "action_space" not in data) and "env" not in data: raise ValueError("The observation_space and action_space was not given, can't verify new environments") # check if given env is valid From 424a5545670422d859abc706a0b78e5b10573d4d Mon Sep 17 00:00:00 2001 From: "Raffin, Antonin" Date: Thu, 5 Dec 2019 14:50:11 +0100 Subject: [PATCH 47/50] Update docstring --- torchy_baselines/common/base_class.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 40c3061..c70c2a2 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -547,7 +547,9 @@ class BaseRLModel(object): @staticmethod def excluded_save_params(): """ - returns the names of the parameters that should be excluded from save + Returns the names of the parameters that should be excluded by default + when saving the model. + :return: ([str]) List of parameters that should be excluded from save """ return ["env", "eval_env", "replay_buffer", "rollout_buffer"] From 695cdc63a44b52b642759398f3835bfe73e90cc2 Mon Sep 17 00:00:00 2001 From: "Raffin, Antonin" Date: Thu, 5 Dec 2019 14:52:59 +0100 Subject: [PATCH 48/50] Update torchy_baselines/common/base_class.py --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index c70c2a2..8f14cb3 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -558,7 +558,7 @@ class BaseRLModel(object): """ Save all the attributes of the object and the model parameters in a zip-file for continuous learning - :param path: (str) path to the file where the data should be saved + :param path: (str) path to the file where the rl agent should be saved :param exclude: ([str]) name of parameters that should be excluded in addition to the default one :param include: ([str]) name of parameters that might be excluded but should be included anyway """ From bac9d4efed1412d0e8c8984c928b66084ada290a Mon Sep 17 00:00:00 2001 From: "Raffin, Antonin" Date: Thu, 5 Dec 2019 14:53:14 +0100 Subject: [PATCH 49/50] Update torchy_baselines/common/base_class.py --- torchy_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8f14cb3..251ce1d 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -556,7 +556,7 @@ class BaseRLModel(object): def save(self, path, exclude=None, include=None): """ - Save all the attributes of the object and the model parameters in a zip-file for continuous learning + Save all the attributes of the object and the model parameters in a zip-file. :param path: (str) path to the file where the rl agent should be saved :param exclude: ([str]) name of parameters that should be excluded in addition to the default one From aa67147796ad3c478014edc350420fa30227dc90 Mon Sep 17 00:00:00 2001 From: Noah Dormann Date: Thu, 5 Dec 2019 15:45:05 +0100 Subject: [PATCH 50/50] clarified bytesIO use for load --- torchy_baselines/common/base_class.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 251ce1d..ccb3f5e 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -359,7 +359,8 @@ class BaseRLModel(object): if "params.pth" in namelist: # Load parameters with build in torch function with archive.open("params.pth", mode="r") as param_file: - # File has to be seekable so load in BytesIO first + # File has to be seekable, but param_file is not, so load in BytesIO first + # fixed in python >= 3.7 file_content = io.BytesIO() file_content.write(param_file.read()) # go to start of file @@ -375,7 +376,8 @@ class BaseRLModel(object): opt_params = dict() for file_path in other_files: with archive.open(file_path, mode="r") as opt_param_file: - # File has to be seekable so load in BytesIO first + # File has to be seekable, but opt_param_file is not, so load in BytesIO first + # fixed in python >= 3.7 file_content = io.BytesIO() file_content.write(opt_param_file.read()) # go to start of file