mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
from collections import OrderedDict
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import numpy as np
|
|
from gym import GoalEnv, spaces
|
|
from gym.envs.registration import EnvSpec
|
|
|
|
from stable_baselines3.common.type_aliases import GymStepReturn
|
|
|
|
|
|
class BitFlippingEnv(GoalEnv):
|
|
"""
|
|
Simple bit flipping env, useful to test HER.
|
|
The goal is to flip all the bits to get a vector of ones.
|
|
In the continuous variant, if the ith action component has a value > 0,
|
|
then the ith bit will be flipped.
|
|
|
|
:param n_bits: Number of bits to flip
|
|
:param continuous: Whether to use the continuous actions version or not,
|
|
by default, it uses the discrete one
|
|
:param max_steps: Max number of steps, by default, equal to n_bits
|
|
:param discrete_obs_space: Whether to use the discrete observation
|
|
version or not, by default, it uses the MultiBinary one
|
|
"""
|
|
|
|
spec = EnvSpec("BitFlippingEnv-v0")
|
|
|
|
def __init__(
|
|
self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False
|
|
):
|
|
super(BitFlippingEnv, self).__init__()
|
|
# The achieved goal is determined by the current state
|
|
# here, it is a special where they are equal
|
|
if discrete_obs_space:
|
|
# In the discrete case, the agent act on the binary
|
|
# representation of the observation
|
|
self.observation_space = spaces.Dict(
|
|
{
|
|
"observation": spaces.Discrete(2 ** n_bits - 1),
|
|
"achieved_goal": spaces.Discrete(2 ** n_bits - 1),
|
|
"desired_goal": spaces.Discrete(2 ** n_bits - 1),
|
|
}
|
|
)
|
|
else:
|
|
self.observation_space = spaces.Dict(
|
|
{
|
|
"observation": spaces.MultiBinary(n_bits),
|
|
"achieved_goal": spaces.MultiBinary(n_bits),
|
|
"desired_goal": spaces.MultiBinary(n_bits),
|
|
}
|
|
)
|
|
|
|
self.obs_space = spaces.MultiBinary(n_bits)
|
|
|
|
if continuous:
|
|
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
|
|
else:
|
|
self.action_space = spaces.Discrete(n_bits)
|
|
self.continuous = continuous
|
|
self.discrete_obs_space = discrete_obs_space
|
|
self.state = None
|
|
self.desired_goal = np.ones((n_bits,))
|
|
if max_steps is None:
|
|
max_steps = n_bits
|
|
self.max_steps = max_steps
|
|
self.current_step = 0
|
|
|
|
def seed(self, seed: int) -> None:
|
|
self.obs_space.seed(seed)
|
|
|
|
def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
|
|
"""
|
|
Convert to discrete space if needed.
|
|
|
|
:param state:
|
|
:return:
|
|
"""
|
|
if self.discrete_obs_space:
|
|
# The internal state is the binary representation of the
|
|
# observed one
|
|
return int(sum([state[i] * 2 ** i for i in range(len(state))]))
|
|
return state
|
|
|
|
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
|
|
"""
|
|
Helper to create the observation.
|
|
|
|
:return:
|
|
"""
|
|
return OrderedDict(
|
|
[
|
|
("observation", self.convert_if_needed(self.state.copy())),
|
|
("achieved_goal", self.convert_if_needed(self.state.copy())),
|
|
("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
|
|
]
|
|
)
|
|
|
|
def reset(self) -> Dict[str, Union[int, np.ndarray]]:
|
|
self.current_step = 0
|
|
self.state = self.obs_space.sample()
|
|
return self._get_obs()
|
|
|
|
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
|
|
if self.continuous:
|
|
self.state[action > 0] = 1 - self.state[action > 0]
|
|
else:
|
|
self.state[action] = 1 - self.state[action]
|
|
obs = self._get_obs()
|
|
reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None))
|
|
done = reward == 0
|
|
self.current_step += 1
|
|
# Episode terminate when we reached the goal or the max number of steps
|
|
info = {"is_success": done}
|
|
done = done or self.current_step >= self.max_steps
|
|
return obs, reward, done, info
|
|
|
|
def compute_reward(
|
|
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
|
|
) -> np.float32:
|
|
# Deceptive reward: it is positive only when the goal is achieved
|
|
# vectorized version
|
|
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
|
|
return -(distance > 0).astype(np.float32)
|
|
|
|
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
|
if mode == "rgb_array":
|
|
return self.state.copy()
|
|
print(self.state)
|
|
|
|
def close(self) -> None:
|
|
pass
|