Add sde update for TD3

This commit is contained in:
Antonin Raffin 2019-11-12 18:37:13 +01:00
parent f2a61949ae
commit a08382faab
3 changed files with 92 additions and 9 deletions

View file

@ -57,6 +57,8 @@ class BaseRLModel(object):
self.replay_buffer = None
self.seed = seed
self.action_noise = None
# Used for SDE only
self.rollout_data = None
# Track the training progress (from 1 to 0)
# this is used to update the learning rate
self._current_progress = 1
@ -113,7 +115,7 @@ class BaseRLModel(object):
(no need for symmetric action space)
"""
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
return low + (0.5 * (scaled_action + 1.0) * (high - low))
def _setup_learning_rate(self):
"""Transform to callable if needed."""
@ -214,12 +216,14 @@ class BaseRLModel(object):
Return a trained model.
:param total_timesteps: (int) The total number of samples to train on
:param seed: (int) The initial seed for training, if None: keep current seed
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
It takes the local and global variables. If it returns False, training is aborted.
:param log_interval: (int) The number of timesteps before logging.
:param tb_log_name: (str) the name of the run for tensorboard log
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:param eval_env: (gym.Env)
:param eval_freq: (int)
:param n_eval_episodes: (int)
:return: (BaseRLModel) the trained model
"""
pass
@ -327,8 +331,12 @@ class BaseRLModel(object):
assert isinstance(env, VecEnv)
assert env.num_envs == 1
self.rollout_data = None
if hasattr(self, 'use_sde') and self.use_sde:
self.actor.reset_noise()
# Reset rollout data
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones']}
# self.rollout_data = {'observations': [], 'actions': [], 'rewards': [], 'returns': [], 'dones': []}
while total_steps < n_steps or total_episodes < n_episodes:
done = False
@ -367,12 +375,19 @@ class BaseRLModel(object):
if replay_buffer is not None:
replay_buffer.add(obs, new_obs, action, reward, done_bool)
if self.rollout_data is not None:
# Assume only one env
self.rollout_data['observations'].append(obs[0].copy())
self.rollout_data['actions'].append(action[0].copy())
self.rollout_data['rewards'].append(reward[0].copy())
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
obs = new_obs
num_timesteps += 1
episode_timesteps += 1
total_steps += 1
if n_steps > 0 and total_steps >= n_steps:
if 0 < n_steps <= total_steps:
break
if done:
@ -383,7 +398,8 @@ class BaseRLModel(object):
action_noise.reset()
# Display training infos
if self.verbose >= 1 and log_interval is not None and (episode_num + total_episodes) % log_interval == 0:
if self.verbose >= 1 and log_interval is not None and (
episode_num + total_episodes) % log_interval == 0:
fps = int(num_timesteps / (time.time() - self.start_time))
logger.logkv("episodes", episode_num + total_episodes)
# logger.logkv("mean 100 episode reward", mean_reward)
@ -401,4 +417,20 @@ class BaseRLModel(object):
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
# Post processing
if self.rollout_data is not None:
for key in ['observations', 'actions', 'rewards', 'dones']:
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
# Compute return
last_return = 0.0
for step in reversed(range(len(self.rollout_data['rewards']))):
if step == len(self.rollout_data['rewards']) - 1:
last_return = self.rollout_data['rewards'][step]
else:
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
self.rollout_data['returns'][step] = last_return
return mean_reward, total_steps, total_episodes, obs

View file

@ -7,12 +7,12 @@ from torchy_baselines.common.policies import BasePolicy, register_policy, create
class Actor(BaseNetwork):
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
use_sde=False, log_std_init=-2, clip_noise=0.5):
use_sde=False, log_std_init=-2, clip_noise=None, lr_sde=3e-4):
super(Actor, self).__init__()
self.latent_pi, self.log_std = None, None
self.weights_dist, self.exploration_mat = None, None
self.use_sde = use_sde
self.use_sde, self.sde_optimizer = use_sde, None
if use_sde:
latent_dim = net_arch[-1]
@ -21,11 +21,25 @@ class Actor(BaseNetwork):
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
self.clip_noise = clip_noise
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
self.reset_noise()
else:
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
self.actor_net = nn.Sequential(*actor_net)
def get_distribution_stats(self, obs, action):
with th.no_grad():
latent_pi = self.latent_pi(obs)
mean_actions = self.actor_net(latent_pi)
variance = th.mm(latent_pi ** 2, th.exp(self.log_std) ** 2)
distribution = Normal(mean_actions, th.sqrt(variance))
log_prob = distribution.log_prob(action)
if len(log_prob.shape) > 1:
log_prob = log_prob.sum(axis=1)
else:
log_prob = log_prob.sum()
return log_prob, distribution.entropy()
def reset_noise(self):
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
self.exploration_mat = self.weights_dist.rsample()
@ -36,7 +50,8 @@ class Actor(BaseNetwork):
if deterministic:
return self.actor_net(latent_pi)
noise = th.mm(latent_pi.detach(), self.exploration_mat)
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
if self.clip_noise is not None:
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
# TODO: fix clipping with squashing ?
return th.clamp(self.actor_net(latent_pi) + noise, -1, 1)
else:
@ -67,7 +82,7 @@ class Critic(BaseNetwork):
class TD3Policy(BasePolicy):
def __init__(self, observation_space, action_space,
learning_rate, net_arch=None, device='cpu',
activation_fn=nn.ReLU, use_sde=False, log_std_init=-2, clip_noise=0.5):
activation_fn=nn.ReLU, use_sde=False, log_std_init=-2, clip_noise=None):
super(TD3Policy, self).__init__(observation_space, action_space, device)
if net_arch is None:

View file

@ -52,7 +52,7 @@ class TD3(BaseRLModel):
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
use_sde=False,
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0,
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
seed=0, device='auto', _init_setup_model=True):
@ -73,7 +73,10 @@ class TD3(BaseRLModel):
self.policy_delay = policy_delay
self.target_noise_clip = target_noise_clip
self.target_policy_noise = target_policy_noise
self.use_sde = use_sde
self.sde_max_grad_norm = sde_max_grad_norm
self.sde_ent_coef = sde_ent_coef
if _init_setup_model:
self._setup_model()
@ -191,6 +194,37 @@ class TD3(BaseRLModel):
if gradient_step % policy_delay == 0:
self.train_actor(replay_data=replay_data, tau_actor=self.tau, tau_critic=self.tau)
def train_sde(self):
# Update optimizer learning rate
# self._update_learning_rate(self.policy.optimizer)
# Unpack
obs, action, returns = self.rollout_data['observations'], self.rollout_data['actions'], self.rollout_data['returns']
# TODO: avoid second computation of everything because of the gradient
log_prob, entropy = self.actor.get_distribution_stats(obs, action)
# Normalize returns
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
policy_loss = -(returns * log_prob).mean()
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.sde_ent_coef * entropy_loss
# Optimization step
self.actor.sde_optimizer.zero_grad()
loss.backward()
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
# Clip grad norm
th.nn.utils.clip_grad_norm_([self.actor.log_std], self.sde_max_grad_norm)
self.actor.sde_optimizer.step()
del self.rollout_data
def learn(self, total_timesteps, callback=None, log_interval=4,
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True):
@ -226,6 +260,8 @@ class TD3(BaseRLModel):
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
if self.use_sde:
self.train_sde()
# Evaluate episode
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: