Add value function for SDE + TD3

This commit is contained in:
Antonin Raffin 2019-12-17 15:01:08 +01:00
parent 1d6f9bf100
commit 57708a628c
3 changed files with 52 additions and 15 deletions

View file

@ -511,7 +511,7 @@ class BaseRLModel(object):
self.actor.reset_noise()
# Reset rollout data
if self.on_policy_exploration:
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones']}
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
while total_steps < n_steps or total_episodes < n_episodes:
done = False
@ -574,6 +574,7 @@ class BaseRLModel(object):
self.rollout_data['actions'].append(scaled_action[0].copy())
self.rollout_data['rewards'].append(reward[0].copy())
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
self.rollout_data['values'].append(self.vf_net(th.FloatTensor(obs).to(self.device))[0].cpu().detach().numpy())
obs = new_obs
# Save the true unnormalized observation
@ -615,19 +616,24 @@ class BaseRLModel(object):
# Post processing
if self.rollout_data is not None:
for key in ['observations', 'actions', 'rewards', 'dones']:
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
# Compute return
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
# Compute return and advantage
last_return = 0.0
for step in reversed(range(len(self.rollout_data['rewards']))):
if step == len(self.rollout_data['rewards']) - 1:
last_return = self.rollout_data['rewards'][step]
next_non_terminal = 1.0 - done[0]
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
else:
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
self.rollout_data['returns'][step] = last_return
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
return mean_reward, total_steps, total_episodes, obs

View file

@ -158,6 +158,27 @@ class Critic(BaseNetwork):
return self.q_networks[0](th.cat([obs, action], dim=1))
class ValueFunction(BaseNetwork):
"""
Value function for TD3 when doing on-policy exploration with SDE.
:param obs_dim: (int) Dimension of the observation
:param net_arch: ([int]) Network architecture
:param activation_fn: (nn.Module) Activation function
"""
def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh):
super(ValueFunction, self).__init__()
if net_arch is None:
net_arch = [64, 64]
vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn)
self.vf_net = nn.Sequential(*vf_net)
def forward(self, obs):
return self.vf_net(obs)
class TD3Policy(BasePolicy):
"""
Policy class (with both actor and critic) for TD3.
@ -206,7 +227,9 @@ class TD3Policy(BasePolicy):
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
# For SDE only
self.use_sde = use_sde
self.vf_net = None
self.log_std_init = log_std_init
self._build(learning_rate)
@ -221,6 +244,10 @@ class TD3Policy(BasePolicy):
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
if self.use_sde:
self.vf_net = ValueFunction(self.obs_dim)
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()})
def reset_noise(self):
return self.actor.reset_noise()

View file

@ -9,6 +9,7 @@ from torchy_baselines.common.buffers import ReplayBuffer
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.td3.policies import TD3Policy
from torchy_baselines.common.vec_env import sync_envs_normalization
from torchy_baselines.ppo.policies import MlpPolicy
class TD3(BaseRLModel):
@ -85,6 +86,7 @@ class TD3(BaseRLModel):
self.sde_ent_coef = sde_ent_coef
self.sde_log_std_scheduler = sde_log_std_scheduler
self.on_policy_exploration = True
self.sde_vf = None
if _init_setup_model:
self._setup_model()
@ -104,6 +106,7 @@ class TD3(BaseRLModel):
self.actor_target = self.policy.actor_target
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
self.vf_net = self.policy.vf_net
def select_action(self, observation, deterministic=True):
# Normally not needed
@ -209,25 +212,26 @@ class TD3(BaseRLModel):
# self._update_learning_rate(self.policy.optimizer)
# Unpack
obs, action, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'returns']]
obs, action, advantage, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'advantage', 'returns']]
# TODO: avoid second computation of everything because of the gradient
log_prob, entropy = self.actor.evaluate_actions(obs, action)
values = self.vf_net(obs).flatten()
# Normalize returns
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# returns = (returns - returns.mean())
with th.no_grad():
current_q1, current_q2 = self.critic(obs, action)
# Alternatively use the q value
returns = (returns - th.min(current_q1, current_q2))
# Normalize advantage
# if self.normalize_advantage:
# advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
policy_loss = -(returns * log_prob).mean()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(returns, values)
# A2C loss
policy_loss = -(advantage * log_prob).mean()
# Entropy loss favor exploration
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.sde_ent_coef * entropy_loss
vf_coef = 0.5
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss
# Optimization step
self.actor.sde_optimizer.zero_grad()