mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Try torch compile and other optimization
This commit is contained in:
parent
d8148deeaa
commit
955e202258
2 changed files with 20 additions and 4 deletions
|
|
@ -124,7 +124,7 @@ class BaseBuffer(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
|
||||
def to_torch(self, array: np.ndarray, copy: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Convert a numpy array to a PyTorch tensor.
|
||||
Note: it copies the data by default
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from stable_baselines3.common.buffers import ReplayBuffer
|
|||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, ReplayBufferSamples, Schedule
|
||||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
|
|
@ -158,6 +158,9 @@ class SAC(OffPolicyAlgorithm):
|
|||
|
||||
def _setup_model(self) -> None:
|
||||
super()._setup_model()
|
||||
|
||||
self.policy = th.compile(self.policy)
|
||||
|
||||
self._create_aliases()
|
||||
# Running mean and running var
|
||||
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
|
||||
|
|
@ -210,9 +213,22 @@ class SAC(OffPolicyAlgorithm):
|
|||
ent_coef_losses, ent_coefs = [], []
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
# Sample replay buffer
|
||||
# Sample all data at once to reduce memory transfer time
|
||||
all_replay_data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
# Slice data
|
||||
slice_indices = slice(gradient_step * batch_size, (gradient_step + 1) * batch_size)
|
||||
replay_data = ReplayBufferSamples(
|
||||
observations=all_replay_data.observations[slice_indices],
|
||||
next_observations=all_replay_data.next_observations[slice_indices],
|
||||
actions=all_replay_data.actions[slice_indices],
|
||||
rewards=all_replay_data.rewards[slice_indices],
|
||||
dones=all_replay_data.dones[slice_indices],
|
||||
)
|
||||
|
||||
# replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
|
||||
|
||||
# We need to sample because `log_std` may have changed between two gradient steps
|
||||
if self.use_sde:
|
||||
|
|
|
|||
Loading…
Reference in a new issue