From e1c1d5c4abd1ca14fdbf53e500e79327d453a3c4 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 18 Sep 2019 22:12:32 +0200 Subject: [PATCH] Bug fixes (not working yet) --- tests/test_run.py | 30 +++++++++++++++--------------- torchy_baselines/ppo/policies.py | 14 ++++++++++++-- torchy_baselines/ppo/ppo.py | 13 +++++++------ 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/tests/test_run.py b/tests/test_run.py index 342ced9..e28c532 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -4,21 +4,21 @@ import gym from torchy_baselines import TD3, CEMRL, PPO -# def test_pendulum(): -# model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), start_timesteps=100, verbose=1) -# model.learn(total_timesteps=500, eval_freq=100) -# model.save("test_save") -# model.load("test_save") -# os.remove("test_save.pth") -# -# -# def test_cemrl(): -# model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, -# start_timesteps=100, verbose=1) -# model.learn(total_timesteps=1000, eval_freq=500) -# model.save("test_save") -# model.load("test_save") -# os.remove("test_save.pth") +def test_pendulum(): + model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), start_timesteps=100, verbose=1) + model.learn(total_timesteps=500, eval_freq=100) + model.save("test_save") + model.load("test_save") + os.remove("test_save.pth") + + +def test_cemrl(): + model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1, + start_timesteps=100, verbose=1) + model.learn(total_timesteps=1000, eval_freq=500) + model.save("test_save") + model.load("test_save") + os.remove("test_save.pth") def test_ppo(): model = PPO('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), verbose=1) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 1aea078..092e603 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -30,18 +30,28 @@ class PPOPolicy(BasePolicy): self.shared_net = nn.Sequential(*shared_net).to(self.device) self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim) self.value_net = nn.Linear(self.net_arch[-1], 1) - self.log_std = nn.Parameter(th.zeros(self.action_dim, 1)) + self.log_std = nn.Parameter(th.zeros(self.action_dim)) self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate) def forward(self, state): + state = th.FloatTensor(state).to(self.device) latent = self.shared_net(state) # TODO: initialize pi_mean weights properly + # TODO: change when multiple envs mean_actions = self.actor_net(latent) - action_distribution = Normal(mean_actions, self.log_std) + action_std = th.ones(mean_actions.size()) * self.log_std.exp() + action_distribution = Normal(mean_actions, action_std) # Sample from the gaussian + # rsample: reparametrization trick action = action_distribution.rsample() + # TODO: handle shape properly + # sum(axis=1) log_prob = action_distribution.log_prob(action) + if len(log_prob.shape) > 1: + log_prob = log_prob.sum(axis=1) + else: + log_prob = log_prob.sum() # entropy = action_distribution.entropy() value = self.value_net(latent) return action, value, log_prob diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 45c8ab2..104b925 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -20,7 +20,7 @@ class PPO(BaseRLModel): """ def __init__(self, policy, env, policy_kwargs=None, verbose=0, - learning_rate=1e-3, seed=0, device='auto', + learning_rate=3e-4, seed=0, device='auto', n_optim=5, batch_size=64, n_steps=256, gamma=0.99, lambda_=0.95, clip_range=0.2, ent_coef=0.01, vf_coef=0.5, @@ -111,24 +111,25 @@ class PPO(BaseRLModel): for it in range(n_iterations): # Sample replay buffer replay_data = self.rollout_buffer.sample(batch_size) - state, action, next_state, done, reward, _, old_log_prob, advantage, return_batch = replay_data + state, _, _, _, _, _, old_log_prob, advantage, return_batch = replay_data - _, value, log_prob = self.policy.forward(state) + _, values, log_prob = self.policy.forward(state) # Normalize advantage # advs = returns - values advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) ratio = th.exp(log_prob - old_log_prob) - policy_loss_1 = -advantage * ratio - policy_loss_2 = -advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) + policy_loss_1 = advantage * ratio + policy_loss_2 = advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # value_loss = th.mean((returns - value)**2) - value_loss = F.mse_loss(return_batch, value) + value_loss = F.mse_loss(return_batch.detach(), values) # Approximate entropy # TODO: replace by distribution entropy entropy_loss = th.mean(-log_prob) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + # loss = policy_loss # TODO: check kl div # approx_kl_div = th.mean(old_log_prob - log_prob) # Optimization step