From dcb54b5301d89d606783c3592e544ffbefa4080b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 23 Mar 2020 14:48:38 +0100 Subject: [PATCH] Remove CEMRL --- README.md | 1 - docs/index.rst | 1 - docs/misc/changelog.rst | 1 + docs/modules/cem_rl.rst | 96 ------------ tests/test_callbacks.py | 11 +- tests/test_predict.py | 3 +- tests/test_run.py | 8 +- tests/test_save_load.py | 3 +- tests/test_vec_normalize.py | 4 +- torchy_baselines/__init__.py | 1 - torchy_baselines/cem_rl/__init__.py | 2 - torchy_baselines/cem_rl/cem.py | 132 ----------------- torchy_baselines/cem_rl/cem_rl.py | 217 ---------------------------- torchy_baselines/sac/sac.py | 1 + torchy_baselines/td3/td3.py | 90 ++++-------- torchy_baselines/version.txt | 2 +- 16 files changed, 40 insertions(+), 533 deletions(-) delete mode 100644 docs/modules/cem_rl.rst delete mode 100644 torchy_baselines/cem_rl/__init__.py delete mode 100644 torchy_baselines/cem_rl/cem.py delete mode 100644 torchy_baselines/cem_rl/cem_rl.py diff --git a/README.md b/README.md index ad745eb..843d55e 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,6 @@ NOTE: Python 3.6 is required! ## Implemented Algorithms - A2C -- CEM-RL (with TD3) - PPO - SAC - TD3 diff --git a/docs/index.rst b/docs/index.rst index 6ae60c1..c46e317 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,6 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do modules/base modules/a2c - modules/cem_rl modules/ppo modules/sac modules/td3 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d58ba20..cb5d7aa 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,7 @@ Pre-Release 0.4.0a0 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Removed CEMRL New Features: ^^^^^^^^^^^^^ diff --git a/docs/modules/cem_rl.rst b/docs/modules/cem_rl.rst deleted file mode 100644 index bc243b9..0000000 --- a/docs/modules/cem_rl.rst +++ /dev/null @@ -1,96 +0,0 @@ -.. _cem_rl: - -.. automodule:: torchy_baselines.cem_rl - - -CEM RL -====== - -Combining cross-entropy method (CEM) and Twin Delayed Deep Deterministic policy gradient (TD3). - - -.. rubric:: Available Policies - -.. autosummary:: - :nosignatures: - - MlpPolicy - - -Notes ------ - -- Original paper: https://arxiv.org/abs/1810.01222 and https://openreview.net/forum?id=BkeU5j0ctQ -- Original Implementation: https://github.com/apourchot/CEM-RL - - -.. note:: - - CEM RL is currently implemented for TD3 - - -.. note:: - - The default policies for CEM RL differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, - to match the original paper - - -Can I use? ----------- - -- Recurrent policies: ❌ -- Multi processing: ❌ -- Gym spaces: - - -============= ====== =========== -Space Action Observation -============= ====== =========== -Discrete ❌ ❌ -Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ -============= ====== =========== - - -Example -------- - -.. code-block:: python - - import numpy as np - - from torchy_baselines import CEMRL - from torchy_baselines.td3.policies import MlpPolicy - - # n_grad = 0 corresponds to CEM (in fact CMA-ES without history) - model = CEMRL(MlpPolicy, 'Pendulum-v0', pop_size=10, n_grad=5, verbose=1) - model.learn(total_timesteps=50000, log_interval=10) - model.save("td3_pendulum") - env = model.get_env() - - del model # remove to demonstrate saving and loading - - model = CEMRL.load("td3_pendulum") - - obs = env.reset() - while True: - action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() - -Parameters ----------- - -.. autoclass:: CEMRL - :members: - :inherited-members: - -.. _cemrl_policies: - -CEM RL Policies ---------------- - -.. autoclass:: MlpPolicy - :members: - :inherited-members: diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index a4ffede..5f0fc07 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,22 +4,17 @@ import shutil import pytest import gym -from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback, EveryNTimesteps, StopTrainingOnRewardThreshold) -@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3]) +@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) def test_callbacks(model_class): log_folder = './logs/callbacks/' - kwargs = {} - if model_class == CEMRL: - kwargs['pop_size'] = 2 - kwargs['n_grad'] = 1 - # Create RL model # Small network for fast test - model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs) + model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32])) checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) diff --git a/tests/test_predict.py b/tests/test_predict.py index 2ffab6b..5fd5064 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,11 +1,10 @@ import gym import pytest -from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.vec_env import DummyVecEnv MODEL_LIST = [ - CEMRL, PPO, A2C, TD3, diff --git a/tests/test_run.py b/tests/test_run.py index db33cc7..5fd1fe1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @@ -14,12 +14,6 @@ def test_td3(action_noise): model.learn(total_timesteps=1000, eval_freq=500) -def test_cemrl(): - model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, - learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise) - model.learn(total_timesteps=1000, eval_freq=500) - - @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) def test_onpolicy(model_class, env_id): diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9a0dad8..9c528f4 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -4,12 +4,11 @@ import pytest import torch as th from copy import deepcopy -from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 +from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.identity_env import IdentityEnvBox from torchy_baselines.common.vec_env import DummyVecEnv MODEL_LIST = [ - CEMRL, PPO, A2C, TD3, diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index d80462c..73c0824 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -4,7 +4,7 @@ import numpy as np from torchy_baselines.common.running_mean_std import RunningMeanStd from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization, unwrap_vec_normalize -from torchy_baselines import CEMRL, SAC, TD3 +from torchy_baselines import SAC, TD3 ENV_ID = 'Pendulum-v0' @@ -116,7 +116,7 @@ def test_normalize_external(): assert np.all(norm_rewards < 1) -@pytest.mark.parametrize("model_class", [SAC, TD3, CEMRL]) +@pytest.mark.parametrize("model_class", [SAC, TD3]) def test_offpolicy_normalization(model_class): env = DummyVecEnv([make_env]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.) diff --git a/torchy_baselines/__init__.py b/torchy_baselines/__init__.py index acf5912..28742f5 100644 --- a/torchy_baselines/__init__.py +++ b/torchy_baselines/__init__.py @@ -1,7 +1,6 @@ import os from torchy_baselines.a2c import A2C -from torchy_baselines.cem_rl import CEMRL from torchy_baselines.ppo import PPO from torchy_baselines.sac import SAC from torchy_baselines.td3 import TD3 diff --git a/torchy_baselines/cem_rl/__init__.py b/torchy_baselines/cem_rl/__init__.py deleted file mode 100644 index 52c3c9a..0000000 --- a/torchy_baselines/cem_rl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torchy_baselines.cem_rl.cem_rl import CEMRL -from torchy_baselines.td3.policies import MlpPolicy diff --git a/torchy_baselines/cem_rl/cem.py b/torchy_baselines/cem_rl/cem.py deleted file mode 100644 index 7527b99..0000000 --- a/torchy_baselines/cem_rl/cem.py +++ /dev/null @@ -1,132 +0,0 @@ -import numpy as np -from typing import Tuple, Optional, List - - -# TODO: add more from https://github.com/hardmaru/estool/blob/master/es.py -# or https://github.com/facebookresearch/nevergrad - -class CEM(object): - """ - Cross-entropy method with diagonal covariance (separable CEM). - - :param num_params: (int) Number of parameters per individual (dimension of the problem) - :param mu_init: (np.ndarray) Initial mean of the population distribution - Taken to be zero if None is passed. - :param sigma_init: (float) Initial standard deviation of the population distribution - :param pop_size: (int) Number of individuals in the population - :param damping_init: (float) Initial value of damping for preventing from early convergence. - :param damping_final: (float) Final value of damping - :param parents: (int) Number of parents used to compute the new distribution - of individuals. - :param elitism: (bool) Keep the best known individual in the population - :param antithetic: (bool) Use a finite difference like method for sampling - (mu + epsilon, mu - epsilon) - """ - def __init__(self, - num_params: int, - mu_init: Optional[np.ndarray] = None, - sigma_init: float = 1e-3, - pop_size: int = 256, - damping_init: float = 1e-3, - damping_final: float = 1e-5, - parents: Optional[int] = None, - elitism: bool = False, - antithetic: bool = False): - super(CEM, self).__init__() - - self.num_params = num_params - - # Distribution parameters - if mu_init is None: - self.mu = np.zeros(self.num_params) - else: - self.mu = np.array(mu_init) - - self.sigma = sigma_init - # Damping parameters - self.damping = damping_init - self.damping_final = damping_final - # Exponential moving average decay for damping - self.tau = 0.95 - # Covariance matrix, here only the diagonal - self.cov = self.sigma * np.ones(self.num_params) - - # elite stuff - self.elitism = elitism - self.elite = np.sqrt(self.sigma) * np.random.rand(self.num_params) - self.elite_score = None - - # sampling parameters - self.pop_size = pop_size - self.antithetic = antithetic - - if self.antithetic: - assert (self.pop_size % 2 == 0), "Population size must be even" - - if parents is None or parents <= 0: - self.parents = pop_size // 2 - else: - self.parents = parents - - # Weighting for computing the new mean of the distributions - # from the parents. The better the individual, the higher the weight - self.weights = np.array([np.log((self.parents + 1) / i) - for i in range(1, self.parents + 1)]) - self.weights /= self.weights.sum() - - def ask(self, pop_size: int) -> List[np.ndarray]: - """ - Returns a list of candidates parameters - - :param pop_size: (int) - :return: ([np.ndarray]) - """ - if self.antithetic and not pop_size % 2: - epsilon_half = np.random.randn(pop_size // 2, self.num_params) - epsilon = np.concatenate([epsilon_half, - epsilon_half]) - else: - epsilon = np.random.randn(pop_size, self.num_params) - - individuals = self.mu + epsilon * np.sqrt(self.cov) - - # Keep the best known individual in the population - if self.elitism: - individuals[-1] = self.elite - - return individuals - - def tell(self, solutions: List[np.ndarray], scores: List[float]) -> None: - """ - Updates the distribution - - :param solutions: ([np.ndarray]) - :param scores: ([float]) episode reward. - """ - # Convert rewards (we want to maximize) to cost (we want to minimize) - scores = np.array(scores) - scores *= -1 - # Sort the individuals by fitness - idx_sorted = np.argsort(scores) - - old_mu = self.mu - # Update damping using a moving average - self.damping = self.damping * self.tau + (1 - self.tau) * self.damping_final - # self.mu = self.weights @ solutions[idx_sorted[:self.parents]] - self.mu = self.weights.dot(solutions[idx_sorted[:self.parents]]) - - # CMA-ES style would be to use the new mean here - z = (solutions[idx_sorted[:self.parents]] - old_mu) - self.cov = 1 / self.parents * self.weights.dot(z * z) + self.damping * np.ones(self.num_params) - - # Retrieve the best individual - self.elite = solutions[idx_sorted[0]] - self.elite_score = scores[idx_sorted[0]] - - def get_distrib_params(self) -> Tuple[np.ndarray, np.ndarray]: - """ - Returns the parameters of the distribution: - the mean and standard deviation. - - :return: (np.ndarray, np.ndarray) - """ - return np.copy(self.mu), np.copy(self.cov) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py deleted file mode 100644 index 1e78055..0000000 --- a/torchy_baselines/cem_rl/cem_rl.py +++ /dev/null @@ -1,217 +0,0 @@ -from typing import Type, Union, Callable, Optional, Dict, Any - -import torch as th - -from torchy_baselines.common.base_class import OffPolicyRLModel -from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback -from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.td3.td3 import TD3, TD3Policy -from torchy_baselines.cem_rl.cem import CEM - - -class CEMRL(TD3): - """ - Implementation of CEM-RL, in fact CEM combined with TD3. - - Paper: https://arxiv.org/abs/1810.01222 - Code: https://github.com/apourchot/CEM-RL - - :param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, - the same learning rate will be used for all networks (Q-Values, Actor and Value function) - it can be a function of the current progress (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. - :param action_noise: (ActionNoise) the action noise type (None by default), this can help - for hard exploration problem. Cf common.noise for the different action noise type. - :param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps - per training steps. The Q values will be updated policy_delay more often (update every training step). - :param target_policy_noise: (float) Standard deviation of Gaussian noise added to target policy - (smoothing noise) - :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise. - :param sigma_init: (float) Initial standard deviation of the population distribution - :param pop_size: (int) Number of individuals in the population - :param damping_init: (float) Initial value of damping for preventing from early convergence. - :param damping_final: (float) Final value of damping - :param elitism: (bool) Keep the best known individual in the population - :param n_grad: (int) Number of individuals that will receive a gradient update. - Half of the population size in the paper. - :param update_style: (str) Update style for the individual that will use the gradient: - - original: original implementation (actor_steps // n_grad steps for the critic - and actor_steps gradient steps per individual) - - original_td3: same as before but the target networks are only update afterward - - td3_like: use policy delay and `actor_steps` steps for both the critic and the individual - - other: `2 * (actor_steps // self.n_grad)` for the critic and the individual - :param create_eval_env: (bool) Whether to create a second environment that will be - used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. - Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance - """ - def __init__(self, policy: Union[str, Type[TD3Policy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 1e-3, - buffer_size: int = int(1e6), - learning_starts: int = 100, - batch_size: int = 100, - tau: float = 0.005, - gamma: float = 0.99, - n_episodes_rollout: int = 1, - action_noise: Optional[ActionNoise] = None, - policy_delay: int = 2, - target_policy_noise: float = 0.2, - target_noise_clip: float = 0.5, - sigma_init: float = 1e-3, - pop_size: int = 10, - damping_init: float = 1e-3, - damping_final: float = 1e-5, - elitism: bool = False, - n_grad: int = 5, - update_style: str = 'original', - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Dict[str, Any] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(CEMRL, self).__init__(policy, env, - buffer_size=buffer_size, learning_rate=learning_rate, seed=seed, device=device, - action_noise=action_noise, target_policy_noise=target_policy_noise, - target_noise_clip=target_noise_clip, learning_starts=learning_starts, - n_episodes_rollout=n_episodes_rollout, tau=tau, gamma=gamma, - policy_kwargs=policy_kwargs, verbose=verbose, - policy_delay=policy_delay, batch_size=batch_size, - create_eval_env=create_eval_env, tensorboard_log=tensorboard_log, - _init_setup_model=False) - - # Evolution strategy method that follows cma-es interface (ask-tell) - # for now, only CEM is implemented - self.es = None # type: Optional[CEM] - self.sigma_init = sigma_init - self.pop_size = pop_size - self.damping_init = damping_init - self.damping_final = damping_final - self.elitism = elitism - self.n_grad = n_grad - self.es_params = None - self.update_style = update_style - self.fitnesses = [] - - if _init_setup_model: - self._setup_model() - - def _setup_model(self) -> None: - super(CEMRL, self)._setup_model() - params_vector = self.actor.parameters_to_vector() - self.es = CEM(len(params_vector), mu_init=params_vector, - sigma_init=self.sigma_init, damping_init=self.damping_init, damping_final=self.damping_final, - pop_size=self.pop_size, antithetic=not self.pop_size % 2, parents=self.pop_size // 2, - elitism=self.elitism) - - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "CEMRL", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> OffPolicyRLModel: - - episode_num, obs, callback = self._setup_learn(eval_env, callback, eval_freq, - n_eval_episodes, eval_log_path, reset_num_timesteps) - actor_steps = 0 - continue_training = True - - callback.on_training_start(locals(), globals()) - - while self.num_timesteps < total_timesteps: - - self.fitnesses = [] - self.es_params = self.es.ask(self.pop_size) - - if self.num_timesteps > 0: - # self.train(episode_timesteps) - # Gradient steps for half of the population - for i in range(self.n_grad): - # set params - self.actor.load_from_vector(self.es_params[i]) - self.actor_target.load_from_vector(self.es_params[i]) - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), - lr=self.lr_schedule(self._current_progress)) - - # In the paper: 2 * actor_steps // self.n_grad - # In the original implementation: actor_steps // self.n_grad - # Difference with TD3 implementation: - # the target critic is updated in the train_critic() - # instead of the train_actor() and no policy delay - # Issue with this update style: the bigger the population, the slower the code - if self.update_style == 'original': - self.train_critic(actor_steps // self.n_grad, tau=self.tau) - self.train_actor(actor_steps, tau_actor=self.tau, tau_critic=0.0) - elif self.update_style == 'original_td3': - self.train_critic(actor_steps // self.n_grad, tau=0.0) - self.train_actor(actor_steps, tau_actor=self.tau, tau_critic=self.tau) - else: - # Closer to td3: with policy delay - if self.update_style == 'td3_like': - n_training_steps = actor_steps - else: - # scales with a bigger population - # but less training steps per agent - n_training_steps = 2 * (actor_steps // self.n_grad) - for it in range(n_training_steps): - # Sample replay buffer - replay_data = self.replay_buffer.sample(self.batch_size, env=self._vec_normalize_env) - self.train_critic(replay_data=replay_data) - - # Delayed policy updates - if it % self.policy_delay == 0: - self.train_actor(replay_data=replay_data, tau_actor=self.tau, tau_critic=self.tau) - - # Get the params back in the population - self.es_params[i] = self.actor.parameters_to_vector() - - actor_steps = 0 - # evaluate all actors - for params in self.es_params: - self.actor.load_from_vector(params) - - rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout, - n_steps=-1, action_noise=self.action_noise, - callback=callback, - learning_starts=self.learning_starts, - replay_buffer=self.replay_buffer, - obs=obs, episode_num=episode_num, - log_interval=log_interval) - - # Unpack - episode_reward, episode_timesteps, n_episodes, obs, continue_training = rollout - - if continue_training is False: - break - - episode_num += n_episodes - actor_steps += episode_timesteps - self.fitnesses.append(episode_reward) - - if continue_training is False: - break - - self._update_current_progress(self.num_timesteps, total_timesteps) - self.es.tell(self.es_params, self.fitnesses) - - callback.on_training_end() - - return self diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 496fe0d..d81bbbb 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -158,6 +158,7 @@ class SAC(OffPolicyRLModel): if self.ent_coef_optimizer is not None: optimizers += [self.ent_coef_optimizer] + # Update learning rate according to lr schedule self._update_learning_rate(optimizers) ent_coef_losses, ent_coefs = [], [] diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 79d3d88..9158a0d 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -126,27 +126,26 @@ class TD3(OffPolicyRLModel): self.critic_target = self.policy.critic_target self.vf_net = self.policy.vf_net - def train_critic(self, gradient_steps: int = 1, - batch_size: int = 100, - replay_data: Optional[ReplayBufferSamples] = None, - tau: float = 0.0) -> None: - # Update optimizer learning rate - self._update_learning_rate(self.critic.optimizer) + def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2) -> None: + + # Update learning rate according to lr schedule + self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) for gradient_step in range(gradient_steps): + # Sample replay buffer - if replay_data is None: - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - # Select action according to policy and add clipped noise - noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise) - noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) - next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) + with th.no_grad(): + # Select action according to policy and add clipped noise + noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise) + noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) + next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) - # Compute the target Q value - target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions) - target_q = th.min(target_q1, target_q2) - target_q = replay_data.rewards + ((1 - replay_data.dones) * self.gamma * target_q).detach() + # Compute the target Q value + target_q1, target_q2 = self.critic_target(replay_data.next_observations, next_actions) + target_q = th.min(target_q1, target_q2) + target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q # Get current Q estimates current_q1, current_q2 = self.critic(replay_data.observations, replay_data.actions) @@ -159,53 +158,22 @@ class TD3(OffPolicyRLModel): critic_loss.backward() self.critic.optimizer.step() - # Update the frozen target models - # Note: by default, for TD3, this update is done in train_actor - # however, for CEMRL it is done here - if tau > 0: - for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): - target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - - def train_actor(self, gradient_steps: int = 1, - batch_size: int = 100, - tau_actor: float = 0.005, - tau_critic: float = 0.005, - replay_data: Optional[ReplayBufferSamples] = None) -> None: - # Update optimizer learning rate - self._update_learning_rate(self.actor.optimizer) - - for gradient_step in range(gradient_steps): - # Sample replay buffer - if replay_data is None: - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - - # Compute actor loss - actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() - - # Update the frozen target models - if tau_critic > 0: - for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): - target_param.data.copy_(tau_critic * param.data + (1 - tau_critic) * target_param.data) - - for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): - target_param.data.copy_(tau_actor * param.data + (1 - tau_actor) * target_param.data) - - def train(self, gradient_steps: int, batch_size: int = 100, policy_delay: int = 2) -> None: - - for gradient_step in range(gradient_steps): - - # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - self.train_critic(replay_data=replay_data) - # Delayed policy updates if gradient_step % policy_delay == 0: - self.train_actor(replay_data=replay_data, tau_actor=self.tau, tau_critic=self.tau) + # Compute actor loss + actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update the frozen target networks + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) self._n_updates += gradient_steps logger.logkv("n_updates", self._n_updates) diff --git a/torchy_baselines/version.txt b/torchy_baselines/version.txt index f28aaa5..6495aed 100644 --- a/torchy_baselines/version.txt +++ b/torchy_baselines/version.txt @@ -1 +1 @@ -0.4.0a0 +0.4.0a1