mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Allow for non-squashed output
This commit is contained in:
parent
c5c29a32d9
commit
b9e8f6cd93
1 changed files with 22 additions and 11 deletions
|
|
@ -17,7 +17,7 @@ from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
|
|||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
|
||||
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
||||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean, should_collect_more_steps
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
|
|
@ -378,21 +378,24 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
and scaled action that will be stored in the replay buffer.
|
||||
The two differs when the action space is not normalized (bounds are not [-1, 1]).
|
||||
"""
|
||||
scaled_action = np.array([0.0])
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
|
||||
action = np.array([self.action_space.sample() for _ in range(n_envs)])
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
scaled_action = self.policy.scale_action(action)
|
||||
else:
|
||||
# Note: when using continuous actions,
|
||||
# we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
assert self._last_obs is not None, "self._last_obs was not set"
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
with th.no_grad():
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
th_action = self.policy._predict(obs_tensor, deterministic=False)
|
||||
action = th_action.cpu().numpy()
|
||||
if self.policy.squash_output:
|
||||
scaled_action = action
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
if isinstance(self.action_space, spaces.Box) and self.policy.squash_output:
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
|
||||
|
|
@ -400,10 +403,18 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# We store the scaled action in the buffer
|
||||
buffer_action = scaled_action
|
||||
action = self.policy.unscale_action(scaled_action)
|
||||
elif isinstance(self.action_space, spaces.Box) and not self.policy.squash_output:
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
action = action + action_noise()
|
||||
|
||||
buffer_action = action
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else:
|
||||
# Discrete case, no need to normalize or clip
|
||||
buffer_action = unscaled_action
|
||||
action = buffer_action
|
||||
buffer_action = action
|
||||
return action, buffer_action
|
||||
|
||||
def _dump_logs(self) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue