Revert buffer update

This commit is contained in:
Antonin RAFFIN 2019-09-21 16:03:22 +02:00
parent a196306d9e
commit dfe1ab9690

View file

@ -60,19 +60,19 @@ class ReplayBuffer(BaseBuffer):
super(ReplayBuffer, self).__init__(buffer_size, state_dim, action_dim, device, n_envs=n_envs)
assert n_envs == 1
self.states = th.zeros(self.buffer_size, self.state_dim)
self.actions = th.zeros(self.buffer_size, self.action_dim)
self.next_states = th.zeros(self.buffer_size, self.state_dim)
self.rewards = th.zeros(self.buffer_size, 1)
self.dones = th.zeros(self.buffer_size, 1)
self.states = th.zeros(self.buffer_size, self.n_envs, self.state_dim)
self.actions = th.zeros(self.buffer_size, self.n_envs, self.action_dim)
self.next_states = th.zeros(self.buffer_size, self.n_envs, self.state_dim)
self.rewards = th.zeros(self.buffer_size, self.n_envs)
self.dones = th.zeros(self.buffer_size, self.n_envs)
def add(self, state, next_state, action, reward, done):
# Copy to avoid modification by reference
self.states[self.pos] = th.FloatTensor(state[0, :])
self.next_states[self.pos] = th.FloatTensor(next_state[0, :])
self.actions[self.pos] = th.FloatTensor(action[0, :])
self.rewards[self.pos] = th.FloatTensor([reward[0]])
self.dones[self.pos] = th.FloatTensor([done[0]])
self.states[self.pos] = th.FloatTensor(np.array(state).copy())
self.next_states[self.pos] = th.FloatTensor(np.array(next_state).copy())
self.actions[self.pos] = th.FloatTensor(np.array(action).copy())
self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy())
self.dones[self.pos] = th.FloatTensor(np.array(done).copy())
self.pos += 1
if self.pos == self.buffer_size:
@ -80,9 +80,9 @@ class ReplayBuffer(BaseBuffer):
self.pos = 0
def _get_samples(self, batch_inds):
return (self.states[batch_inds].to(self.device),
self.actions[batch_inds].to(self.device),
self.next_states[batch_inds].to(self.device),
return (self.states[batch_inds, 0, :].to(self.device),
self.actions[batch_inds, 0, :].to(self.device),
self.next_states[batch_inds, 0, :].to(self.device),
self.dones[batch_inds].to(self.device),
self.rewards[batch_inds].to(self.device))
@ -129,12 +129,12 @@ class RolloutBuffer(BaseBuffer):
self.returns = self.advantages + self.values
def add(self, state, action, reward, done, value, log_prob):
self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten())
self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone())
self.states[self.pos] = th.FloatTensor(np.array(state).copy())
self.actions[self.pos] = th.FloatTensor(np.array(action).copy())
self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy())
self.dones[self.pos] = th.FloatTensor(np.array(done).copy())
self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten())
self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone())
self.pos += 1
if self.pos == self.buffer_size:
self.full = True