mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Revert buffer update
This commit is contained in:
parent
a196306d9e
commit
dfe1ab9690
1 changed files with 15 additions and 15 deletions
|
|
@ -60,19 +60,19 @@ class ReplayBuffer(BaseBuffer):
|
|||
super(ReplayBuffer, self).__init__(buffer_size, state_dim, action_dim, device, n_envs=n_envs)
|
||||
|
||||
assert n_envs == 1
|
||||
self.states = th.zeros(self.buffer_size, self.state_dim)
|
||||
self.actions = th.zeros(self.buffer_size, self.action_dim)
|
||||
self.next_states = th.zeros(self.buffer_size, self.state_dim)
|
||||
self.rewards = th.zeros(self.buffer_size, 1)
|
||||
self.dones = th.zeros(self.buffer_size, 1)
|
||||
self.states = th.zeros(self.buffer_size, self.n_envs, self.state_dim)
|
||||
self.actions = th.zeros(self.buffer_size, self.n_envs, self.action_dim)
|
||||
self.next_states = th.zeros(self.buffer_size, self.n_envs, self.state_dim)
|
||||
self.rewards = th.zeros(self.buffer_size, self.n_envs)
|
||||
self.dones = th.zeros(self.buffer_size, self.n_envs)
|
||||
|
||||
def add(self, state, next_state, action, reward, done):
|
||||
# Copy to avoid modification by reference
|
||||
self.states[self.pos] = th.FloatTensor(state[0, :])
|
||||
self.next_states[self.pos] = th.FloatTensor(next_state[0, :])
|
||||
self.actions[self.pos] = th.FloatTensor(action[0, :])
|
||||
self.rewards[self.pos] = th.FloatTensor([reward[0]])
|
||||
self.dones[self.pos] = th.FloatTensor([done[0]])
|
||||
self.states[self.pos] = th.FloatTensor(np.array(state).copy())
|
||||
self.next_states[self.pos] = th.FloatTensor(np.array(next_state).copy())
|
||||
self.actions[self.pos] = th.FloatTensor(np.array(action).copy())
|
||||
self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy())
|
||||
self.dones[self.pos] = th.FloatTensor(np.array(done).copy())
|
||||
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
|
|
@ -80,9 +80,9 @@ class ReplayBuffer(BaseBuffer):
|
|||
self.pos = 0
|
||||
|
||||
def _get_samples(self, batch_inds):
|
||||
return (self.states[batch_inds].to(self.device),
|
||||
self.actions[batch_inds].to(self.device),
|
||||
self.next_states[batch_inds].to(self.device),
|
||||
return (self.states[batch_inds, 0, :].to(self.device),
|
||||
self.actions[batch_inds, 0, :].to(self.device),
|
||||
self.next_states[batch_inds, 0, :].to(self.device),
|
||||
self.dones[batch_inds].to(self.device),
|
||||
self.rewards[batch_inds].to(self.device))
|
||||
|
||||
|
|
@ -129,12 +129,12 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.returns = self.advantages + self.values
|
||||
|
||||
def add(self, state, action, reward, done, value, log_prob):
|
||||
self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten())
|
||||
self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone())
|
||||
self.states[self.pos] = th.FloatTensor(np.array(state).copy())
|
||||
self.actions[self.pos] = th.FloatTensor(np.array(action).copy())
|
||||
self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy())
|
||||
self.dones[self.pos] = th.FloatTensor(np.array(done).copy())
|
||||
self.values[self.pos] = th.FloatTensor(value.clone().cpu().flatten())
|
||||
self.log_probs[self.pos] = th.FloatTensor(log_prob.cpu().clone())
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue