Quick and dirty SDE version for TD3

This commit is contained in:
Antonin Raffin 2019-11-07 17:31:52 +01:00
parent 95c741c707
commit db87e0d36a
4 changed files with 70 additions and 17 deletions

View file

@ -3,7 +3,7 @@ import pytest
import torch as th
from torch.distributions import Normal
from torchy_baselines import A2C
from torchy_baselines import A2C, TD3
def test_state_dependent_exploration():
@ -54,4 +54,12 @@ def test_state_dependent_noise(model_class):
model = model_class('MlpPolicy', env, n_steps=200, max_grad_norm=1, use_rms_prop=False,
use_sde=True, ent_coef=0.00, verbose=1, create_eval_env=True, learning_rate=3e-4,
policy_kwargs=dict(log_std_init=0.0, ortho_init=False, net_arch=[256, dict(pi=[256], vf=[256])]), seed=None)
model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env)
# model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env)
model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env)
@pytest.mark.parametrize("model_class", [TD3])
def test_state_dependent_offpolicy_noise(model_class):
model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
verbose=1, policy_kwargs=dict(log_std_init=-2))
model.learn(total_timesteps=int(20000), eval_freq=1000)

View file

@ -327,6 +327,9 @@ class BaseRLModel(object):
assert isinstance(env, VecEnv)
assert env.num_envs == 1
if hasattr(self, 'use_sde') and self.use_sde:
self.policy.reset_noise()
while total_steps < n_steps or total_episodes < n_episodes:
done = False
# Reset environment: not needed for VecEnv
@ -338,6 +341,8 @@ class BaseRLModel(object):
if num_timesteps < learning_starts:
action = np.array([self.action_space.sample()])
else:
if hasattr(self, 'use_sde'):
deterministic = not self.use_sde
action = self.predict(obs, deterministic=deterministic)
# Rescale the action from [low, high] to [-1, 1]

View file

@ -1,19 +1,46 @@
import torch as th
import torch.nn as nn
from torch.distributions import Normal
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork
class Actor(BaseNetwork):
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU):
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=0.0, clip_noise=0.1):
super(Actor, self).__init__()
# TODO: orthogonal initialization?
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
self.actor_net = nn.Sequential(*actor_net)
self.latent_pi, self.log_std = None, None
self.weights_dist, self.exploration_mat = None, None
self.use_sde = use_sde
def forward(self, obs):
return self.actor_net(obs)
if use_sde:
latent_dim = net_arch[-1]
latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
self.latent_pi = nn.Sequential(*latent_pi)
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
self.clip_noise = clip_noise
self.reset_noise()
else:
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
self.actor_net = nn.Sequential(*actor_net)
def reset_noise(self):
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
self.exploration_mat = self.weights_dist.rsample()
def forward(self, obs, deterministic=True):
if self.use_sde:
latent_pi = self.latent_pi(obs)
if deterministic:
return self.actor_net(latent_pi)
noise = th.mm(latent_pi.detach(), self.exploration_mat)
# noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
# TODO: fix clipping
return th.clamp(self.actor_net(latent_pi) + noise, -1, 1)
else:
return self.actor_net(obs)
class Critic(BaseNetwork):
@ -40,7 +67,7 @@ class Critic(BaseNetwork):
class TD3Policy(BasePolicy):
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU):
activation_fn=nn.ReLU, use_sde=False, log_std_init=0.0):
super(TD3Policy, self).__init__(observation_space, action_space, device)
if net_arch is None:
@ -56,8 +83,14 @@ class TD3Policy(BasePolicy):
'net_arch': self.net_arch,
'activation_fn': self.activation_fn
}
self.actor_kwargs = self.net_args.copy()
self.actor_kwargs['use_sde'] = use_sde
self.actor_kwargs['log_std_init'] = log_std_init
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
self.use_sde = use_sde
self.log_std_init = log_std_init
self._build(learning_rate)
def _build(self, learning_rate):
@ -71,14 +104,17 @@ class TD3Policy(BasePolicy):
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
def reset_noise(self):
return self.actor.reset_noise()
def make_actor(self):
return Actor(**self.net_args).to(self.device)
return Actor(**self.actor_kwargs).to(self.device)
def make_critic(self):
return Critic(**self.net_args).to(self.device)
def forward(self, obs):
return self.actor(obs)
def forward(self, obs, deterministic=True):
return self.actor(obs, deterministic=deterministic)
MlpPolicy = TD3Policy

View file

@ -37,6 +37,8 @@ class TD3(BaseRLModel):
:param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
(smoothing noise)
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
@ -50,6 +52,7 @@ class TD3(BaseRLModel):
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
use_sde=False,
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
seed=0, device='auto', _init_setup_model=True):
@ -70,6 +73,7 @@ class TD3(BaseRLModel):
self.policy_delay = policy_delay
self.target_noise_clip = target_noise_clip
self.target_policy_noise = target_policy_noise
self.use_sde = use_sde
if _init_setup_model:
self._setup_model()
@ -79,8 +83,8 @@ class TD3(BaseRLModel):
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy(self.observation_space, self.action_space,
self.learning_rate, device=self.device, **self.policy_kwargs)
self.policy = self.policy(self.observation_space, self.action_space, self.learning_rate,
use_sde=self.use_sde, device=self.device, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
self._create_aliases()
@ -90,12 +94,12 @@ class TD3(BaseRLModel):
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
def select_action(self, observation):
def select_action(self, observation, deterministic=True):
# Normally not needed
observation = np.array(observation)
with th.no_grad():
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
return self.actor(observation).cpu().numpy()
return self.actor(observation, deterministic=deterministic).cpu().numpy()
def predict(self, observation, state=None, mask=None, deterministic=True):
"""
@ -107,7 +111,7 @@ class TD3(BaseRLModel):
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
return self.unscale_action(self.select_action(observation))
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0):
# Update optimizer learning rate