diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index abab7d8..a5e08b5 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -237,7 +237,10 @@ class MlpExtractor(nn.Module): :param activation_fn: (nn.Module) The activation function to use for the networks. :param device: (th.device) """ - def __init__(self, feature_dim, net_arch, activation_fn, device='cpu'): + def __init__(self, feature_dim: int, + net_arch: List[Union[int, Dict[str, List[int]]]], + activation_fn: nn.Module, + device: Union[th.device, str] = 'cpu'): super(MlpExtractor, self).__init__() shared_net, policy_net, value_net = [], [], [] @@ -291,7 +294,7 @@ class MlpExtractor(nn.Module): self.policy_net = nn.Sequential(*policy_net).to(device) self.value_net = nn.Sequential(*value_net).to(device) - def forward(self, features): + def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 3e47375..491737a 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -1,13 +1,16 @@ +from typing import Optional, List, Tuple, Callable, Union, Dict from functools import partial +import gym import torch as th import torch.nn as nn import numpy as np -from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, \ - create_sde_feature_extractor -from torchy_baselines.common.distributions import make_proba_distribution,\ - DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution +from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor, + create_sde_feature_extractor) +from torchy_baselines.common.distributions import (make_proba_distribution, Distribution, + DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution) + class PPOPolicy(BasePolicy): @@ -35,12 +38,21 @@ class PPOPolicy(BasePolicy): :param squash_output: (bool) Whether to squash the output using a tanh function, this allows to ensure boundaries when using SDE. """ - def __init__(self, observation_space, action_space, - learning_rate, net_arch=None, device='cpu', - activation_fn=nn.Tanh, adam_epsilon=1e-5, - ortho_init=True, use_sde=False, - log_std_init=0.0, full_std=True, - sde_net_arch=None, use_expln=False, squash_output=False): + def __init__(self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + learning_rate: Callable, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + device: Union[th.device, str] = 'cpu', + activation_fn: nn.Module = nn.Tanh, + adam_epsilon: float = 1e-5, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False): super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output) self.obs_dim = self.observation_space.shape[0] @@ -83,7 +95,7 @@ class PPOPolicy(BasePolicy): self._build(learning_rate) - def reset_noise(self, n_envs: int = 1): + def reset_noise(self, n_envs: int = 1) -> None: """ Sample new weights for the exploration matrix. @@ -92,7 +104,7 @@ class PPOPolicy(BasePolicy): assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE' self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - def _build(self, learning_rate): + def _build(self, learning_rate: Callable) -> None: self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device) @@ -129,7 +141,7 @@ class PPOPolicy(BasePolicy): module.apply(partial(self.init_weights, gain=gain)) self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon) - def forward(self, obs, deterministic=False): + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: if not isinstance(obs, th.Tensor): obs = th.FloatTensor(obs).to(self.device) latent_pi, latent_vf, latent_sde = self._get_latent(obs) @@ -139,7 +151,7 @@ class PPOPolicy(BasePolicy): log_prob = action_distribution.log_prob(action) return action, value, log_prob - def _get_latent(self, obs): + def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: features = self.features_extractor(obs) latent_pi, latent_vf = self.mlp_extractor(features) # Features for sde @@ -148,7 +160,9 @@ class PPOPolicy(BasePolicy): latent_sde = self.sde_feature_extractor(features) return latent_pi, latent_vf, latent_sde - def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False): + def _get_action_dist_from_latent(self, latent_pi: th.Tensor, + latent_sde: Optional[th.Tensor] = None, + deterministic: bool = False) -> Tuple[th.Tensor, Distribution]: mean_actions = self.action_net(latent_pi) if isinstance(self.action_dist, DiagGaussianDistribution): @@ -169,7 +183,7 @@ class PPOPolicy(BasePolicy): action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic) return action - def evaluate_actions(self, obs, action, deterministic=False): + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, given the observations. @@ -182,13 +196,9 @@ class PPOPolicy(BasePolicy): """ latent_pi, latent_vf, latent_sde = self._get_latent(obs) _, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic) - log_prob = action_distribution.log_prob(action) - value = self.value_net(latent_vf) - return value, log_prob, action_distribution.entropy() - - def value_forward(self, obs): - _, latent_vf, _ = self._get_latent(obs) - return self.value_net(latent_vf) + log_prob = action_distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, action_distribution.entropy() MlpPolicy = PPOPolicy