From d4e2dc8a9cab534a5e112774914e6bae6d7fe479 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 6 Sep 2019 14:01:10 +0200 Subject: [PATCH] Add CEM-RL --- tests/test_td3.py | 11 +- torchy_baselines/__init__.py | 1 + torchy_baselines/cem_rl/__init__.py | 1 + torchy_baselines/cem_rl/cem.py | 98 +++++++++ torchy_baselines/cem_rl/cem_rl.py | 285 ++++++++++++++++++++++++++ torchy_baselines/common/evaluation.py | 5 +- torchy_baselines/common/policies.py | 35 ++++ torchy_baselines/td3/policies.py | 28 ++- torchy_baselines/td3/td3.py | 12 +- 9 files changed, 467 insertions(+), 9 deletions(-) create mode 100644 torchy_baselines/cem_rl/__init__.py create mode 100644 torchy_baselines/cem_rl/cem.py create mode 100644 torchy_baselines/cem_rl/cem_rl.py diff --git a/tests/test_td3.py b/tests/test_td3.py index a6fdb15..b67103a 100644 --- a/tests/test_td3.py +++ b/tests/test_td3.py @@ -2,7 +2,7 @@ import os import gym -from torchy_baselines import TD3 +from torchy_baselines import TD3, CEMRL def test_pendulum(): model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), start_timesteps=100, verbose=1) @@ -10,3 +10,12 @@ def test_pendulum(): model.save("test_save") model.load("test_save") os.remove("test_save.pth") + + +def test_cemrl(): + model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), pop_size=5, n_grad=2, + start_timesteps=100, verbose=1) + model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save") + model.load("test_save") + os.remove("test_save.pth") diff --git a/torchy_baselines/__init__.py b/torchy_baselines/__init__.py index 5656e71..89bea73 100644 --- a/torchy_baselines/__init__.py +++ b/torchy_baselines/__init__.py @@ -1,3 +1,4 @@ from torchy_baselines.td3 import TD3 +from torchy_baselines.cem_rl import CEMRL __version__ = "0.0.1" diff --git a/torchy_baselines/cem_rl/__init__.py b/torchy_baselines/cem_rl/__init__.py new file mode 100644 index 0000000..b93cd30 --- /dev/null +++ b/torchy_baselines/cem_rl/__init__.py @@ -0,0 +1 @@ +from torchy_baselines.cem_rl.cem_rl import CEMRL diff --git a/torchy_baselines/cem_rl/cem.py b/torchy_baselines/cem_rl/cem.py new file mode 100644 index 0000000..0ec81a2 --- /dev/null +++ b/torchy_baselines/cem_rl/cem.py @@ -0,0 +1,98 @@ +import numpy as np + + +# TODO: add more from https://github.com/hardmaru/estool/blob/master/es.py +# or https://github.com/facebookresearch/nevergrad + +class CEM(object): + + """ + Cross-entropy methods. + """ + + def __init__(self, num_params, + mu_init=None, + sigma_init=1e-3, + pop_size=256, + damp=1e-3, + damp_limit=1e-5, + parents=None, + elitism=False, + antithetic=False): + super(CEM, self).__init__() + # misc + self.num_params = num_params + + # distribution parameters + if mu_init is None: + self.mu = np.zeros(self.num_params) + else: + self.mu = np.array(mu_init) + self.sigma = sigma_init + self.damp = damp + self.damp_limit = damp_limit + self.tau = 0.95 + self.cov = self.sigma * np.ones(self.num_params) + + # elite stuff + self.elitism = elitism + self.elite = np.sqrt(self.sigma) * np.random.rand(self.num_params) + self.elite_score = None + + # sampling stuff + self.pop_size = pop_size + self.antithetic = antithetic + + if self.antithetic: + assert (self.pop_size % 2 == 0), "Population size must be even" + if parents is None or parents <= 0: + self.parents = pop_size // 2 + else: + self.parents = parents + self.weights = np.array([np.log((self.parents + 1) / i) + for i in range(1, self.parents + 1)]) + self.weights /= self.weights.sum() + + def ask(self, pop_size): + """ + Returns a list of candidates parameters + """ + if self.antithetic and not pop_size % 2: + epsilon_half = np.random.randn(pop_size // 2, self.num_params) + epsilon = np.concatenate([epsilon_half, - epsilon_half]) + + else: + epsilon = np.random.randn(pop_size, self.num_params) + + inds = self.mu + epsilon * np.sqrt(self.cov) + if self.elitism: + inds[-1] = self.elite + + return inds + + def tell(self, solutions, scores): + """ + Updates the distribution + """ + scores = np.array(scores) + scores *= -1 + idx_sorted = np.argsort(scores) + + old_mu = self.mu + self.damp = self.damp * self.tau + (1 - self.tau) * self.damp_limit + # self.mu = self.weights @ solutions[idx_sorted[:self.parents]] + self.mu = self.weights.dot(solutions[idx_sorted[:self.parents]]) + + z = (solutions[idx_sorted[:self.parents]] - old_mu) + self.cov = 1 / self.parents * self.weights.dot(z * z) + self.damp * np.ones(self.num_params) + + self.elite = solutions[idx_sorted[0]] + self.elite_score = scores[idx_sorted[0]] + # print(self.cov) + + def get_distrib_params(self): + """ + Returns the parameters of the distrubtion: + the mean and sigma + """ + return np.copy(self.mu), np.copy(self.cov) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py new file mode 100644 index 0000000..33b7b3f --- /dev/null +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -0,0 +1,285 @@ +import sys +import time + +import torch as th +import torch.nn.functional as F +import numpy as np + +from torchy_baselines import TD3 +from torchy_baselines.common.evaluation import evaluate_policy +from torchy_baselines.cem_rl.cem import CEM + + +class CEMRL(TD3): + """ + Implementation of CEM-RL + + Paper: https://arxiv.org/abs/1810.01222 + Code: https://github.com/apourchot/CEM-RL + """ + + def __init__(self, policy, env, policy_kwargs=None, verbose=0, + sigma_init=1e-3, pop_size=10, damp=1e-3, damp_limit=1e-5, + elitism=False, n_grad=5, + buffer_size=int(1e6), learning_rate=1e-3, seed=0, device='cpu', + action_noise_std=0.0, start_timesteps=100, _init_setup_model=True): + + super(CEMRL, self).__init__(policy, env, policy_kwargs, verbose, + buffer_size, learning_rate, seed, device, + action_noise_std, start_timesteps, _init_setup_model=False) + + self.es = None + self.sigma_init = sigma_init + self.pop_size = pop_size + self.damp = damp + self.damp_limit = damp_limit + self.elitism = elitism + self.n_grad = n_grad + self.es_params = None + self.fitnesses = [] + + if _init_setup_model: + self._setup_model() + + def _setup_model(self, seed=None): + super(CEMRL, self)._setup_model() + params_vector = self.actor.parameters_to_vector() + self.es = CEM(len(params_vector), mu_init=params_vector, + sigma_init=self.sigma_init, damp=self.damp, damp_limit=self.damp_limit, + pop_size=self.pop_size, antithetic=not self.pop_size % 2, parents=self.pop_size // 2, + elitism=self.elitism) + + def select_action(self, observation): + with th.no_grad(): + observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) + return self.actor(observation).cpu().data.numpy().flatten() + + def predict(self, observation, state=None, mask=None, deterministic=True): + """ + Get the model's action from an observation + + :param observation: (np.ndarray) the input observation + :param state: (np.ndarray) The last states (can be None, used in recurrent policies) + :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) + :param deterministic: (bool) Whether or not to return deterministic actions. + :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) + """ + return self.max_action * self.select_action(observation) + + def train_critic(self, n_iterations, batch_size=100, discount=0.99, + policy_noise=0.2, noise_clip=0.5): + + for it in range(n_iterations): + # Sample replay buffer + state, action, next_state, done, reward = self.replay_buffer.sample(batch_size) + + # Select action according to policy and add clipped noise + noise = action.clone().data.normal_(0, policy_noise) + noise = noise.clamp(-noise_clip, noise_clip) + next_action = (self.actor_target(next_state) + noise).clamp(-1, 1) + + # Compute the target Q value + target_q1, target_q2 = self.critic_target(next_state, next_action) + target_q = th.min(target_q1, target_q2) + target_q = reward + ((1 - done) * discount * target_q).detach() + + # Get current Q estimates + current_q1, current_q2 = self.critic(state, action) + + # Compute critic loss + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # # Update the frozen target models + # for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + # target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + def train_actor(self, n_iterations, batch_size=100, tau=0.005): + + for it in range(n_iterations): + # Sample replay buffer + state, action, next_state, done, reward = self.replay_buffer.sample(batch_size) + + # Compute actor loss + actor_loss = -self.critic.q1_forward(state, self.actor(state)).mean() + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update the frozen target models + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + def train(self, n_iterations, batch_size=100, discount=0.99, + tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2): + + for it in range(n_iterations): + + # Sample replay buffer + state, action, next_state, done, reward = self.replay_buffer.sample(batch_size) + + # Select action according to policy and add clipped noise + noise = action.clone().data.normal_(0, policy_noise) + noise = noise.clamp(-noise_clip, noise_clip) + next_action = (self.actor_target(next_state) + noise).clamp(-1, 1) + + # Compute the target Q value + target_q1, target_q2 = self.critic_target(next_state, next_action) + target_q = th.min(target_q1, target_q2) + target_q = reward + ((1 - done) * discount * target_q).detach() + + # Get current Q estimates + current_q1, current_q2 = self.critic(state, action) + + # Compute critic loss + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) + + # Optimize the critic + self.critic.optimizer.zero_grad() + critic_loss.backward() + self.critic.optimizer.step() + + # Delayed policy updates + if it % policy_freq == 0: + + # Compute actor loss + actor_loss = -self.critic.q1_forward(state, self.actor(state)).mean() + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() + + # Update the frozen target models + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + def learn(self, total_timesteps, callback=None, log_interval=100, + eval_freq=-1, n_eval_episodes=5, tb_log_name="CEMRL", reset_num_timesteps=True): + + timesteps_since_eval = 0 + actor_steps = 0 + episode_num = 0 + evaluations = [] + start_time = time.time() + + while self.num_timesteps < total_timesteps: + + self.fitnesses = [] + self.es_params = self.es.ask(self.pop_size) + + if callback is not None: + # Only stop training if return value is False, not when it is None. + if callback(locals(), globals()) is False: + break + + if self.num_timesteps > 0: + # self.train(episode_timesteps) + # Gradient steps for half of the population + for i in range(self.n_grad): + # set params + self.actor.load_from_vector(self.es_params[i]) + self.actor_target.load_from_vector(self.es_params[i]) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate) + + # In the paper: 2 * actor_steps // self.n_grad + self.train_critic(actor_steps // self.n_grad) + + self.train_actor(actor_steps) + + # Get the params back in the population + self.es_params[i] = self.actor.parameters_to_vector() + + # Evaluate episode + if 0 < eval_freq <= timesteps_since_eval: + timesteps_since_eval %= eval_freq + + self.actor.load_from_vector(self.es.mu) + + mean_reward, _ = evaluate_policy(self, self.env, n_eval_episodes) + evaluations.append(mean_reward) + + if self.verbose > 0: + print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1])) + print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time))) + sys.stdout.flush() + + actor_steps = 0 + # evaluate all actors + for params in self.es_params: + + self.actor.load_from_vector(params) + + # Reset environment + obs = self.env.reset() + episode_reward = 0 + episode_timesteps = 0 + episode_num += 1 + done = False + + while not done: + # Select action randomly or according to policy + if self.num_timesteps < self.start_timesteps: + action = self.env.action_space.sample() + else: + action = self.select_action(np.array(obs)) + + if self.action_noise_std > 0: + # NOTE: in the original implementation, the noise is applied to the unscaled action + action_noise = np.random.normal(0, self.action_noise_std, size=self.action_space.shape[0]) + action = (action + action_noise).clip(-1, 1) + + # Rescale and perform action + new_obs, reward, done, _ = self.env.step(self.max_action * action) + + if hasattr(self.env, '_max_episode_steps'): + done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done) + else: + done_bool = float(done) + + episode_reward += reward + + # Store data in replay buffer + # self.replay_buffer.add(state, next_state, action, reward, done) + self.replay_buffer.add(obs, new_obs, action, reward, done_bool) + + obs = new_obs + episode_timesteps += 1 + + if self.verbose > 1: + print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format( + self.num_timesteps, episode_num, episode_timesteps, episode_reward)) + + actor_steps += episode_timesteps + self.fitnesses.append(episode_reward) + + self.es.tell(self.es_params, self.fitnesses) + + self.num_timesteps += actor_steps + timesteps_since_eval += actor_steps + return self + + def save(self, path): + if not path.endswith('.pth'): + path += '.pth' + th.save(self.policy.state_dict(), path) + + def load(self, path, env=None, **_kwargs): + if not path.endswith('.pth'): + path += '.pth' + if env is not None: + pass + self.policy.load_state_dict(th.load(path)) + self._create_aliases() diff --git a/torchy_baselines/common/evaluation.py b/torchy_baselines/common/evaluation.py index fb3f37f..441cb1c 100644 --- a/torchy_baselines/common/evaluation.py +++ b/torchy_baselines/common/evaluation.py @@ -5,7 +5,7 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, render=F """ Runs policy for n episodes and returns average reward """ - mean_reward = 0.0 + mean_reward, n_steps = 0.0, 0 for _ in range(n_eval_episodes): obs = env.reset() done = False @@ -13,9 +13,10 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, render=F action = model.predict(np.array(obs), deterministic=deterministic) obs, reward, done, _ = env.step(action) mean_reward += reward + n_steps += 1 if render: env.render() mean_reward /= n_eval_episodes - return mean_reward + return mean_reward, n_steps diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 8047c43..212dda8 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -15,6 +15,41 @@ class BasePolicy(nn.Module): self.action_space = action_space self.device = device + def forward(self, *_args, **kwargs): + raise NotImplementedError() + + def save(self, path): + """ + Save model to a given location. + + :param path: (str) + """ + th.save(self.state_dict(), path) + + def load(self, path): + """ + Load saved model from path. + + :param path: (str) + """ + self.load_state_dict(th.load(path)) + + def load_from_vector(self, vector): + """ + Load parameters from a 1D vector. + + :param vector: (np.ndarray) + """ + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self): + """ + Convert the parameters to a 1D vector. + + :return: (np.ndarray) + """ + return th.nn.utils.parameters_to_vector(self.parameters()) + _policy_registry = dict() diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index efee4b7..e3b05c5 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -4,7 +4,31 @@ import torch.nn as nn from torchy_baselines.common.policies import BasePolicy, register_policy -class Actor(nn.Module): +class BaseNetwork(nn.Module): + """docstring for BaseNetwork.""" + + def __init__(self, device='cpu'): + super(BaseNetwork, self).__init__() + + def load_from_vector(self, vector): + """ + Load parameters from a 1D vector. + + :param vector: (np.ndarray) + """ + device = next(self.parameters()).device + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters()) + + def parameters_to_vector(self): + """ + Convert the parameters to a 1D vector. + + :return: (np.ndarray) + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().numpy() + + +class Actor(BaseNetwork): def __init__(self, state_dim, action_dim, net_arch=None): super(Actor, self).__init__() @@ -26,7 +50,7 @@ class Actor(nn.Module): return self.actor_net(x) -class Critic(nn.Module): +class Critic(BaseNetwork): def __init__(self, state_dim, action_dim, net_arch=None): super(Critic, self).__init__() diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index bba946c..c7dc2f0 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -20,11 +20,14 @@ class TD3(BaseRLModel): """ def __init__(self, policy, env, policy_kwargs=None, verbose=0, - buffer_size=int(1e6), learning_rate=1e-3, seed=0, device='cpu', + buffer_size=int(1e6), learning_rate=1e-3, seed=0, device='auto', action_noise_std=0.1, start_timesteps=100, _init_setup_model=True): super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose) + if device == 'auto': + device = 'cuda' if th.cuda.is_available() else 'cpu' + self.max_action = np.abs(self.action_space.high) self.replay_buffer = None self.device = device @@ -32,7 +35,7 @@ class TD3(BaseRLModel): self.learning_rate = learning_rate self.buffer_size = buffer_size self.start_timesteps = start_timesteps - self.seed = 0 + self.seed = seed if _init_setup_model: self._setup_model() @@ -143,9 +146,10 @@ class TD3(BaseRLModel): self.train(episode_timesteps) # Evaluate episode - if eval_freq > 0 and timesteps_since_eval >= eval_freq: + if 0 < eval_freq <= timesteps_since_eval: timesteps_since_eval %= eval_freq - evaluations.append(evaluate_policy(self, self.env, n_eval_episodes)) + mean_reward, _ = evaluate_policy(self, self.env, n_eval_episodes) + evaluations.append(mean_reward) if self.verbose > 0: print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1])) print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time)))