mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
Type td3 policies
This commit is contained in:
parent
d4ddb3d021
commit
a67bb75438
2 changed files with 38 additions and 28 deletions
|
|
@ -4,8 +4,8 @@ import gym
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
|
||||
create_sde_feature_extractor
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -234,7 +234,7 @@ class SACPolicy(BasePolicy):
|
|||
return self.predict(obs, deterministic=False)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor.forward(observation, deterministic)
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
|
||||
MlpPolicy = SACPolicy
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
from typing import Optional, List, Tuple, Callable, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
|
||||
create_sde_feature_extractor
|
||||
|
||||
|
||||
class Actor(BaseNetwork):
|
||||
|
|
@ -76,7 +77,7 @@ class Actor(BaseNetwork):
|
|||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def get_std(self) -> torch.Tensor:
|
||||
def get_std(self) -> th.Tensor:
|
||||
"""
|
||||
Retrieve the standard deviation of the action distribution.
|
||||
Only useful when using SDE.
|
||||
|
|
@ -92,7 +93,7 @@ class Actor(BaseNetwork):
|
|||
mean_actions = self.mu(latent_pi)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
|
||||
def _get_latent(self, obs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
latent_pi = self.latent_pi(obs)
|
||||
|
||||
if self.sde_feature_extractor is not None:
|
||||
|
|
@ -101,7 +102,7 @@ class Actor(BaseNetwork):
|
|||
latent_sde = latent_pi
|
||||
return latent_pi, latent_sde
|
||||
|
||||
def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def evaluate_actions(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations. Only useful when using SDE.
|
||||
|
|
@ -123,7 +124,7 @@ class Actor(BaseNetwork):
|
|||
"""
|
||||
self.action_dist.sample_weights(self.log_std)
|
||||
|
||||
def forward(self, obs: torch.Tensor, deterministic: bool = True) -> torch.Tensor:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
if self.use_sde:
|
||||
latent_pi, latent_sde = self._get_latent(obs)
|
||||
if deterministic:
|
||||
|
|
@ -162,11 +163,11 @@ class Critic(BaseNetwork):
|
|||
q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q2_net = nn.Sequential(*q2_net)
|
||||
|
||||
def forward(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
qvalue_input = th.cat([obs, action], dim=1)
|
||||
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
|
||||
|
||||
def q1_forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
||||
def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
||||
return self.q1_net(th.cat([obs, action], dim=1))
|
||||
|
||||
|
||||
|
|
@ -175,10 +176,11 @@ class ValueFunction(BaseNetwork):
|
|||
Value function for TD3 when doing on-policy exploration with SDE.
|
||||
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param net_arch: (Optional[List[int]]) Network architecture
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
"""
|
||||
def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh):
|
||||
def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None,
|
||||
activation_fn: nn.Module = nn.Tanh):
|
||||
super(ValueFunction, self).__init__()
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -187,7 +189,7 @@ class ValueFunction(BaseNetwork):
|
|||
vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn)
|
||||
self.vf_net = nn.Sequential(*vf_net)
|
||||
|
||||
def forward(self, obs):
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
return self.vf_net(obs)
|
||||
|
||||
|
||||
|
|
@ -210,10 +212,18 @@ class TD3Policy(BasePolicy):
|
|||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
|
||||
clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False):
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
learning_rate: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
clip_noise: Optional[float] = None,
|
||||
lr_sde: float = 3e-4,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
|
|
@ -249,7 +259,7 @@ class TD3Policy(BasePolicy):
|
|||
self.log_std_init = log_std_init
|
||||
self._build(learning_rate)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
def _build(self, learning_rate: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor_target = self.make_actor()
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
|
|
@ -262,22 +272,22 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
if self.use_sde:
|
||||
self.vf_net = ValueFunction(self.obs_dim)
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()})
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error
|
||||
|
||||
def reset_noise(self):
|
||||
def reset_noise(self) -> None:
|
||||
return self.actor.reset_noise()
|
||||
|
||||
def make_actor(self):
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self):
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
|
||||
def forward(self, obs, deterministic=True):
|
||||
return self.actor(obs, deterministic=deterministic)
|
||||
def forward(self, observation: th.Tensor, deterministic: bool = False):
|
||||
return self.predict(observation, deterministic=deterministic)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.forward(observation, deterministic)
|
||||
return self.actor(observation, deterministic=deterministic)
|
||||
|
||||
|
||||
MlpPolicy = TD3Policy
|
||||
|
|
|
|||
Loading…
Reference in a new issue