Type td3 policies

This commit is contained in:
Antonin RAFFIN 2020-03-16 13:31:06 +01:00
parent d4ddb3d021
commit a67bb75438
2 changed files with 38 additions and 28 deletions

View file

@ -4,8 +4,8 @@ import gym
import torch as th
import torch.nn as nn
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
create_sde_feature_extractor
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
# CAP the standard deviation of the actor
@ -234,7 +234,7 @@ class SACPolicy(BasePolicy):
return self.predict(obs, deterministic=False)
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor.forward(observation, deterministic)
return self.actor(observation, deterministic)
MlpPolicy = SACPolicy

View file

@ -1,11 +1,12 @@
import torch
from typing import Optional, List, Tuple, Callable, Union
import gym
import torch as th
import torch.nn as nn
from typing import List, Tuple, Optional
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
create_sde_feature_extractor
class Actor(BaseNetwork):
@ -76,7 +77,7 @@ class Actor(BaseNetwork):
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True)
self.mu = nn.Sequential(*actor_net)
def get_std(self) -> torch.Tensor:
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using SDE.
@ -92,7 +93,7 @@ class Actor(BaseNetwork):
mean_actions = self.mu(latent_pi)
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
def _get_latent(self, obs) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]:
latent_pi = self.latent_pi(obs)
if self.sde_feature_extractor is not None:
@ -101,7 +102,7 @@ class Actor(BaseNetwork):
latent_sde = latent_pi
return latent_pi, latent_sde
def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def evaluate_actions(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations. Only useful when using SDE.
@ -123,7 +124,7 @@ class Actor(BaseNetwork):
"""
self.action_dist.sample_weights(self.log_std)
def forward(self, obs: torch.Tensor, deterministic: bool = True) -> torch.Tensor:
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
if self.use_sde:
latent_pi, latent_sde = self._get_latent(obs)
if deterministic:
@ -162,11 +163,11 @@ class Critic(BaseNetwork):
q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
self.q2_net = nn.Sequential(*q2_net)
def forward(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
qvalue_input = th.cat([obs, action], dim=1)
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
def q1_forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
return self.q1_net(th.cat([obs, action], dim=1))
@ -175,10 +176,11 @@ class ValueFunction(BaseNetwork):
Value function for TD3 when doing on-policy exploration with SDE.
:param obs_dim: (int) Dimension of the observation
:param net_arch: ([int]) Network architecture
:param net_arch: (Optional[List[int]]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh):
def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None,
activation_fn: nn.Module = nn.Tanh):
super(ValueFunction, self).__init__()
if net_arch is None:
@ -187,7 +189,7 @@ class ValueFunction(BaseNetwork):
vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn)
self.vf_net = nn.Sequential(*vf_net)
def forward(self, obs):
def forward(self, obs: th.Tensor) -> th.Tensor:
return self.vf_net(obs)
@ -210,10 +212,18 @@ class TD3Policy(BasePolicy):
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False):
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
learning_rate: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = 'cpu',
activation_fn: nn.Module = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
clip_noise: Optional[float] = None,
lr_sde: float = 3e-4,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False):
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
# Default network architecture, from the original paper
@ -249,7 +259,7 @@ class TD3Policy(BasePolicy):
self.log_std_init = log_std_init
self._build(learning_rate)
def _build(self, learning_rate):
def _build(self, learning_rate: Callable) -> None:
self.actor = self.make_actor()
self.actor_target = self.make_actor()
self.actor_target.load_state_dict(self.actor.state_dict())
@ -262,22 +272,22 @@ class TD3Policy(BasePolicy):
if self.use_sde:
self.vf_net = ValueFunction(self.obs_dim)
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()})
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error
def reset_noise(self):
def reset_noise(self) -> None:
return self.actor.reset_noise()
def make_actor(self):
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)
def make_critic(self):
def make_critic(self) -> Critic:
return Critic(**self.net_args).to(self.device)
def forward(self, obs, deterministic=True):
return self.actor(obs, deterministic=deterministic)
def forward(self, observation: th.Tensor, deterministic: bool = False):
return self.predict(observation, deterministic=deterministic)
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic)
return self.actor(observation, deterministic=deterministic)
MlpPolicy = TD3Policy