From b9e8f6cd93166c7eb478380321c28ddc1e48005e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 9 Feb 2025 10:00:39 +0100 Subject: [PATCH] Allow for non-squashed output --- .../common/off_policy_algorithm.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c3e1c66..df41202 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -17,7 +17,7 @@ from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit -from stable_baselines3.common.utils import safe_mean, should_collect_more_steps +from stable_baselines3.common.utils import obs_as_tensor, safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer @@ -378,21 +378,24 @@ class OffPolicyAlgorithm(BaseAlgorithm): and scaled action that will be stored in the replay buffer. The two differs when the action space is not normalized (bounds are not [-1, 1]). """ + scaled_action = np.array([0.0]) # Select action randomly or according to policy if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): # Warmup phase - unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)]) + action = np.array([self.action_space.sample() for _ in range(n_envs)]) + if isinstance(self.action_space, spaces.Box): + scaled_action = self.policy.scale_action(action) else: - # Note: when using continuous actions, - # we assume that the policy uses tanh to scale the action - # We use non-deterministic action in the case of SAC, for TD3, it does not matter assert self._last_obs is not None, "self._last_obs was not set" - unscaled_action, _ = self.predict(self._last_obs, deterministic=False) + with th.no_grad(): + obs_tensor = obs_as_tensor(self._last_obs, self.device) + th_action = self.policy._predict(obs_tensor, deterministic=False) + action = th_action.cpu().numpy() + if self.policy.squash_output: + scaled_action = action # Rescale the action from [low, high] to [-1, 1] - if isinstance(self.action_space, spaces.Box): - scaled_action = self.policy.scale_action(unscaled_action) - + if isinstance(self.action_space, spaces.Box) and self.policy.squash_output: # Add noise to the action (improve exploration) if action_noise is not None: scaled_action = np.clip(scaled_action + action_noise(), -1, 1) @@ -400,10 +403,18 @@ class OffPolicyAlgorithm(BaseAlgorithm): # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) + elif isinstance(self.action_space, spaces.Box) and not self.policy.squash_output: + # Add noise to the action (improve exploration) + if action_noise is not None: + action = action + action_noise() + + buffer_action = action + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + action = np.clip(action, self.action_space.low, self.action_space.high) else: # Discrete case, no need to normalize or clip - buffer_action = unscaled_action - action = buffer_action + buffer_action = action return action, buffer_action def _dump_logs(self) -> None: