diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index b2fc5a7..bdc1a12 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -124,7 +124,7 @@ class BaseBuffer(ABC): """ raise NotImplementedError() - def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: + def to_torch(self, array: np.ndarray, copy: bool = False) -> th.Tensor: """ Convert a numpy array to a PyTorch tensor. Note: it copies the data by default diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa50..0fa6ba2 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -9,7 +9,7 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, ReplayBufferSamples, Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy @@ -158,6 +158,9 @@ class SAC(OffPolicyAlgorithm): def _setup_model(self) -> None: super()._setup_model() + + self.policy = th.compile(self.policy) + self._create_aliases() # Running mean and running var self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) @@ -210,9 +213,22 @@ class SAC(OffPolicyAlgorithm): ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] + # Sample replay buffer + # Sample all data at once to reduce memory transfer time + all_replay_data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # type: ignore[union-attr] + for gradient_step in range(gradient_steps): - # Sample replay buffer - replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] + # Slice data + slice_indices = slice(gradient_step * batch_size, (gradient_step + 1) * batch_size) + replay_data = ReplayBufferSamples( + observations=all_replay_data.observations[slice_indices], + next_observations=all_replay_data.next_observations[slice_indices], + actions=all_replay_data.actions[slice_indices], + rewards=all_replay_data.rewards[slice_indices], + dones=all_replay_data.dones[slice_indices], + ) + + # replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: