mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
BitFlippingEnv argument check and docs clarification (#1698)
* made change, not tested yet * add back _obs_space with note on purpose * match formatting * update documentation
This commit is contained in:
parent
2ca94cb73d
commit
fab6cb339d
2 changed files with 71 additions and 51 deletions
|
|
@ -26,6 +26,7 @@ Bug Fixes:
|
|||
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``
|
||||
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
|
||||
- Fixed check_env for Sequence observation space (@corentinlger)
|
||||
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -13,15 +13,17 @@ class BitFlippingEnv(Env):
|
|||
Simple bit flipping env, useful to test HER.
|
||||
The goal is to flip all the bits to get a vector of ones.
|
||||
In the continuous variant, if the ith action component has a value > 0,
|
||||
then the ith bit will be flipped.
|
||||
then the ith bit will be flipped. Uses a ``MultiBinary`` observation space
|
||||
by default.
|
||||
|
||||
:param n_bits: Number of bits to flip
|
||||
:param continuous: Whether to use the continuous actions version or not,
|
||||
by default, it uses the discrete one
|
||||
:param max_steps: Max number of steps, by default, equal to n_bits
|
||||
:param discrete_obs_space: Whether to use the discrete observation
|
||||
version or not, by default, it uses the ``MultiBinary`` one
|
||||
:param image_obs_space: Use image as input instead of the ``MultiBinary`` one.
|
||||
version or not, ie a one-hot encoding of all possible states
|
||||
:param image_obs_space: Whether to use an image observation version
|
||||
or not, ie a greyscale image of the state
|
||||
:param channel_first: Whether to use channel-first or last image.
|
||||
"""
|
||||
|
||||
|
|
@ -44,52 +46,11 @@ class BitFlippingEnv(Env):
|
|||
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
|
||||
# The achieved goal is determined by the current state
|
||||
# here, it is a special where they are equal
|
||||
if discrete_obs_space:
|
||||
# In the discrete case, the agent act on the binary
|
||||
# representation of the observation
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Discrete(2**n_bits),
|
||||
"achieved_goal": spaces.Discrete(2**n_bits),
|
||||
"desired_goal": spaces.Discrete(2**n_bits),
|
||||
}
|
||||
)
|
||||
elif image_obs_space:
|
||||
# When using image as input,
|
||||
# one image contains the bits 0 -> 0, 1 -> 255
|
||||
# and the rest is filled with zeros
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"achieved_goal": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"desired_goal": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"observation": spaces.MultiBinary(n_bits),
|
||||
"achieved_goal": spaces.MultiBinary(n_bits),
|
||||
"desired_goal": spaces.MultiBinary(n_bits),
|
||||
}
|
||||
)
|
||||
|
||||
self.obs_space = spaces.MultiBinary(n_bits)
|
||||
# observation space for observations given to the model
|
||||
self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits)
|
||||
# observation space used to update internal state
|
||||
self._obs_space = spaces.MultiBinary(n_bits)
|
||||
|
||||
if continuous:
|
||||
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
|
||||
|
|
@ -105,7 +66,7 @@ class BitFlippingEnv(Env):
|
|||
self.current_step = 0
|
||||
|
||||
def seed(self, seed: int) -> None:
|
||||
self.obs_space.seed(seed)
|
||||
self._obs_space.seed(seed)
|
||||
|
||||
def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
|
||||
"""
|
||||
|
|
@ -144,6 +105,64 @@ class BitFlippingEnv(Env):
|
|||
bit_vector = np.array(state).reshape(batch_size, -1)
|
||||
return bit_vector
|
||||
|
||||
def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict:
|
||||
"""
|
||||
Helper to create observation space
|
||||
|
||||
:param discrete_obs_space: Whether to use the discrete observation version
|
||||
:param image_obs_space: Whether to use the image observation version
|
||||
:param n_bits: The number of bits used to represent the state
|
||||
:return: the environment observation space
|
||||
"""
|
||||
if discrete_obs_space and image_obs_space:
|
||||
raise ValueError("Cannot use both discrete and image observation spaces")
|
||||
|
||||
if discrete_obs_space:
|
||||
# In the discrete case, the agent act on the binary
|
||||
# representation of the observation
|
||||
return spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Discrete(2**n_bits),
|
||||
"achieved_goal": spaces.Discrete(2**n_bits),
|
||||
"desired_goal": spaces.Discrete(2**n_bits),
|
||||
}
|
||||
)
|
||||
|
||||
if image_obs_space:
|
||||
# When using image as input,
|
||||
# one image contains the bits 0 -> 0, 1 -> 255
|
||||
# and the rest is filled with zeros
|
||||
return spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"achieved_goal": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
"desired_goal": spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.image_shape,
|
||||
dtype=np.uint8,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return spaces.Dict(
|
||||
{
|
||||
"observation": spaces.MultiBinary(n_bits),
|
||||
"achieved_goal": spaces.MultiBinary(n_bits),
|
||||
"desired_goal": spaces.MultiBinary(n_bits),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
|
||||
"""
|
||||
Helper to create the observation.
|
||||
|
|
@ -162,9 +181,9 @@ class BitFlippingEnv(Env):
|
|||
self, *, seed: Optional[int] = None, options: Optional[Dict] = None
|
||||
) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
|
||||
if seed is not None:
|
||||
self.obs_space.seed(seed)
|
||||
self._obs_space.seed(seed)
|
||||
self.current_step = 0
|
||||
self.state = self.obs_space.sample()
|
||||
self.state = self._obs_space.sample()
|
||||
return self._get_obs(), {}
|
||||
|
||||
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
|
||||
|
|
|
|||
Loading…
Reference in a new issue