diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3932918..8b0a741 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ Bug Fixes: - Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()`` - Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD) - Fixed check_env for Sequence observation space (@corentinlger) +- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 1089797..3ea0c7b 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -13,15 +13,17 @@ class BitFlippingEnv(Env): Simple bit flipping env, useful to test HER. The goal is to flip all the bits to get a vector of ones. In the continuous variant, if the ith action component has a value > 0, - then the ith bit will be flipped. + then the ith bit will be flipped. Uses a ``MultiBinary`` observation space + by default. :param n_bits: Number of bits to flip :param continuous: Whether to use the continuous actions version or not, by default, it uses the discrete one :param max_steps: Max number of steps, by default, equal to n_bits :param discrete_obs_space: Whether to use the discrete observation - version or not, by default, it uses the ``MultiBinary`` one - :param image_obs_space: Use image as input instead of the ``MultiBinary`` one. + version or not, ie a one-hot encoding of all possible states + :param image_obs_space: Whether to use an image observation version + or not, ie a greyscale image of the state :param channel_first: Whether to use channel-first or last image. """ @@ -44,52 +46,11 @@ class BitFlippingEnv(Env): self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state # here, it is a special where they are equal - if discrete_obs_space: - # In the discrete case, the agent act on the binary - # representation of the observation - self.observation_space = spaces.Dict( - { - "observation": spaces.Discrete(2**n_bits), - "achieved_goal": spaces.Discrete(2**n_bits), - "desired_goal": spaces.Discrete(2**n_bits), - } - ) - elif image_obs_space: - # When using image as input, - # one image contains the bits 0 -> 0, 1 -> 255 - # and the rest is filled with zeros - self.observation_space = spaces.Dict( - { - "observation": spaces.Box( - low=0, - high=255, - shape=self.image_shape, - dtype=np.uint8, - ), - "achieved_goal": spaces.Box( - low=0, - high=255, - shape=self.image_shape, - dtype=np.uint8, - ), - "desired_goal": spaces.Box( - low=0, - high=255, - shape=self.image_shape, - dtype=np.uint8, - ), - } - ) - else: - self.observation_space = spaces.Dict( - { - "observation": spaces.MultiBinary(n_bits), - "achieved_goal": spaces.MultiBinary(n_bits), - "desired_goal": spaces.MultiBinary(n_bits), - } - ) - self.obs_space = spaces.MultiBinary(n_bits) + # observation space for observations given to the model + self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits) + # observation space used to update internal state + self._obs_space = spaces.MultiBinary(n_bits) if continuous: self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) @@ -105,7 +66,7 @@ class BitFlippingEnv(Env): self.current_step = 0 def seed(self, seed: int) -> None: - self.obs_space.seed(seed) + self._obs_space.seed(seed) def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: """ @@ -144,6 +105,64 @@ class BitFlippingEnv(Env): bit_vector = np.array(state).reshape(batch_size, -1) return bit_vector + def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict: + """ + Helper to create observation space + + :param discrete_obs_space: Whether to use the discrete observation version + :param image_obs_space: Whether to use the image observation version + :param n_bits: The number of bits used to represent the state + :return: the environment observation space + """ + if discrete_obs_space and image_obs_space: + raise ValueError("Cannot use both discrete and image observation spaces") + + if discrete_obs_space: + # In the discrete case, the agent act on the binary + # representation of the observation + return spaces.Dict( + { + "observation": spaces.Discrete(2**n_bits), + "achieved_goal": spaces.Discrete(2**n_bits), + "desired_goal": spaces.Discrete(2**n_bits), + } + ) + + if image_obs_space: + # When using image as input, + # one image contains the bits 0 -> 0, 1 -> 255 + # and the rest is filled with zeros + return spaces.Dict( + { + "observation": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), + "achieved_goal": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), + "desired_goal": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), + } + ) + + return spaces.Dict( + { + "observation": spaces.MultiBinary(n_bits), + "achieved_goal": spaces.MultiBinary(n_bits), + "desired_goal": spaces.MultiBinary(n_bits), + } + ) + def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: """ Helper to create the observation. @@ -162,9 +181,9 @@ class BitFlippingEnv(Env): self, *, seed: Optional[int] = None, options: Optional[Dict] = None ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: if seed is not None: - self.obs_space.seed(seed) + self._obs_space.seed(seed) self.current_step = 0 - self.state = self.obs_space.sample() + self.state = self._obs_space.sample() return self._get_obs(), {} def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: