diff --git a/tests/test_run.py b/tests/test_run.py index e28c532..b3ed1a1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -11,7 +11,6 @@ def test_pendulum(): model.load("test_save") os.remove("test_save.pth") - def test_cemrl(): model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, start_timesteps=100, verbose=1) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 7ee555a..7554798 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -1,11 +1,8 @@ import numpy as np import torch as th -from torchy_baselines.common.utils import discount_cumsum class BaseBuffer(object): - """docstring for BaseBuffer.""" - def __init__(self, buffer_size, state_dim, action_dim, device='cpu'): super(BaseBuffer, self).__init__() self.buffer_size = buffer_size @@ -88,25 +85,23 @@ class RolloutBuffer(BaseBuffer): self.values = th.zeros(self.buffer_size, 1) self.log_probs = th.zeros(self.buffer_size, 1) self.advantages = th.zeros(self.buffer_size, 1) - self.path_start_idx = 0 - def finish_path(self, last_value=0): + def compute_returns_and_advantage(self, last_value, done=False): """ - From https://github.com/openai/spinningup/blob/master/spinup/algos/ppo/ppo.py + From PPO2 """ - # No use of dones? - path_slice = slice(self.path_start_idx, self.pos) - rewards = np.append(self.rewards[path_slice].detach().cpu().numpy(), last_value) - values = np.append(self.values[path_slice].detach().cpu().numpy(), last_value) - - # the next two lines implement GAE-Lambda advantage calculation - deltas = rewards[:-1] + self.gamma * values[1:] - values[:-1] - - self.advantages[path_slice, 0] = th.FloatTensor(discount_cumsum(deltas, self.gamma * self.lambda_).copy()) - # the next line computes rewards-to-go, to be targets for the value function - self.returns[path_slice, 0] = th.FloatTensor(discount_cumsum(rewards, self.gamma)[:-1].copy()) - - self.path_start_idx = self.pos + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - float(done) + next_value = last_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + next_value = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.lambda_ * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + self.returns = self.advantages + self.values def add(self, state, action, reward, done, value, log_prob): self.values[self.pos] = th.FloatTensor([value]) @@ -119,14 +114,9 @@ class RolloutBuffer(BaseBuffer): if self.pos == self.buffer_size: self.full = True - def reset(self): - self.path_start_idx = 0 - super(RolloutBuffer, self).reset() - def get(self, batch_size): assert self.full indices = th.randperm(self.buffer_size) - minibatch_indices = [] start_idx = 0 while start_idx < self.buffer_size: yield self._get_samples(indices[start_idx:start_idx + batch_size]) @@ -135,6 +125,6 @@ class RolloutBuffer(BaseBuffer): def _get_samples(self, batch_inds): return (self.states[batch_inds].to(self.device), self.actions[batch_inds].to(self.device), - self.log_probs[batch_inds].to(self.device), - self.advantages[batch_inds].to(self.device), - self.returns[batch_inds].to(self.device)) + self.log_probs[batch_inds].flatten().to(self.device), + self.advantages[batch_inds].flatten().to(self.device), + self.returns[batch_inds].flatten().to(self.device)) diff --git a/torchy_baselines/common/utils.py b/torchy_baselines/common/utils.py index 71f6182..da3a454 100644 --- a/torchy_baselines/common/utils.py +++ b/torchy_baselines/common/utils.py @@ -21,38 +21,6 @@ def set_random_seed(seed, using_cuda=False): th.cuda.manual_seed(seed) -# From stable_baselines.common.math_util -# def discount(vector, gamma): -# """ -# computes discounted sums along 0th dimension of vector x. -# y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], -# where k = len(x) - t - 1 -# -# :param vector: (np.ndarray) the input vector -# :param gamma: (float) the discount value -# :return: (np.ndarray) the output vector -# """ -# assert vector.ndim >= 1 -# return scipy.signal.lfilter([1], [1, -gamma], vector[::-1], axis=0)[::-1] - - -def discount_cumsum(x, discount): - """ - magic from rllab for computing discounted cumulative sums of vectors. - - input: - vector x, - [x0, - x1, - x2] - - output: - [x0 + discount * x1 + discount^2 * x2, - x1 + discount * x2, - x2] - """ - return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1] - # From stable baselines def explained_variance(y_pred, y_true): """ diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 0ee08da..fcb8a28 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -88,19 +88,19 @@ class PPO(BaseRLModel): # No grad ok? with th.no_grad(): action, value, log_prob = self.policy.forward(obs) - action = action[0].cpu().numpy() + action = action.flatten().cpu().numpy() # Rescale and perform action - obs, reward, done, _ = env.step(np.clip(action, -self.max_action, self.max_action)) + new_obs, reward, done, _ = env.step(np.clip(action, -self.max_action, self.max_action)) n_steps += 1 rollout_buffer.add(obs, action, reward, float(done), value, log_prob) + obs = new_obs if done: - value = 0.0 obs = None - rollout_buffer.finish_path(last_value=value) + rollout_buffer.compute_returns_and_advantage(value, done=done) return obs @@ -113,7 +113,6 @@ class PPO(BaseRLModel): # Unpack state, action, old_log_prob, advantage, return_batch = replay_data - # _, values, log_prob = self.policy.forward(state) values, log_prob, entropy = self.policy.get_policy_stats(state, action) # Normalize advantage @@ -121,11 +120,12 @@ class PPO(BaseRLModel): advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) ratio = th.exp(log_prob - old_log_prob) + policy_loss_1 = advantage * ratio policy_loss_2 = advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # value_loss = th.mean((return_batch - value)**2) - value_loss = F.mse_loss(return_batch, values) + value_loss = F.mse_loss(return_batch, values.flatten()) entropy_loss = th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # loss = policy_loss @@ -138,7 +138,8 @@ class PPO(BaseRLModel): # TODO: clip grad norm? # nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - print(explained_variance(return_batch.numpy()[:, 0], values[:, 0].detach().cpu().numpy())) + # print(value_loss.item()) + # print(explained_variance(return_batch.numpy(), values.flatten().detach().cpu().numpy())) def learn(self, total_timesteps, callback=None, log_interval=100, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", reset_num_timesteps=True):