stable-baselines3/stable_baselines3/common/bit_flipping_env.py
Megan Klaiber dd6e361204
Implement HER (#120)
* Added working her version, Online sampling is missing.

* Updated test_her.

* Added first version of online her sampling. Still problems with tensor dimensions.

* Reformat

* Fixed tests

* Added some comments.

* Updated changelog.

* Add missing init file

* Fixed some small bugs.

* Reduced arguments for HER, small changes.

* Added getattr. Fixed bug for online sampling.

* Updated save/load funtions. Small changes.

* Added her to init.

* Updated save method.

* Updated her ratio.

* Move obs_wrapper

* Added DQN test.

* Fix potential bug

* Offline and online her share same sample_goal function.

* Changed lists into arrays.

* Updated her test.

* Fix online sampling

* Fixed action bug. Updated time limit for episodes.

* Updated convert_dict method to take keys as arguments.

* Renamed obs dict wrapper.

* Seed bit flipping env

* Remove get_episode_dict

* Add fast online sampling version

* Added documentation.

* Vectorized reward computation

* Vectorized goal sampling

* Update time limit for episodes in online her sampling.

* Fix max episode length inference

* Bug fix for Fetch envs

* Fix for HER + gSDE

* Reformat (new black version)

* Added info dict to compute new reward. Check her_replay_buffer again.

* Fix info buffer

* Updated done flag.

* Fixes for gSDE

* Offline her version uses now HerReplayBuffer as episode storage.

* Fix num_timesteps computation

* Fix get torch params

* Vectorized version for offline sampling.

* Modified offline her sampling to use sample method of her_replay_buffer

* Updated HER tests.

* Updated documentation

* Cleanup docstrings

* Updated to review comments

* Fix pytype

* Update according to review comments.

* Removed random goal strategy. Updated sample transitions.

* Updated migration. Removed time signal removal.

* Update doc

* Fix potential load issue

* Add VecNormalize support for dict obs

* Updated saving/loading replay buffer for HER.

* Fix test memory usage

* Fixed save/load replay buffer.

* Fixed save/load replay buffer

* Fixed transition index after loading replay buffer in online sampling

* Better error handling

* Add tests for get_time_limit

* More tests for VecNormalize with dict obs

* Update doc

* Improve HER description

* Add test for sde support

* Add comments

* Add comments

* Remove check that was always valid

* Fix for terminal observation

* Updated buffer size in offline version and reset of HER buffer

* Reformat

* Update doc

* Remove np.empty + add doc

* Fix loading

* Updated loading replay buffer

* Separate online and offline sampling + bug fixes

* Update tensorboard log name

* Version bump

* Bug fix for special case

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 11:56:43 +02:00

131 lines
4.7 KiB
Python

from collections import OrderedDict
from typing import Any, Dict, Optional, Union
import numpy as np
from gym import GoalEnv, spaces
from gym.envs.registration import EnvSpec
from stable_baselines3.common.type_aliases import GymStepReturn
class BitFlippingEnv(GoalEnv):
"""
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
In the continuous variant, if the ith action component has a value > 0,
then the ith bit will be flipped.
:param n_bits: Number of bits to flip
:param continuous: Whether to use the continuous actions version or not,
by default, it uses the discrete one
:param max_steps: Max number of steps, by default, equal to n_bits
:param discrete_obs_space: Whether to use the discrete observation
version or not, by default, it uses the MultiBinary one
"""
spec = EnvSpec("BitFlippingEnv-v0")
def __init__(
self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False
):
super(BitFlippingEnv, self).__init__()
# The achieved goal is determined by the current state
# here, it is a special where they are equal
if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
self.observation_space = spaces.Dict(
{
"observation": spaces.Discrete(2 ** n_bits - 1),
"achieved_goal": spaces.Discrete(2 ** n_bits - 1),
"desired_goal": spaces.Discrete(2 ** n_bits - 1),
}
)
else:
self.observation_space = spaces.Dict(
{
"observation": spaces.MultiBinary(n_bits),
"achieved_goal": spaces.MultiBinary(n_bits),
"desired_goal": spaces.MultiBinary(n_bits),
}
)
self.obs_space = spaces.MultiBinary(n_bits)
if continuous:
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
else:
self.action_space = spaces.Discrete(n_bits)
self.continuous = continuous
self.discrete_obs_space = discrete_obs_space
self.state = None
self.desired_goal = np.ones((n_bits,))
if max_steps is None:
max_steps = n_bits
self.max_steps = max_steps
self.current_step = 0
def seed(self, seed: int) -> None:
self.obs_space.seed(seed)
def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
"""
Convert to discrete space if needed.
:param state:
:return:
"""
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum([state[i] * 2 ** i for i in range(len(state))]))
return state
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
"""
Helper to create the observation.
:return:
"""
return OrderedDict(
[
("observation", self.convert_if_needed(self.state.copy())),
("achieved_goal", self.convert_if_needed(self.state.copy())),
("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
]
)
def reset(self) -> Dict[str, Union[int, np.ndarray]]:
self.current_step = 0
self.state = self.obs_space.sample()
return self._get_obs()
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
self.state[action] = 1 - self.state[action]
obs = self._get_obs()
reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None))
done = reward == 0
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
info = {"is_success": done}
done = done or self.current_step >= self.max_steps
return obs, reward, done, info
def compute_reward(
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
) -> np.float32:
# Deceptive reward: it is positive only when the goal is achieved
# vectorized version
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
def render(self, mode: str = "human") -> Optional[np.ndarray]:
if mode == "rgb_array":
return self.state.copy()
print(self.state)
def close(self) -> None:
pass