Add type aliases for buffer samples

This commit is contained in:
Antonin Raffin 2020-02-11 17:33:22 +01:00
parent 2ce31c1e21
commit 240833ffef
3 changed files with 13 additions and 7 deletions

View file

@ -4,6 +4,7 @@ import numpy as np
import torch as th
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
class BaseBuffer(object):
@ -177,7 +178,7 @@ class ReplayBuffer(BaseBuffer):
def _get_samples(self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None
) -> Tuple[th.Tensor, ...]:
) -> ReplayBufferSamples:
data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
self.actions[batch_inds, 0, :],
self._normalize_obs(self.next_observations[batch_inds, 0, :], env),
@ -305,7 +306,7 @@ class RolloutBuffer(BaseBuffer):
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[Tuple[th.Tensor, ...], None, None]:
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ''
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
@ -325,7 +326,7 @@ class RolloutBuffer(BaseBuffer):
start_idx += batch_size
def _get_samples(self, batch_inds: np.ndarray,
env: Optional[VecNormalize] = None) -> Tuple[th.Tensor, ...]:
env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
data = (self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),

View file

@ -3,12 +3,16 @@ Common aliases for type hing
"""
from typing import Union, Type, Optional, Dict, Any, List, Tuple
import torch
import torch as th
import gym
from torchy_baselines.common.vec_env import VecEnv
GymEnv = Union[gym.Env, VecEnv]
TensorDict = Dict[str, torch.Tensor]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
# obs, action, old_values, old_log_prob, advantage, return_batch
RolloutBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]
# obs, action, next_obs, done, reward
ReplayBufferSamples = Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]

View file

@ -6,6 +6,7 @@ import numpy as np
from torchy_baselines.common.base_class import OffPolicyRLModel
from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.type_aliases import ReplayBufferSamples
from torchy_baselines.td3.policies import TD3Policy
@ -134,7 +135,7 @@ class TD3(OffPolicyRLModel):
def train_critic(self, gradient_steps: int = 1,
batch_size: int = 100,
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None,
replay_data: Optional[ReplayBufferSamples] = None,
tau: float = 0.0):
# Update optimizer learning rate
self._update_learning_rate(self.critic.optimizer)
@ -178,7 +179,7 @@ class TD3(OffPolicyRLModel):
batch_size: int = 100,
tau_actor: float = 0.005,
tau_critic: float = 0.005,
replay_data: Optional[Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]] = None):
replay_data: Optional[ReplayBufferSamples] = None):
# Update optimizer learning rate
self._update_learning_rate(self.actor.optimizer)