From dfe1ab96907924f796b0efc893c879c74f6ff31a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 21 Sep 2019 16:03:22 +0200 Subject: [PATCH] Revert buffer update --- torchy_baselines/common/buffers.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index c16dbe8..32c37b3 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -60,19 +60,19 @@ class ReplayBuffer(BaseBuffer): super(ReplayBuffer, self).__init__(buffer_size, state_dim, action_dim, device, n_envs=n_envs) assert n_envs == 1 - self.states = th.zeros(self.buffer_size, self.state_dim) - self.actions = th.zeros(self.buffer_size, self.action_dim) - self.next_states = th.zeros(self.buffer_size, self.state_dim) - self.rewards = th.zeros(self.buffer_size, 1) - self.dones = th.zeros(self.buffer_size, 1) + self.states = th.zeros(self.buffer_size, self.n_envs, self.state_dim) + self.actions = th.zeros(self.buffer_size, self.n_envs, self.action_dim) + self.next_states = th.zeros(self.buffer_size, self.n_envs, self.state_dim) + self.rewards = th.zeros(self.buffer_size, self.n_envs) + self.dones = th.zeros(self.buffer_size, self.n_envs) def add(self, state, next_state, action, reward, done): # Copy to avoid modification by reference - self.states[self.pos] = th.FloatTensor(state[0, :]) - self.next_states[self.pos] = th.FloatTensor(next_state[0, :]) - self.actions[self.pos] = th.FloatTensor(action[0, :]) - self.rewards[self.pos] = th.FloatTensor([reward[0]]) - self.dones[self.pos] = th.FloatTensor([done[0]]) + self.states[self.pos] = th.FloatTensor(np.array(state).copy()) + self.next_states[self.pos] = th.FloatTensor(np.array(next_state).copy()) + self.actions[self.pos] = th.FloatTensor(np.array(action).copy()) + self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy()) + self.dones[self.pos] = th.FloatTensor(np.array(done).copy()) self.pos += 1 if self.pos == self.buffer_size: @@ -80,9 +80,9 @@ class ReplayBuffer(BaseBuffer): self.pos = 0 def _get_samples(self, batch_inds): - return (self.states[batch_inds].to(self.device), - self.actions[batch_inds].to(self.device), - self.next_states[batch_inds].to(self.device), + return (self.states[batch_inds, 0, :].to(self.device), + self.actions[batch_inds, 0, :].to(self.device), + self.next_states[batch_inds, 0, :].to(self.device), self.dones[batch_inds].to(self.device), self.rewards[batch_inds].to(self.device)) @@ -129,12 +129,12 @@ class RolloutBuffer(BaseBuffer): self.returns = self.advantages + self.values def add(self, state, action, reward, done, value, log_prob): - self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten()) - self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone()) self.states[self.pos] = th.FloatTensor(np.array(state).copy()) self.actions[self.pos] = th.FloatTensor(np.array(action).copy()) self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy()) self.dones[self.pos] = th.FloatTensor(np.array(done).copy()) + self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten()) + self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone()) self.pos += 1 if self.pos == self.buffer_size: self.full = True