Handling multi-dimensional action spaces (#971)

* Handle non 1D action shape

* Revert changes of observation (out of the scope of this PR)

* Apply changes  to DictReplayBuffer

* Update tests

* Rollout buffer n-D actions space handling

* Remove error when non 1D action space

* ActorCriticPolicy return action with the proper shape

* remove useless reshape

* Update changelog

* Add tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-08-06 14:19:20 +02:00 committed by GitHub
parent 6ce33f5bd2
commit c4f54fcf04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 10 deletions

View file

@ -19,6 +19,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)
Deprecations:
^^^^^^^^^^^^^

View file

@ -247,8 +247,7 @@ class ReplayBuffer(BaseBuffer):
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
# Same, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
@ -433,6 +432,9 @@ class RolloutBuffer(BaseBuffer):
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@ -586,8 +588,7 @@ class DictReplayBuffer(ReplayBuffer):
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
# Same reshape, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()

View file

@ -658,7 +658,6 @@ def make_proba_distribution(
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):

View file

@ -336,8 +336,8 @@ class BasePolicy(BaseModel):
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy):
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:

View file

@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env):
return self.observation_space.sample(), 0.0, False, {}
class DummyMultidimensionalAction(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
def test_identity_spaces(model_class, env):
@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
def test_action_spaces(model_class, env):
kwargs = {}
if model_class in [SAC, DDPG, TD3]:
supported_action_space = env == "Pendulum-v1"
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
kwargs["learning_starts"] = 2
kwargs["train_freq"] = 32
elif model_class == DQN:
supported_action_space = env == "CartPole-v1"
elif model_class in [A2C, PPO]:
supported_action_space = True
kwargs["n_steps"] = 64
if supported_action_space:
model_class("MlpPolicy", env)
model = model_class("MlpPolicy", env, **kwargs)
if isinstance(env, DummyMultidimensionalAction):
model.learn(64)
else:
with pytest.raises(AssertionError):
model_class("MlpPolicy", env)
def test_sde_multi_dim():
SAC(
"MlpPolicy",
DummyMultidimensionalAction(),
learning_starts=10,
use_sde=True,
sde_sample_freq=2,
use_sde_at_warmup=True,
).learn(20)
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):