mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-05 04:17:59 +00:00
Add sde update for TD3
This commit is contained in:
parent
f2a61949ae
commit
a08382faab
3 changed files with 92 additions and 9 deletions
|
|
@ -57,6 +57,8 @@ class BaseRLModel(object):
|
|||
self.replay_buffer = None
|
||||
self.seed = seed
|
||||
self.action_noise = None
|
||||
# Used for SDE only
|
||||
self.rollout_data = None
|
||||
# Track the training progress (from 1 to 0)
|
||||
# this is used to update the learning rate
|
||||
self._current_progress = 1
|
||||
|
|
@ -113,7 +115,7 @@ class BaseRLModel(object):
|
|||
(no need for symmetric action space)
|
||||
"""
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
def _setup_learning_rate(self):
|
||||
"""Transform to callable if needed."""
|
||||
|
|
@ -214,12 +216,14 @@ class BaseRLModel(object):
|
|||
Return a trained model.
|
||||
|
||||
:param total_timesteps: (int) The total number of samples to train on
|
||||
:param seed: (int) The initial seed for training, if None: keep current seed
|
||||
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
|
||||
It takes the local and global variables. If it returns False, training is aborted.
|
||||
:param log_interval: (int) The number of timesteps before logging.
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
||||
:param eval_env: (gym.Env)
|
||||
:param eval_freq: (int)
|
||||
:param n_eval_episodes: (int)
|
||||
:return: (BaseRLModel) the trained model
|
||||
"""
|
||||
pass
|
||||
|
|
@ -327,8 +331,12 @@ class BaseRLModel(object):
|
|||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
self.rollout_data = None
|
||||
if hasattr(self, 'use_sde') and self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones']}
|
||||
# self.rollout_data = {'observations': [], 'actions': [], 'rewards': [], 'returns': [], 'dones': []}
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
|
|
@ -367,12 +375,19 @@ class BaseRLModel(object):
|
|||
if replay_buffer is not None:
|
||||
replay_buffer.add(obs, new_obs, action, reward, done_bool)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['actions'].append(action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
|
||||
obs = new_obs
|
||||
|
||||
num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if n_steps > 0 and total_steps >= n_steps:
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
|
|
@ -383,7 +398,8 @@ class BaseRLModel(object):
|
|||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (episode_num + total_episodes) % log_interval == 0:
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
# logger.logkv("mean 100 episode reward", mean_reward)
|
||||
|
|
@ -401,4 +417,20 @@ class BaseRLModel(object):
|
|||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
# Compute return
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
last_return = self.rollout_data['rewards'][step]
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from torchy_baselines.common.policies import BasePolicy, register_policy, create
|
|||
|
||||
class Actor(BaseNetwork):
|
||||
def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU,
|
||||
use_sde=False, log_std_init=-2, clip_noise=0.5):
|
||||
use_sde=False, log_std_init=-2, clip_noise=None, lr_sde=3e-4):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
self.latent_pi, self.log_std = None, None
|
||||
self.weights_dist, self.exploration_mat = None, None
|
||||
self.use_sde = use_sde
|
||||
self.use_sde, self.sde_optimizer = use_sde, None
|
||||
|
||||
if use_sde:
|
||||
latent_dim = net_arch[-1]
|
||||
|
|
@ -21,11 +21,25 @@ class Actor(BaseNetwork):
|
|||
self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init)
|
||||
self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh())
|
||||
self.clip_noise = clip_noise
|
||||
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
|
||||
self.reset_noise()
|
||||
else:
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
self.actor_net = nn.Sequential(*actor_net)
|
||||
|
||||
def get_distribution_stats(self, obs, action):
|
||||
with th.no_grad():
|
||||
latent_pi = self.latent_pi(obs)
|
||||
mean_actions = self.actor_net(latent_pi)
|
||||
variance = th.mm(latent_pi ** 2, th.exp(self.log_std) ** 2)
|
||||
distribution = Normal(mean_actions, th.sqrt(variance))
|
||||
log_prob = distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
return log_prob, distribution.entropy()
|
||||
|
||||
def reset_noise(self):
|
||||
self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std))
|
||||
self.exploration_mat = self.weights_dist.rsample()
|
||||
|
|
@ -36,7 +50,8 @@ class Actor(BaseNetwork):
|
|||
if deterministic:
|
||||
return self.actor_net(latent_pi)
|
||||
noise = th.mm(latent_pi.detach(), self.exploration_mat)
|
||||
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
|
||||
if self.clip_noise is not None:
|
||||
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
|
||||
# TODO: fix clipping with squashing ?
|
||||
return th.clamp(self.actor_net(latent_pi) + noise, -1, 1)
|
||||
else:
|
||||
|
|
@ -67,7 +82,7 @@ class Critic(BaseNetwork):
|
|||
class TD3Policy(BasePolicy):
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-2, clip_noise=0.5):
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-2, clip_noise=None):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class TD3(BaseRLModel):
|
|||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
use_sde=False,
|
||||
use_sde=False, sde_max_grad_norm=1, sde_ent_coef=0.0,
|
||||
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
|
||||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
|
|
@ -73,7 +73,10 @@ class TD3(BaseRLModel):
|
|||
self.policy_delay = policy_delay
|
||||
self.target_noise_clip = target_noise_clip
|
||||
self.target_policy_noise = target_policy_noise
|
||||
|
||||
self.use_sde = use_sde
|
||||
self.sde_max_grad_norm = sde_max_grad_norm
|
||||
self.sde_ent_coef = sde_ent_coef
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -191,6 +194,37 @@ class TD3(BaseRLModel):
|
|||
if gradient_step % policy_delay == 0:
|
||||
self.train_actor(replay_data=replay_data, tau_actor=self.tau, tau_critic=self.tau)
|
||||
|
||||
def train_sde(self):
|
||||
# Update optimizer learning rate
|
||||
# self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
# Unpack
|
||||
obs, action, returns = self.rollout_data['observations'], self.rollout_data['actions'], self.rollout_data['returns']
|
||||
|
||||
# TODO: avoid second computation of everything because of the gradient
|
||||
log_prob, entropy = self.actor.get_distribution_stats(obs, action)
|
||||
|
||||
# Normalize returns
|
||||
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
|
||||
|
||||
policy_loss = -(returns * log_prob).mean()
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.sde_ent_coef * entropy_loss
|
||||
|
||||
# Optimization step
|
||||
self.actor.sde_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# print(self.actor.log_std.grad.mean().item(), self.actor.log_std.grad.max().item(), self.actor.log_std.grad.min().item())
|
||||
# print(self.actor.log_std.mean().item(), self.actor.log_std.max().item(), self.actor.log_std.min().item())
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_([self.actor.log_std], self.sde_max_grad_norm)
|
||||
self.actor.sde_optimizer.step()
|
||||
|
||||
del self.rollout_data
|
||||
|
||||
def learn(self, total_timesteps, callback=None, log_interval=4,
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True):
|
||||
|
||||
|
|
@ -226,6 +260,8 @@ class TD3(BaseRLModel):
|
|||
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else episode_timesteps
|
||||
self.train(gradient_steps, batch_size=self.batch_size, policy_delay=self.policy_delay)
|
||||
if self.use_sde:
|
||||
self.train_sde()
|
||||
|
||||
# Evaluate episode
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue