Allow for non-squashed output

This commit is contained in:
Antonin RAFFIN 2025-02-09 10:00:39 +01:00
parent c5c29a32d9
commit b9e8f6cd93

View file

@ -17,7 +17,7 @@ from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.utils import obs_as_tensor, safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
@ -378,21 +378,24 @@ class OffPolicyAlgorithm(BaseAlgorithm):
and scaled action that will be stored in the replay buffer.
The two differs when the action space is not normalized (bounds are not [-1, 1]).
"""
scaled_action = np.array([0.0])
# Select action randomly or according to policy
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
# Warmup phase
unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
action = np.array([self.action_space.sample() for _ in range(n_envs)])
if isinstance(self.action_space, spaces.Box):
scaled_action = self.policy.scale_action(action)
else:
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
with th.no_grad():
obs_tensor = obs_as_tensor(self._last_obs, self.device)
th_action = self.policy._predict(obs_tensor, deterministic=False)
action = th_action.cpu().numpy()
if self.policy.squash_output:
scaled_action = action
# Rescale the action from [low, high] to [-1, 1]
if isinstance(self.action_space, spaces.Box):
scaled_action = self.policy.scale_action(unscaled_action)
if isinstance(self.action_space, spaces.Box) and self.policy.squash_output:
# Add noise to the action (improve exploration)
if action_noise is not None:
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
@ -400,10 +403,18 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
elif isinstance(self.action_space, spaces.Box) and not self.policy.squash_output:
# Add noise to the action (improve exploration)
if action_noise is not None:
action = action + action_noise()
buffer_action = action
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
action = np.clip(action, self.action_space.low, self.action_space.high)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action
buffer_action = action
return action, buffer_action
def _dump_logs(self) -> None: