From c5420096414dc2c1d63f0bf58c052e6f99f5317b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 20 Jan 2020 11:17:55 +0100 Subject: [PATCH] Clean up code + bug fixes --- tests/test_custom_policy.py | 5 +-- tests/test_distributions.py | 6 +-- tests/test_save_load.py | 10 ++--- tests/test_sde.py | 8 +--- tests/test_vec_envs.py | 7 ++-- torchy_baselines/common/base_class.py | 41 +++++++------------ torchy_baselines/common/buffers.py | 5 +-- torchy_baselines/common/distributions.py | 3 +- torchy_baselines/common/logger.py | 14 +------ torchy_baselines/common/vec_env/__init__.py | 2 +- .../common/vec_env/vec_normalize.py | 2 +- torchy_baselines/ppo/policies.py | 3 +- torchy_baselines/sac/policies.py | 3 ++ torchy_baselines/sac/sac.py | 9 ++-- torchy_baselines/td3/td3.py | 11 ++--- 15 files changed, 51 insertions(+), 78 deletions(-) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index d45be88..2555016 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,6 +1,3 @@ -import os - -import gym import pytest from torchy_baselines import PPO @@ -15,4 +12,4 @@ from torchy_baselines import PPO [12, dict(pi=[8])], ]) def test_flexible_mlp(net_arch): - model = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) + _ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 3d896ab..7e24cda 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,10 +1,9 @@ import pytest -import numpy as np import torch as th +from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \ + StateDependentNoiseDistribution from torchy_baselines.common.utils import set_random_seed -from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution,\ - CategoricalDistribution, TanhBijector, StateDependentNoiseDistribution # TODO: more tests for the other distributions @@ -43,6 +42,7 @@ def test_sde_distribution(): N_ACTIONS = 1 + # TODO: fix for num action > 1 # TODO: analytical form for squashed Gaussian? @pytest.mark.parametrize("dist", [ diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 3a2cb42..f4adc0c 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,13 +1,12 @@ +import numpy as np import os import pytest -from copy import deepcopy -import numpy as np - import torch as th +from copy import deepcopy from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines.common.identity_env import IdentityEnvBox from torchy_baselines.common.vec_env import DummyVecEnv -from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv MODEL_LIST = [ CEMRL, @@ -81,7 +80,8 @@ def test_save_load(model_class): for optimizer, opt_state in opt_params.items(): for param_group_idx, param_group in enumerate(opt_state['param_groups']): for param_key, param_value in param_group.items(): - if param_key == 'params': # don't know how to handle params correctly, therefore only check if we have the same amount + # don't know how to handle params correctly, therefore only check if we have the same amount + if param_key == 'params': assert len(param_value) == len( new_opt_params[optimizer]['param_groups'][param_group_idx][param_key]) else: diff --git a/tests/test_sde.py b/tests/test_sde.py index 851ca42..497f398 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -1,12 +1,8 @@ import pytest - -import gym import torch as th from torch.distributions import Normal from torchy_baselines import A2C, TD3, SAC -from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize -from torchy_baselines.common.monitor import Monitor def test_state_dependent_exploration_grad(): @@ -35,7 +31,7 @@ def test_state_dependent_exploration_grad(): action_dist = Normal(mu, th.sqrt(variance)) # Sum over the action dimension because we assume they are independent - loss = action_dist.log_prob((action).detach()).sum(dim=-1).mean() + loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean() loss.backward() # From Rueckstiess paper: check that the computed gradient @@ -72,6 +68,6 @@ def test_scheduler(): return -2.0 * progress + 1 model = TD3('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, - verbose=1, sde_log_std_scheduler=scheduler) + verbose=1, sde_log_std_scheduler=scheduler) model.learn(total_timesteps=int(1000), eval_freq=500) assert th.isclose(model.actor.log_std, th.ones_like(model.actor.log_std)).all() diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index efa5119..f2dd1c2 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -325,9 +325,8 @@ def test_vecenv_wrapper_getattr(): assert wrapped.name_test() == CustomWrapperBB double_wrapped = CustomWrapperA(CustomWrapperB(wrapped)) - dummy = double_wrapped.var_a # should not raise as it is directly defined here + _ = double_wrapped.var_a # should not raise as it is directly defined here with pytest.raises(AttributeError): # should raise due to ambiguity - dummy = double_wrapped.var_b + _ = double_wrapped.var_b with pytest.raises(AttributeError): # should raise as does not exist - dummy = double_wrapped.nonexistent_attribute - del dummy # keep linter happy + _ = double_wrapped.nonexistent_attribute diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 4c7a764..0f20211 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -225,7 +225,8 @@ class BaseRLModel(object): :param env: (gym.Env) The environment for learning a policy """ if self.check_env(env, self.observation_space, self.action_space) is False: - raise ValueError("The given environment is not compatible with model: observation and action spaces do not match") + raise ValueError("The given environment is not compatible with model: " + "observation and action spaces do not match") # it must be coherent now # if it is not a VecEnv, make it a VecEnv if not isinstance(env, VecEnv): @@ -258,24 +259,6 @@ class BaseRLModel(object): """ raise NotImplementedError() - def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, - adam_epsilon=1e-8, val_interval=None): - """ - Pretrain a model using behavior cloning: - supervised learning given an expert dataset. - - NOTE: only Box and Discrete spaces are supported for now. - - :param dataset: (ExpertDataset) Dataset manager - :param n_epochs: (int) Number of iterations on the training set - :param learning_rate: (float) Learning rate - :param adam_epsilon: (float) the epsilon value for the adam optimizer - :param val_interval: (int) Report training and validation losses every n epochs. - By default, every 10th of the maximum number of epochs. - :return: (BaseRLModel) the pretrained model - """ - raise NotImplementedError() - @abstractmethod def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run", eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True): @@ -308,11 +291,12 @@ class BaseRLModel(object): """ pass - def load_parameters(self, load_dict, opt_params=None): + def load_parameters(self, load_dict, opt_params): """ Load model parameters from a dictionary load_dict should contain all keys from torch.model.state_dict() - If opt_params are given this does also load agent's optimizer-parameters, but can only be handled in child classes. + If opt_params are given this does also load agent's optimizer-parameters, + but can only be handled in child classes. :param load_dict: (dict) dict of parameters from model.state_dict() @@ -350,6 +334,7 @@ class BaseRLModel(object): env = data["env"] # first create model, but only setup if a env was given + # noinspection PyArgumentList model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None) # load parameters @@ -365,7 +350,8 @@ class BaseRLModel(object): :param load_path: (str) Where to load the model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) - :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) + :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) + and dict of optimizer parameters (dict of state_dict) """ # Check if file exists if load_path is a string if isinstance(load_path, str): @@ -403,7 +389,7 @@ class BaseRLModel(object): # check for all other .pth files other_files = [file_name for file_name in namelist if - os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] + os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters if len(other_files) > 0: @@ -521,7 +507,7 @@ class BaseRLModel(object): episode_reward, episode_timesteps = 0.0, 0 while not done: - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.actor.reset_noise() @@ -577,7 +563,8 @@ class BaseRLModel(object): self.rollout_data['actions'].append(scaled_action[0].copy()) self.rollout_data['rewards'].append(reward[0].copy()) self.rollout_data['dones'].append(np.array(done_bool[0]).copy()) - self.rollout_data['values'].append(self.vf_net(th.FloatTensor(obs).to(self.device))[0].cpu().detach().numpy()) + obs_tensor = th.FloatTensor(obs).to(self.device) + self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy()) obs = new_obs # Save the true unnormalized observation @@ -674,9 +661,9 @@ class BaseRLModel(object): with archive.open('params.pth', mode="w") as param_file: th.save(params, param_file) if opt_params is not None: - for file_name, dict in opt_params.items(): + for file_name, dict_ in opt_params.items(): with archive.open(file_name + '.pth', mode="w") as opt_param_file: - th.save(dict, opt_param_file) + th.save(dict_, opt_param_file) @staticmethod def excluded_save_params(): diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index ae31a1f..369841f 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -1,8 +1,6 @@ import numpy as np import torch as th -from torchy_baselines.common.vec_env import unwrap_vec_normalize - class BaseBuffer(object): """ @@ -79,7 +77,8 @@ class BaseBuffer(object): """ raise NotImplementedError() - def _normalize_obs(self, obs, env=None): + @staticmethod + def _normalize_obs(obs, env=None): if env is not None: # TODO: get rid of pytorch - numpy conversion return th.FloatTensor(env.normalize_obs(obs.numpy())) diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index ba37cd3..9d6189c 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -411,7 +411,8 @@ class TanhBijector(object): super(TanhBijector, self).__init__() self.epsilon = epsilon - def forward(self, x): + @staticmethod + def forward(x): return th.tanh(x) @staticmethod diff --git a/torchy_baselines/common/logger.py b/torchy_baselines/common/logger.py index 5b59b76..15528eb 100644 --- a/torchy_baselines/common/logger.py +++ b/torchy_baselines/common/logger.py @@ -1,11 +1,10 @@ """ Taken from stable-baselines """ -import os import sys -import json -import time import datetime +import json +import os import tempfile import warnings from collections import defaultdict @@ -185,15 +184,6 @@ class CSVOutputFormat(KVWriter): self.file.close() -def summary_val(key, value): - """ - :param key: (str) - :param value: (float) - """ - kwargs = {'tag': key, 'simple_value': float(value)} - return tf.Summary.Value(**kwargs) - - def valid_float_value(value): """ Returns True if the value can be successfully cast into a float diff --git a/torchy_baselines/common/vec_env/__init__.py b/torchy_baselines/common/vec_env/__init__.py index 2c542e5..38099af 100644 --- a/torchy_baselines/common/vec_env/__init__.py +++ b/torchy_baselines/common/vec_env/__init__.py @@ -35,4 +35,4 @@ def sync_envs_normalization(env, eval_env): if isinstance(env_tmp, VecNormalize): eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) env_tmp = env_tmp.venv - eval_env_tmp.venv + eval_env_tmp = eval_env_tmp.venv diff --git a/torchy_baselines/common/vec_env/vec_normalize.py b/torchy_baselines/common/vec_env/vec_normalize.py index 0b9797f..3cb3c67 100644 --- a/torchy_baselines/common/vec_env/vec_normalize.py +++ b/torchy_baselines/common/vec_env/vec_normalize.py @@ -71,7 +71,7 @@ class VecNormalize(VecEnvWrapper): def normalize_obs(self, obs): if self.norm_obs: return np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, - self.clip_obs) + self.clip_obs) return obs def normalize_reward(self, reward): diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index d3b897e..5dd746f 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -157,7 +157,8 @@ class PPOPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic) + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, + deterministic=deterministic) def actor_forward(self, obs, deterministic=False): latent_pi, _, latent_sde = self._get_latent(obs) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index bc5fbcf..12b23cc 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -241,6 +241,9 @@ class SACPolicy(BasePolicy): def make_critic(self): return Critic(**self.net_args).to(self.device) + def forward(self, obs): + return self.actor(obs) + MlpPolicy = SACPolicy diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index f3f45fb..f5a18d2 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -1,5 +1,3 @@ -import time - import torch as th import torch.nn.functional as F import numpy as np @@ -68,7 +66,8 @@ class SAC(BaseRLModel): _init_setup_model=True): super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device, - create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq) + create_eval_env=create_eval_env, seed=seed, + use_sde=use_sde, sde_sample_freq=sde_sample_freq) self.learning_rate = learning_rate self.target_entropy = target_entropy @@ -131,7 +130,8 @@ class SAC(BaseRLModel): self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy_class(self.observation_space, self.action_space, - self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs) + self.learning_rate, use_sde=self.use_sde, + device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() @@ -203,7 +203,6 @@ class SAC(BaseRLModel): ent_coef_loss.backward() self.ent_coef_optimizer.step() - with th.no_grad(): # if self.use_sde: # self.actor.reset_noise(batch_size=batch_size) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index a24e0db..d9ce51b 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -1,5 +1,3 @@ -import time - import torch as th import torch.nn.functional as F import numpy as np @@ -62,7 +60,8 @@ class TD3(BaseRLModel): seed=0, device='auto', _init_setup_model=True): super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device, - create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq) + create_eval_env=create_eval_env, seed=seed, + use_sde=use_sde, sde_sample_freq=sde_sample_freq) self.buffer_size = buffer_size self.learning_rate = learning_rate @@ -94,7 +93,8 @@ class TD3(BaseRLModel): self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy_class(self.observation_space, self.action_space, - self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs) + self.learning_rate, use_sde=self.use_sde, + device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() @@ -209,7 +209,8 @@ class TD3(BaseRLModel): # self._update_learning_rate(self.policy.optimizer) # Unpack - obs, action, advantage, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'advantage', 'returns']] + obs, action, advantage, returns = [self.rollout_data[key] for key in + ['observations', 'actions', 'advantage', 'returns']] log_prob, entropy = self.actor.evaluate_actions(obs, action) values = self.vf_net(obs).flatten()