mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Bug fixes (not working yet)
This commit is contained in:
parent
6bb7e183d2
commit
e1c1d5c4ab
3 changed files with 34 additions and 23 deletions
|
|
@ -4,21 +4,21 @@ import gym
|
|||
|
||||
from torchy_baselines import TD3, CEMRL, PPO
|
||||
|
||||
# def test_pendulum():
|
||||
# model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), start_timesteps=100, verbose=1)
|
||||
# model.learn(total_timesteps=500, eval_freq=100)
|
||||
# model.save("test_save")
|
||||
# model.load("test_save")
|
||||
# os.remove("test_save.pth")
|
||||
#
|
||||
#
|
||||
# def test_cemrl():
|
||||
# model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1,
|
||||
# start_timesteps=100, verbose=1)
|
||||
# model.learn(total_timesteps=1000, eval_freq=500)
|
||||
# model.save("test_save")
|
||||
# model.load("test_save")
|
||||
# os.remove("test_save.pth")
|
||||
def test_pendulum():
|
||||
model = TD3('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), start_timesteps=100, verbose=1)
|
||||
model.learn(total_timesteps=500, eval_freq=100)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
|
||||
|
||||
def test_cemrl():
|
||||
model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1,
|
||||
start_timesteps=100, verbose=1)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.save("test_save")
|
||||
model.load("test_save")
|
||||
os.remove("test_save.pth")
|
||||
|
||||
def test_ppo():
|
||||
model = PPO('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
|
|
|
|||
|
|
@ -30,18 +30,28 @@ class PPOPolicy(BasePolicy):
|
|||
self.shared_net = nn.Sequential(*shared_net).to(self.device)
|
||||
self.actor_net = nn.Linear(self.net_arch[-1], self.action_dim)
|
||||
self.value_net = nn.Linear(self.net_arch[-1], 1)
|
||||
self.log_std = nn.Parameter(th.zeros(self.action_dim, 1))
|
||||
self.log_std = nn.Parameter(th.zeros(self.action_dim))
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
|
||||
def forward(self, state):
|
||||
|
||||
state = th.FloatTensor(state).to(self.device)
|
||||
latent = self.shared_net(state)
|
||||
# TODO: initialize pi_mean weights properly
|
||||
# TODO: change when multiple envs
|
||||
mean_actions = self.actor_net(latent)
|
||||
action_distribution = Normal(mean_actions, self.log_std)
|
||||
action_std = th.ones(mean_actions.size()) * self.log_std.exp()
|
||||
action_distribution = Normal(mean_actions, action_std)
|
||||
# Sample from the gaussian
|
||||
# rsample: reparametrization trick
|
||||
action = action_distribution.rsample()
|
||||
# TODO: handle shape properly
|
||||
# sum(axis=1)
|
||||
log_prob = action_distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
log_prob = log_prob.sum(axis=1)
|
||||
else:
|
||||
log_prob = log_prob.sum()
|
||||
# entropy = action_distribution.entropy()
|
||||
value = self.value_net(latent)
|
||||
return action, value, log_prob
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class PPO(BaseRLModel):
|
|||
"""
|
||||
|
||||
def __init__(self, policy, env, policy_kwargs=None, verbose=0,
|
||||
learning_rate=1e-3, seed=0, device='auto',
|
||||
learning_rate=3e-4, seed=0, device='auto',
|
||||
n_optim=5, batch_size=64, n_steps=256,
|
||||
gamma=0.99, lambda_=0.95, clip_range=0.2,
|
||||
ent_coef=0.01, vf_coef=0.5,
|
||||
|
|
@ -111,24 +111,25 @@ class PPO(BaseRLModel):
|
|||
for it in range(n_iterations):
|
||||
# Sample replay buffer
|
||||
replay_data = self.rollout_buffer.sample(batch_size)
|
||||
state, action, next_state, done, reward, _, old_log_prob, advantage, return_batch = replay_data
|
||||
state, _, _, _, _, _, old_log_prob, advantage, return_batch = replay_data
|
||||
|
||||
_, value, log_prob = self.policy.forward(state)
|
||||
_, values, log_prob = self.policy.forward(state)
|
||||
|
||||
# Normalize advantage
|
||||
# advs = returns - values
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
|
||||
|
||||
ratio = th.exp(log_prob - old_log_prob)
|
||||
policy_loss_1 = -advantage * ratio
|
||||
policy_loss_2 = -advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
|
||||
policy_loss_1 = advantage * ratio
|
||||
policy_loss_2 = advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
# value_loss = th.mean((returns - value)**2)
|
||||
value_loss = F.mse_loss(return_batch, value)
|
||||
value_loss = F.mse_loss(return_batch.detach(), values)
|
||||
# Approximate entropy
|
||||
# TODO: replace by distribution entropy
|
||||
entropy_loss = th.mean(-log_prob)
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
# loss = policy_loss
|
||||
# TODO: check kl div
|
||||
# approx_kl_div = th.mean(old_log_prob - log_prob)
|
||||
# Optimization step
|
||||
|
|
|
|||
Loading…
Reference in a new issue