More typing

This commit is contained in:
Antonin Raffin 2020-03-10 18:09:45 +01:00
parent 6ebad92e1b
commit 35d0d2b320
2 changed files with 38 additions and 25 deletions

View file

@ -237,7 +237,10 @@ class MlpExtractor(nn.Module):
:param activation_fn: (nn.Module) The activation function to use for the networks.
:param device: (th.device)
"""
def __init__(self, feature_dim, net_arch, activation_fn, device='cpu'):
def __init__(self, feature_dim: int,
net_arch: List[Union[int, Dict[str, List[int]]]],
activation_fn: nn.Module,
device: Union[th.device, str] = 'cpu'):
super(MlpExtractor, self).__init__()
shared_net, policy_net, value_net = [], [], []
@ -291,7 +294,7 @@ class MlpExtractor(nn.Module):
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features):
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``

View file

@ -1,13 +1,16 @@
from typing import Optional, List, Tuple, Callable, Union, Dict
from functools import partial
import gym
import torch as th
import torch.nn as nn
import numpy as np
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, \
create_sde_feature_extractor
from torchy_baselines.common.distributions import make_proba_distribution,\
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import (make_proba_distribution, Distribution,
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution)
class PPOPolicy(BasePolicy):
@ -35,12 +38,21 @@ class PPOPolicy(BasePolicy):
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
"""
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.Tanh, adam_epsilon=1e-5,
ortho_init=True, use_sde=False,
log_std_init=0.0, full_std=True,
sde_net_arch=None, use_expln=False, squash_output=False):
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
learning_rate: Callable,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = 'cpu',
activation_fn: nn.Module = nn.Tanh,
adam_epsilon: float = 1e-5,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False):
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
self.obs_dim = self.observation_space.shape[0]
@ -83,7 +95,7 @@ class PPOPolicy(BasePolicy):
self._build(learning_rate)
def reset_noise(self, n_envs: int = 1):
def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
@ -92,7 +104,7 @@ class PPOPolicy(BasePolicy):
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, learning_rate):
def _build(self, learning_rate: Callable) -> None:
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
activation_fn=self.activation_fn, device=self.device)
@ -129,7 +141,7 @@ class PPOPolicy(BasePolicy):
module.apply(partial(self.init_weights, gain=gain))
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon)
def forward(self, obs, deterministic=False):
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
if not isinstance(obs, th.Tensor):
obs = th.FloatTensor(obs).to(self.device)
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
@ -139,7 +151,7 @@ class PPOPolicy(BasePolicy):
log_prob = action_distribution.log_prob(action)
return action, value, log_prob
def _get_latent(self, obs):
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
features = self.features_extractor(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
# Features for sde
@ -148,7 +160,9 @@ class PPOPolicy(BasePolicy):
latent_sde = self.sde_feature_extractor(features)
return latent_pi, latent_vf, latent_sde
def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False):
def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
latent_sde: Optional[th.Tensor] = None,
deterministic: bool = False) -> Tuple[th.Tensor, Distribution]:
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
@ -169,7 +183,7 @@ class PPOPolicy(BasePolicy):
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
return action
def evaluate_actions(self, obs, action, deterministic=False):
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
@ -182,13 +196,9 @@ class PPOPolicy(BasePolicy):
"""
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
_, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
log_prob = action_distribution.log_prob(action)
value = self.value_net(latent_vf)
return value, log_prob, action_distribution.entropy()
def value_forward(self, obs):
_, latent_vf, _ = self._get_latent(obs)
return self.value_net(latent_vf)
log_prob = action_distribution.log_prob(actions)
values = self.value_net(latent_vf)
return values, log_prob, action_distribution.entropy()
MlpPolicy = PPOPolicy