mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
Quick and dirty SDE version for TD3
This commit is contained in:
parent
95c741c707
commit
db87e0d36a
4 changed files with 70 additions and 17 deletions
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch as th
|
||||
from torch.distributions import Normal
|
||||
|
||||
from torchy_baselines import A2C
|
||||
from torchy_baselines import A2C, TD3
|
||||
|
||||
|
||||
def test_state_dependent_exploration():
|
||||
|
|
@ -54,4 +54,12 @@ def test_state_dependent_noise(model_class):
|
|||
model = model_class('MlpPolicy', env, n_steps=200, max_grad_norm=1, use_rms_prop=False,
|
||||
use_sde=True, ent_coef=0.00, verbose=1, create_eval_env=True, learning_rate=3e-4,
|
||||
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, net_arch=[256, dict(pi=[256], vf=[256])]), seed=None)
|
||||
model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env)
|
||||
# model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env)
|
||||
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TD3])
|
||||
def test_state_dependent_offpolicy_noise(model_class):
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
|
||||
verbose=1, policy_kwargs=dict(log_std_init=-2))
|
||||
model.learn(total_timesteps=int(20000), eval_freq=1000)
|
||||
|
|
|
|||
|
|
@ -327,6 +327,9 @@ class BaseRLModel(object):
|
|||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
if hasattr(self, 'use_sde') and self.use_sde:
|
||||
self.policy.reset_noise()
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
# Reset environment: not needed for VecEnv
|
||||
|
|
@ -338,6 +341,8 @@ class BaseRLModel(object):
|
|||
if num_timesteps < learning_starts:
|
||||
action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
if hasattr(self, 'use_sde'):
|
||||
deterministic = not self.use_sde
|
||||
action = self.predict(obs, deterministic=deterministic)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,46 @@
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
|
||||
|
||||
|
||||
class Actor(BaseNetwork):
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU):
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=0.0, clip_noise=0.1):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
# TODO: orthogonal initialization?
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
self.latent_pi, self.log_std = None, None
|
||||
self.weights_dist, self.exploration_mat = None, None
|
||||
self.use_sde = use_sde
|
||||
|
||||
def forward(self, obs):
|
||||
return self.actor_net(obs)
|
||||
if use_sde:
|
||||
latent_dim = net_arch[-1]
|
||||
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
|
||||
self.latent_pi = nn.Sequential(*latent_pi)
|
||||
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
|
||||
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
|
||||
self.clip_noise = clip_noise
|
||||
self.reset_noise()
|
||||
else:
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
|
||||
def reset_noise(self):
|
||||
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
||||
def forward(self, obs, deterministic=True):
|
||||
if self.use_sde:
|
||||
latent_pi = self.latent_pi(obs)
|
||||
if deterministic:
|
||||
return self.actor_net(latent_pi)
|
||||
noise = th.mm(latent_pi.detach(), self.exploration_mat)
|
||||
# noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
|
||||
# TODO: fix clipping
|
||||
return th.clamp(self.actor_net(latent_pi) + noise, -1, 1)
|
||||
else:
|
||||
return self.actor_net(obs)
|
||||
|
||||
|
||||
class Critic(BaseNetwork):
|
||||
|
|
@ -40,7 +67,7 @@ class Critic(BaseNetwork):
|
|||
class TD3Policy(BasePolicy):
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU):
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=0.0):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -56,8 +83,14 @@ class TD3Policy(BasePolicy):
|
|||
'net_arch': self.net_arch,
|
||||
'activation_fn': self.activation_fn
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.actor_kwargs['use_sde'] = use_sde
|
||||
self.actor_kwargs['log_std_init'] = log_std_init
|
||||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
self.use_sde = use_sde
|
||||
self.log_std_init = log_std_init
|
||||
self._build(learning_rate)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
|
|
@ -71,14 +104,17 @@ class TD3Policy(BasePolicy):
|
|||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
|
||||
|
||||
def reset_noise(self):
|
||||
return self.actor.reset_noise()
|
||||
|
||||
def make_actor(self):
|
||||
return Actor(**self.net_args).to(self.device)
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self):
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
|
||||
def forward(self, obs):
|
||||
return self.actor(obs)
|
||||
def forward(self, obs, deterministic=True):
|
||||
return self.actor(obs, deterministic=deterministic)
|
||||
|
||||
|
||||
MlpPolicy = TD3Policy
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class TD3(BaseRLModel):
|
|||
:param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
|
||||
(smoothing noise)
|
||||
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -50,6 +52,7 @@ class TD3(BaseRLModel):
|
|||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
use_sde=False,
|
||||
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
|
||||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
|
|
@ -70,6 +73,7 @@ class TD3(BaseRLModel):
|
|||
self.policy_delay = policy_delay
|
||||
self.target_noise_clip = target_noise_clip
|
||||
self.target_policy_noise = target_policy_noise
|
||||
self.use_sde = use_sde
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -79,8 +83,8 @@ class TD3(BaseRLModel):
|
|||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy(self.observation_space, self.action_space, self.learning_rate,
|
||||
use_sde=self.use_sde, device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
|
|
@ -90,12 +94,12 @@ class TD3(BaseRLModel):
|
|||
self.critic = self.policy.critic
|
||||
self.critic_target = self.policy.critic_target
|
||||
|
||||
def select_action(self, observation):
|
||||
def select_action(self, observation, deterministic=True):
|
||||
# Normally not needed
|
||||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.actor(observation).cpu().numpy()
|
||||
return self.actor(observation, deterministic=deterministic).cpu().numpy()
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=True):
|
||||
"""
|
||||
|
|
@ -107,7 +111,7 @@ class TD3(BaseRLModel):
|
|||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
return self.unscale_action(self.select_action(observation))
|
||||
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
|
||||
|
||||
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
|
||||
# Update optimizer learning rate
|
||||
|
|
|
|||
Loading…
Reference in a new issue