mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Enable optim
This commit is contained in:
parent
5993033c73
commit
8f56befca7
3 changed files with 10 additions and 3 deletions
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from stable_baselines3.a2c import A2C
|
||||
from stable_baselines3.ddpg import DDPG
|
||||
from stable_baselines3.dqn import DQN
|
||||
|
|
@ -8,6 +10,11 @@ from stable_baselines3.ppo import PPO
|
|||
from stable_baselines3.sac import SAC
|
||||
from stable_baselines3.td3 import TD3
|
||||
|
||||
# See https://www.youtube.com/watch?v=9mS1fIYj1So
|
||||
# PyTorch Performance Tuning Guide
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Read version from file
|
||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||
with open(version_file, "r") as file_handler:
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
self.policy.optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
self.critic.optimizer.zero_grad()
|
||||
self.critic.optimizer.zero_grad(set_to_none=True)
|
||||
critic_loss.backward()
|
||||
self.critic.optimizer.step()
|
||||
|
||||
|
|
@ -253,7 +253,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
self.actor.optimizer.zero_grad(set_to_none=True)
|
||||
actor_loss.backward()
|
||||
self.actor.optimizer.step()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue