mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-19 21:40:19 +00:00
Add value function for SDE + TD3
This commit is contained in:
parent
1d6f9bf100
commit
57708a628c
3 changed files with 52 additions and 15 deletions
|
|
@ -511,7 +511,7 @@ class BaseRLModel(object):
|
|||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones']}
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
|
|
@ -574,6 +574,7 @@ class BaseRLModel(object):
|
|||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
self.rollout_data['values'].append(self.vf_net(th.FloatTensor(obs).to(self.device))[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
|
|
@ -615,19 +616,24 @@ class BaseRLModel(object):
|
|||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones']:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
# Compute return
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
last_return = self.rollout_data['rewards'][step]
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs
|
||||
|
||||
|
|
|
|||
|
|
@ -158,6 +158,27 @@ class Critic(BaseNetwork):
|
|||
return self.q_networks[0](th.cat([obs, action], dim=1))
|
||||
|
||||
|
||||
class ValueFunction(BaseNetwork):
|
||||
"""
|
||||
Value function for TD3 when doing on-policy exploration with SDE.
|
||||
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
"""
|
||||
def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh):
|
||||
super(ValueFunction, self).__init__()
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [64, 64]
|
||||
|
||||
vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn)
|
||||
self.vf_net = nn.Sequential(*vf_net)
|
||||
|
||||
def forward(self, obs):
|
||||
return self.vf_net(obs)
|
||||
|
||||
|
||||
class TD3Policy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TD3.
|
||||
|
|
@ -206,7 +227,9 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
# For SDE only
|
||||
self.use_sde = use_sde
|
||||
self.vf_net = None
|
||||
self.log_std_init = log_std_init
|
||||
self._build(learning_rate)
|
||||
|
||||
|
|
@ -221,6 +244,10 @@ class TD3Policy(BasePolicy):
|
|||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
|
||||
|
||||
if self.use_sde:
|
||||
self.vf_net = ValueFunction(self.obs_dim)
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()})
|
||||
|
||||
def reset_noise(self):
|
||||
return self.actor.reset_noise()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torchy_baselines.common.buffers import ReplayBuffer
|
|||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
from torchy_baselines.common.vec_env import sync_envs_normalization
|
||||
from torchy_baselines.ppo.policies import MlpPolicy
|
||||
|
||||
|
||||
class TD3(BaseRLModel):
|
||||
|
|
@ -85,6 +86,7 @@ class TD3(BaseRLModel):
|
|||
self.sde_ent_coef = sde_ent_coef
|
||||
self.sde_log_std_scheduler = sde_log_std_scheduler
|
||||
self.on_policy_exploration = True
|
||||
self.sde_vf = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -104,6 +106,7 @@ class TD3(BaseRLModel):
|
|||
self.actor_target = self.policy.actor_target
|
||||
self.critic = self.policy.critic
|
||||
self.critic_target = self.policy.critic_target
|
||||
self.vf_net = self.policy.vf_net
|
||||
|
||||
def select_action(self, observation, deterministic=True):
|
||||
# Normally not needed
|
||||
|
|
@ -209,25 +212,26 @@ class TD3(BaseRLModel):
|
|||
# self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
# Unpack
|
||||
obs, action, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'returns']]
|
||||
obs, action, advantage, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'advantage', 'returns']]
|
||||
|
||||
# TODO: avoid second computation of everything because of the gradient
|
||||
log_prob, entropy = self.actor.evaluate_actions(obs, action)
|
||||
values = self.vf_net(obs).flatten()
|
||||
|
||||
# Normalize returns
|
||||
# returns = (returns - returns.mean()) / (returns.std() + 1e-8)
|
||||
# returns = (returns - returns.mean())
|
||||
with th.no_grad():
|
||||
current_q1, current_q2 = self.critic(obs, action)
|
||||
# Alternatively use the q value
|
||||
returns = (returns - th.min(current_q1, current_q2))
|
||||
# Normalize advantage
|
||||
# if self.normalize_advantage:
|
||||
# advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
|
||||
policy_loss = -(returns * log_prob).mean()
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(returns, values)
|
||||
|
||||
# A2C loss
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
|
||||
# Entropy loss favor exploration
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
loss = policy_loss + self.sde_ent_coef * entropy_loss
|
||||
vf_coef = 0.5
|
||||
loss = policy_loss + self.sde_ent_coef * entropy_loss + vf_coef * value_loss
|
||||
|
||||
# Optimization step
|
||||
self.actor.sde_optimizer.zero_grad()
|
||||
|
|
|
|||
Loading…
Reference in a new issue