mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-27 03:11:57 +00:00
Add type aliases for buffer samples
This commit is contained in:
parent
2ce31c1e21
commit
240833ffef
3 changed files with 13 additions and 7 deletions
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch as th
|
||||
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
|
||||
|
||||
|
||||
class BaseBuffer(object):
|
||||
|
|
@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
def _get_samples(self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None
|
||||
) -> Tuple[th.Tensor, ...]:
|
||||
) -> ReplayBufferSamples:
|
||||
data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
|
||||
self.actions[batch_inds, 0, :],
|
||||
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
|
||||
|
|
@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]:
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
||||
assert self.full, ''
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]:
|
||||
env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
|
||||
data = (self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
self.values[batch_inds].flatten(),
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@ Common aliases for type hing
|
|||
"""
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch as th
|
||||
import gym
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
|
||||
|
||||
GymEnv = Union[gym.Env, VecEnv]
|
||||
TensorDict = Dict[str, torch.Tensor]
|
||||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
# obs, action, old_values, old_log_prob, advantage, return_batch
|
||||
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
# obs, action, next_obs, done, reward
|
||||
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import numpy as np
|
|||
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.common.type_aliases import ReplayBufferSamples
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
|
||||
|
||||
|
|
@ -134,7 +135,7 @@ class TD3(OffPolicyRLModel):
|
|||
|
||||
def train_critic(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None,
|
||||
replay_data: Optional[ReplayBufferSamples] = None,
|
||||
tau: float = 0.0):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.critic.optimizer)
|
||||
|
|
@ -178,7 +179,7 @@ class TD3(OffPolicyRLModel):
|
|||
batch_size: int = 100,
|
||||
tau_actor: float = 0.005,
|
||||
tau_critic: float = 0.005,
|
||||
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None):
|
||||
replay_data: Optional[ReplayBufferSamples] = None):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.actor.optimizer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue