mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-27 22:55:17 +00:00
More typing
This commit is contained in:
parent
6ebad92e1b
commit
35d0d2b320
2 changed files with 38 additions and 25 deletions
|
|
@ -237,7 +237,10 @@ class MlpExtractor(nn.Module):
|
|||
:param activation_fn: (nn.Module) The activation function to use for the networks.
|
||||
:param device: (th.device)
|
||||
"""
|
||||
def __init__(self, feature_dim, net_arch, activation_fn, device='cpu'):
|
||||
def __init__(self, feature_dim: int,
|
||||
net_arch: List[Union[int, Dict[str, List[int]]]],
|
||||
activation_fn: nn.Module,
|
||||
device: Union[th.device, str] = 'cpu'):
|
||||
super(MlpExtractor, self).__init__()
|
||||
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
|
|
@ -291,7 +294,7 @@ class MlpExtractor(nn.Module):
|
|||
self.policy_net = nn.Sequential(*policy_net).to(device)
|
||||
self.value_net = nn.Sequential(*value_net).to(device)
|
||||
|
||||
def forward(self, features):
|
||||
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Dict
|
||||
from functools import partial
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, MlpExtractor, \
|
||||
create_sde_feature_extractor
|
||||
from torchy_baselines.common.distributions import make_proba_distribution,\
|
||||
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import (make_proba_distribution, Distribution,
|
||||
DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
|
||||
class PPOPolicy(BasePolicy):
|
||||
|
|
@ -35,12 +38,21 @@ class PPOPolicy(BasePolicy):
|
|||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.Tanh, adam_epsilon=1e-5,
|
||||
ortho_init=True, use_sde=False,
|
||||
log_std_init=0.0, full_std=True,
|
||||
sde_net_arch=None, use_expln=False, squash_output=False):
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
learning_rate: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: nn.Module = nn.Tanh,
|
||||
adam_epsilon: float = 1e-5,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
|
|
@ -83,7 +95,7 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
self._build(learning_rate)
|
||||
|
||||
def reset_noise(self, n_envs: int = 1):
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix.
|
||||
|
||||
|
|
@ -92,7 +104,7 @@ class PPOPolicy(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
def _build(self, learning_rate: Callable) -> None:
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn, device=self.device)
|
||||
|
||||
|
|
@ -129,7 +141,7 @@ class PPOPolicy(BasePolicy):
|
|||
module.apply(partial(self.init_weights, gain=gain))
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon)
|
||||
|
||||
def forward(self, obs, deterministic=False):
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
if not isinstance(obs, th.Tensor):
|
||||
obs = th.FloatTensor(obs).to(self.device)
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
|
|
@ -139,7 +151,7 @@ class PPOPolicy(BasePolicy):
|
|||
log_prob = action_distribution.log_prob(action)
|
||||
return action, value, log_prob
|
||||
|
||||
def _get_latent(self, obs):
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
features = self.features_extractor(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
# Features for sde
|
||||
|
|
@ -148,7 +160,9 @@ class PPOPolicy(BasePolicy):
|
|||
latent_sde = self.sde_feature_extractor(features)
|
||||
return latent_pi, latent_vf, latent_sde
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi, latent_sde=None, deterministic=False):
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
|
||||
latent_sde: Optional[th.Tensor] = None,
|
||||
deterministic: bool = False) -> Tuple[th.Tensor, Distribution]:
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
|
|
@ -169,7 +183,7 @@ class PPOPolicy(BasePolicy):
|
|||
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
|
||||
return action
|
||||
|
||||
def evaluate_actions(self, obs, action, deterministic=False):
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
|
@ -182,13 +196,9 @@ class PPOPolicy(BasePolicy):
|
|||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
_, action_distribution = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
value = self.value_net(latent_vf)
|
||||
return value, log_prob, action_distribution.entropy()
|
||||
|
||||
def value_forward(self, obs):
|
||||
_, latent_vf, _ = self._get_latent(obs)
|
||||
return self.value_net(latent_vf)
|
||||
log_prob = action_distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, action_distribution.entropy()
|
||||
|
||||
|
||||
MlpPolicy = PPOPolicy
|
||||
|
|
|
|||
Loading…
Reference in a new issue