diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index bcac479..de93489 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,5 +1,7 @@ import os +import torch + from stable_baselines3.a2c import A2C from stable_baselines3.ddpg import DDPG from stable_baselines3.dqn import DQN @@ -8,6 +10,11 @@ from stable_baselines3.ppo import PPO from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 +# See https://www.youtube.com/watch?v=9mS1fIYj1So +# PyTorch Performance Tuning Guide + +torch.backends.cudnn.benchmark = True + # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file, "r") as file_handler: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 52579b8..09b0e99 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -212,7 +212,7 @@ class PPO(OnPolicyAlgorithm): loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step - self.policy.optimizer.zero_grad() + self.policy.optimizer.zero_grad(set_to_none=True) loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index cd7a413..772770f 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -240,7 +240,7 @@ class SAC(OffPolicyAlgorithm): critic_losses.append(critic_loss.item()) # Optimize the critic - self.critic.optimizer.zero_grad() + self.critic.optimizer.zero_grad(set_to_none=True) critic_loss.backward() self.critic.optimizer.step() @@ -253,7 +253,7 @@ class SAC(OffPolicyAlgorithm): actor_losses.append(actor_loss.item()) # Optimize the actor - self.actor.optimizer.zero_grad() + self.actor.optimizer.zero_grad(set_to_none=True) actor_loss.backward() self.actor.optimizer.step()