mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-19 21:40:19 +00:00
Handling multi-dimensional action spaces (#971)
* Handle non 1D action shape * Revert changes of observation (out of the scope of this PR) * Apply changes to DictReplayBuffer * Update tests * Rollout buffer n-D actions space handling * Remove error when non 1D action space * ActorCriticPolicy return action with the proper shape * remove useless reshape * Update changelog * Add tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
6ce33f5bd2
commit
c4f54fcf04
5 changed files with 42 additions and 10 deletions
|
|
@ -19,6 +19,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
- Added multidimensional action space support (@qgallouedec)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -247,8 +247,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
|
||||
|
||||
# Same, for actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
# Copy to avoid modification by reference
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
|
|
@ -433,6 +432,9 @@ class RolloutBuffer(BaseBuffer):
|
|||
if isinstance(self.observation_space, spaces.Discrete):
|
||||
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
||||
|
||||
# Same reshape, for actions
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
|
|
@ -586,8 +588,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
|
||||
|
||||
# Same reshape, for actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
|
|
|
|||
|
|
@ -658,7 +658,6 @@ def make_proba_distribution(
|
|||
dist_kwargs = {}
|
||||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
|
|
|
|||
|
|
@ -336,8 +336,8 @@ class BasePolicy(BaseModel):
|
|||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
|
|
@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
actions = actions.reshape((-1,) + self.action_space.shape)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
||||
|
|
|
|||
|
|
@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env):
|
|||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
class DummyMultidimensionalAction(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
|
||||
def test_identity_spaces(model_class, env):
|
||||
|
|
@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
|
||||
def test_action_spaces(model_class, env):
|
||||
kwargs = {}
|
||||
if model_class in [SAC, DDPG, TD3]:
|
||||
supported_action_space = env == "Pendulum-v1"
|
||||
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
|
||||
kwargs["learning_starts"] = 2
|
||||
kwargs["train_freq"] = 32
|
||||
elif model_class == DQN:
|
||||
supported_action_space = env == "CartPole-v1"
|
||||
elif model_class in [A2C, PPO]:
|
||||
supported_action_space = True
|
||||
kwargs["n_steps"] = 64
|
||||
|
||||
if supported_action_space:
|
||||
model_class("MlpPolicy", env)
|
||||
model = model_class("MlpPolicy", env, **kwargs)
|
||||
if isinstance(env, DummyMultidimensionalAction):
|
||||
model.learn(64)
|
||||
else:
|
||||
with pytest.raises(AssertionError):
|
||||
model_class("MlpPolicy", env)
|
||||
|
||||
|
||||
def test_sde_multi_dim():
|
||||
SAC(
|
||||
"MlpPolicy",
|
||||
DummyMultidimensionalAction(),
|
||||
learning_starts=10,
|
||||
use_sde=True,
|
||||
sde_sample_freq=2,
|
||||
use_sde_at_warmup=True,
|
||||
).learn(20)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
|
||||
@pytest.mark.parametrize("env", ["Taxi-v3"])
|
||||
def test_discrete_obs_space(model_class, env):
|
||||
|
|
|
|||
Loading…
Reference in a new issue