From a67bb75438546c2e7532318624f672f07423cc7f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 16 Mar 2020 13:31:06 +0100 Subject: [PATCH] Type td3 policies --- torchy_baselines/sac/policies.py | 6 ++-- torchy_baselines/td3/policies.py | 60 +++++++++++++++++++------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 00d7457..46d8a2e 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -4,8 +4,8 @@ import gym import torch as th import torch.nn as nn -from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \ - create_sde_feature_extractor +from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, + create_sde_feature_extractor) from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution # CAP the standard deviation of the actor @@ -234,7 +234,7 @@ class SACPolicy(BasePolicy): return self.predict(obs, deterministic=False) def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: - return self.actor.forward(observation, deterministic) + return self.actor(observation, deterministic) MlpPolicy = SACPolicy diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index fc86c77..d7d3e81 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -1,11 +1,12 @@ -import torch +from typing import Optional, List, Tuple, Callable, Union + +import gym import torch as th import torch.nn as nn -from typing import List, Tuple, Optional +from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, + create_sde_feature_extractor) from torchy_baselines.common.distributions import StateDependentNoiseDistribution -from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \ - create_sde_feature_extractor class Actor(BaseNetwork): @@ -76,7 +77,7 @@ class Actor(BaseNetwork): actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True) self.mu = nn.Sequential(*actor_net) - def get_std(self) -> torch.Tensor: + def get_std(self) -> th.Tensor: """ Retrieve the standard deviation of the action distribution. Only useful when using SDE. @@ -92,7 +93,7 @@ class Actor(BaseNetwork): mean_actions = self.mu(latent_pi) return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) - def _get_latent(self, obs) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]: latent_pi = self.latent_pi(obs) if self.sde_feature_extractor is not None: @@ -101,7 +102,7 @@ class Actor(BaseNetwork): latent_sde = latent_pi return latent_pi, latent_sde - def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def evaluate_actions(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, given the observations. Only useful when using SDE. @@ -123,7 +124,7 @@ class Actor(BaseNetwork): """ self.action_dist.sample_weights(self.log_std) - def forward(self, obs: torch.Tensor, deterministic: bool = True) -> torch.Tensor: + def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: if self.use_sde: latent_pi, latent_sde = self._get_latent(obs) if deterministic: @@ -162,11 +163,11 @@ class Critic(BaseNetwork): q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn) self.q2_net = nn.Sequential(*q2_net) - def forward(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: qvalue_input = th.cat([obs, action], dim=1) return self.q1_net(qvalue_input), self.q2_net(qvalue_input) - def q1_forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor: return self.q1_net(th.cat([obs, action], dim=1)) @@ -175,10 +176,11 @@ class ValueFunction(BaseNetwork): Value function for TD3 when doing on-policy exploration with SDE. :param obs_dim: (int) Dimension of the observation - :param net_arch: ([int]) Network architecture + :param net_arch: (Optional[List[int]]) Network architecture :param activation_fn: (nn.Module) Activation function """ - def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh): + def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None, + activation_fn: nn.Module = nn.Tanh): super(ValueFunction, self).__init__() if net_arch is None: @@ -187,7 +189,7 @@ class ValueFunction(BaseNetwork): vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn) self.vf_net = nn.Sequential(*vf_net) - def forward(self, obs): + def forward(self, obs: th.Tensor) -> th.Tensor: return self.vf_net(obs) @@ -210,10 +212,18 @@ class TD3Policy(BasePolicy): a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. """ - def __init__(self, observation_space, action_space, - learning_rate, net_arch=None, device='cpu', - activation_fn=nn.ReLU, use_sde=False, log_std_init=-3, - clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False): + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + learning_rate: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = 'cpu', + activation_fn: nn.Module = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + clip_noise: Optional[float] = None, + lr_sde: float = 3e-4, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False): super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True) # Default network architecture, from the original paper @@ -249,7 +259,7 @@ class TD3Policy(BasePolicy): self.log_std_init = log_std_init self._build(learning_rate) - def _build(self, learning_rate): + def _build(self, learning_rate: Callable) -> None: self.actor = self.make_actor() self.actor_target = self.make_actor() self.actor_target.load_state_dict(self.actor.state_dict()) @@ -262,22 +272,22 @@ class TD3Policy(BasePolicy): if self.use_sde: self.vf_net = ValueFunction(self.obs_dim) - self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) + self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error - def reset_noise(self): + def reset_noise(self) -> None: return self.actor.reset_noise() - def make_actor(self): + def make_actor(self) -> Actor: return Actor(**self.actor_kwargs).to(self.device) - def make_critic(self): + def make_critic(self) -> Critic: return Critic(**self.net_args).to(self.device) - def forward(self, obs, deterministic=True): - return self.actor(obs, deterministic=deterministic) + def forward(self, observation: th.Tensor, deterministic: bool = False): + return self.predict(observation, deterministic=deterministic) def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: - return self.forward(observation, deterministic) + return self.actor(observation, deterministic=deterministic) MlpPolicy = TD3Policy