Enable optim

This commit is contained in:
Antonin RAFFIN 2021-01-09 16:11:48 +01:00
parent 5993033c73
commit 8f56befca7
3 changed files with 10 additions and 3 deletions

View file

@ -1,5 +1,7 @@
import os
import torch
from stable_baselines3.a2c import A2C
from stable_baselines3.ddpg import DDPG
from stable_baselines3.dqn import DQN
@ -8,6 +10,11 @@ from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3
# See https://www.youtube.com/watch?v=9mS1fIYj1So
# PyTorch Performance Tuning Guide
torch.backends.cudnn.benchmark = True
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as file_handler:

View file

@ -212,7 +212,7 @@ class PPO(OnPolicyAlgorithm):
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Optimization step
self.policy.optimizer.zero_grad()
self.policy.optimizer.zero_grad(set_to_none=True)
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)

View file

@ -240,7 +240,7 @@ class SAC(OffPolicyAlgorithm):
critic_losses.append(critic_loss.item())
# Optimize the critic
self.critic.optimizer.zero_grad()
self.critic.optimizer.zero_grad(set_to_none=True)
critic_loss.backward()
self.critic.optimizer.step()
@ -253,7 +253,7 @@ class SAC(OffPolicyAlgorithm):
actor_losses.append(actor_loss.item())
# Optimize the actor
self.actor.optimizer.zero_grad()
self.actor.optimizer.zero_grad(set_to_none=True)
actor_loss.backward()
self.actor.optimizer.step()