From 240833ffef2867b6f4d1c6fcb9335969a20b8bcd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 11 Feb 2020 17:33:22 +0100 Subject: [PATCH] Add type aliases for buffer samples --- torchy_baselines/common/buffers.py | 7 ++++--- torchy_baselines/common/type_aliases.py | 8 ++++++-- torchy_baselines/td3/td3.py | 5 +++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 0ac3e2f..c6e9d5c 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -4,6 +4,7 @@ import numpy as np import torch as th from torchy_baselines.common.vec_env import VecNormalize +from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples class BaseBuffer(object): @@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer): def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None - ) -> Tuple[th.Tensor, ...]: + ) -> ReplayBufferSamples: data = (self._normalize_obs(self.observations[batch_inds, 0, :], env), self.actions[batch_inds, 0, :], self._normalize_obs(self.next_observations[batch_inds, 0, :], env), @@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer): if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: assert self.full, '' indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer): start_idx += batch_size def _get_samples(self, batch_inds: np.ndarray, - env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]: + env: Optional[VecNormalize] = None) -> RolloutBufferSamples: data = (self.observations[batch_inds], self.actions[batch_inds], self.values[batch_inds].flatten(), diff --git a/torchy_baselines/common/type_aliases.py b/torchy_baselines/common/type_aliases.py index 8378647..b9035db 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/torchy_baselines/common/type_aliases.py @@ -3,12 +3,16 @@ Common aliases for type hing """ from typing import Union, Type, Optional, Dict, Any, List, Tuple -import torch +import torch as th import gym from torchy_baselines.common.vec_env import VecEnv GymEnv = Union[gym.Env, VecEnv] -TensorDict = Dict[str, torch.Tensor] +TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] +# obs, action, old_values, old_log_prob, advantage, return_batch +RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] +# obs, action, next_obs, done, reward +ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor] diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 6b731e3..05f1b4f 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -6,6 +6,7 @@ import numpy as np from torchy_baselines.common.base_class import OffPolicyRLModel from torchy_baselines.common.buffers import ReplayBuffer +from torchy_baselines.common.type_aliases import ReplayBufferSamples from torchy_baselines.td3.policies import TD3Policy @@ -134,7 +135,7 @@ class TD3(OffPolicyRLModel): def train_critic(self, gradient_steps: int = 1, batch_size: int = 100, - replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None, + replay_data: Optional[ReplayBufferSamples] = None, tau: float = 0.0): # Update optimizer learning rate self._update_learning_rate(self.critic.optimizer) @@ -178,7 +179,7 @@ class TD3(OffPolicyRLModel): batch_size: int = 100, tau_actor: float = 0.005, tau_critic: float = 0.005, - replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None): + replay_data: Optional[ReplayBufferSamples] = None): # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer)