Try torch compile and other optimization

This commit is contained in:
Antonin Raffin 2024-07-07 15:20:21 +02:00
parent d8148deeaa
commit 955e202258
2 changed files with 20 additions and 4 deletions

View file

@ -124,7 +124,7 @@ class BaseBuffer(ABC):
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
def to_torch(self, array: np.ndarray, copy: bool = False) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default

View file

@ -9,7 +9,7 @@ from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, ReplayBufferSamples, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
@ -158,6 +158,9 @@ class SAC(OffPolicyAlgorithm):
def _setup_model(self) -> None:
super()._setup_model()
self.policy = th.compile(self.policy)
self._create_aliases()
# Running mean and running var
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
@ -210,9 +213,22 @@ class SAC(OffPolicyAlgorithm):
ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], []
# Sample replay buffer
# Sample all data at once to reduce memory transfer time
all_replay_data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # type: ignore[union-attr]
for gradient_step in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# Slice data
slice_indices = slice(gradient_step * batch_size, (gradient_step + 1) * batch_size)
replay_data = ReplayBufferSamples(
observations=all_replay_data.observations[slice_indices],
next_observations=all_replay_data.next_observations[slice_indices],
actions=all_replay_data.actions[slice_indices],
rewards=all_replay_data.rewards[slice_indices],
dones=all_replay_data.dones[slice_indices],
)
# replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
# We need to sample because `log_std` may have changed between two gradient steps
if self.use_sde: