mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
Fix type annotations of buffers (#1700)
* Fix type annotation and replay buffer * Exclude pytype check * Remove some pytype specific annotaiton and update changelog * Fix HerReplayBuffer type hints * try remove # type: ignore[assignment] * revert change --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
fab6cb339d
commit
c6c660e51b
13 changed files with 107 additions and 84 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a5 (WIP)
|
||||
Release 2.2.0a6 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -49,6 +49,9 @@ Others:
|
|||
- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/save_util.py`` type hints
|
||||
- Updated docker images to Ubuntu Jammy using micromamba 1.5
|
||||
- Fixed ``stable_baselines3/common/buffers.py`` type hints
|
||||
- Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints
|
||||
- Buffers do no call an additional ``.copy()`` when storing new transitions
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -27,19 +27,27 @@ line-length = 127
|
|||
[tool.pytype]
|
||||
inputs = ["stable_baselines3"]
|
||||
disable = ["pyi-error"]
|
||||
# Checked with mypy
|
||||
exclude = [
|
||||
"stable_baselines3/common/buffers.py",
|
||||
"stable_baselines3/common/base_class.py",
|
||||
"stable_baselines3/common/callbacks.py",
|
||||
"stable_baselines3/common/on_policy_algorithm.py",
|
||||
"stable_baselines3/common/vec_env/stacked_observations.py",
|
||||
"stable_baselines3/common/vec_env/subproc_vec_env.py",
|
||||
"stable_baselines3/common/vec_env/patch_gym.py"
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
stable_baselines3/common/buffers.py$
|
||||
| stable_baselines3/common/distributions.py$
|
||||
stable_baselines3/common/distributions.py$
|
||||
| stable_baselines3/common/off_policy_algorithm.py$
|
||||
| stable_baselines3/common/policies.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
| stable_baselines3/her/her_replay_buffer.py$
|
||||
| tests/test_logger.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
)"""
|
||||
|
|
|
|||
|
|
@ -420,9 +420,7 @@ class BaseAlgorithm(ABC):
|
|||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
assert self.env is not None
|
||||
# pytype: disable=annotation-type-mismatch
|
||||
self._last_obs = self.env.reset() # type: ignore[assignment]
|
||||
# pytype: enable=annotation-type-mismatch
|
||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
|
|
@ -707,7 +705,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
# Gym -> Gymnasium space conversion
|
||||
for key in {"observation_space", "action_space"}:
|
||||
data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands
|
||||
data[key] = _convert_space(data[key])
|
||||
|
||||
if env is not None:
|
||||
# Wrap first if needed
|
||||
|
|
@ -726,14 +724,12 @@ class BaseAlgorithm(ABC):
|
|||
if "env" in data:
|
||||
env = data["env"]
|
||||
|
||||
# pytype: disable=not-instantiable,wrong-keyword-args
|
||||
model = cls(
|
||||
policy=data["policy_class"],
|
||||
env=env,
|
||||
device=device,
|
||||
_init_setup_model=False, # type: ignore[call-arg]
|
||||
)
|
||||
# pytype: enable=not-instantiable,wrong-keyword-args
|
||||
|
||||
# load parameters
|
||||
model.__dict__.update(data)
|
||||
|
|
@ -776,7 +772,7 @@ class BaseAlgorithm(ABC):
|
|||
# Sample gSDE exploration matrix, so it uses the right device
|
||||
# see issue #44
|
||||
if model.use_sde:
|
||||
model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error
|
||||
model.policy.reset_noise() # type: ignore[operator]
|
||||
return model
|
||||
|
||||
def get_parameters(self) -> Dict[str, Dict]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -35,6 +35,9 @@ class BaseBuffer(ABC):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
observation_space: spaces.Space
|
||||
obs_shape: Tuple[int, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
|
|
@ -47,7 +50,7 @@ class BaseBuffer(ABC):
|
|||
self.buffer_size = buffer_size
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.obs_shape = get_obs_shape(observation_space)
|
||||
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
|
||||
|
||||
self.action_dim = get_action_dim(action_space)
|
||||
self.pos = 0
|
||||
|
|
@ -171,6 +174,13 @@ class ReplayBuffer(BaseBuffer):
|
|||
https://github.com/DLR-RM/stable-baselines3/issues/284
|
||||
"""
|
||||
|
||||
observations: np.ndarray
|
||||
next_observations: np.ndarray
|
||||
actions: np.ndarray
|
||||
rewards: np.ndarray
|
||||
dones: np.ndarray
|
||||
timeouts: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
|
|
@ -201,10 +211,8 @@ class ReplayBuffer(BaseBuffer):
|
|||
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
|
||||
|
||||
if optimize_memory_usage:
|
||||
# `observations` contains also the next observation
|
||||
self.next_observations = None
|
||||
else:
|
||||
if not optimize_memory_usage:
|
||||
# When optimizing memory, `observations` contains also the next observation
|
||||
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
|
||||
|
||||
self.actions = np.zeros(
|
||||
|
|
@ -219,9 +227,11 @@ class ReplayBuffer(BaseBuffer):
|
|||
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
|
||||
if psutil is not None:
|
||||
total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
||||
total_memory_usage: float = (
|
||||
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
||||
)
|
||||
|
||||
if self.next_observations is not None:
|
||||
if not optimize_memory_usage:
|
||||
total_memory_usage += self.next_observations.nbytes
|
||||
|
||||
if total_memory_usage > mem_available:
|
||||
|
|
@ -252,16 +262,16 @@ class ReplayBuffer(BaseBuffer):
|
|||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
# Copy to avoid modification by reference
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
self.observations[self.pos] = np.array(obs)
|
||||
|
||||
if self.optimize_memory_usage:
|
||||
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
|
||||
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
|
||||
else:
|
||||
self.next_observations[self.pos] = np.array(next_obs).copy()
|
||||
self.next_observations[self.pos] = np.array(next_obs)
|
||||
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
self.dones[self.pos] = np.array(done).copy()
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.dones[self.pos] = np.array(done)
|
||||
|
||||
if self.handle_timeout_termination:
|
||||
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
||||
|
|
@ -457,10 +467,10 @@ class RolloutBuffer(BaseBuffer):
|
|||
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
||||
self.observations[self.pos] = np.array(obs)
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.episode_starts[self.pos] = np.array(episode_start)
|
||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||
self.pos += 1
|
||||
|
|
@ -498,7 +508,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
) -> RolloutBufferSamples:
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
|
|
@ -527,10 +537,15 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
https://github.com/DLR-RM/stable-baselines3/issues/284
|
||||
"""
|
||||
|
||||
observation_space: spaces.Dict
|
||||
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
n_envs: int = 1,
|
||||
|
|
@ -576,8 +591,8 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
for _, obs in self.observations.items():
|
||||
obs_nbytes += obs.nbytes
|
||||
|
||||
total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
||||
if self.next_observations is not None:
|
||||
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
||||
if not optimize_memory_usage:
|
||||
next_obs_nbytes = 0
|
||||
for _, obs in self.observations.items():
|
||||
next_obs_nbytes += obs.nbytes
|
||||
|
|
@ -592,7 +607,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
||||
)
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
next_obs: Dict[str, np.ndarray],
|
||||
|
|
@ -600,7 +615,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> None: # pytype: disable=signature-mismatch
|
||||
) -> None:
|
||||
# Copy to avoid modification by reference
|
||||
for key in self.observations.keys():
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
|
|
@ -612,14 +627,14 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
for key in self.next_observations.keys():
|
||||
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
||||
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
|
||||
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
|
||||
self.next_observations[key][self.pos] = np.array(next_obs[key])
|
||||
|
||||
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
self.dones[self.pos] = np.array(done).copy()
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.dones[self.pos] = np.array(done)
|
||||
|
||||
if self.handle_timeout_termination:
|
||||
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
||||
|
|
@ -629,11 +644,11 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
self.full = True
|
||||
self.pos = 0
|
||||
|
||||
def sample(
|
||||
def sample( # type: ignore[override]
|
||||
self,
|
||||
batch_size: int,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
) -> DictReplayBufferSamples:
|
||||
"""
|
||||
Sample elements from the replay buffer.
|
||||
|
||||
|
|
@ -644,11 +659,11 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
"""
|
||||
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
|
||||
|
||||
def _get_samples(
|
||||
def _get_samples( # type: ignore[override]
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
) -> DictReplayBufferSamples:
|
||||
# Sample randomly the env idx
|
||||
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
||||
|
||||
|
|
@ -658,6 +673,8 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
{key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
|
||||
)
|
||||
|
||||
assert isinstance(obs_, dict)
|
||||
assert isinstance(next_obs_, dict)
|
||||
# Convert to torch tensor
|
||||
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
|
||||
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
|
||||
|
|
@ -700,12 +717,14 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
observations: Dict[str, np.ndarray]
|
||||
observation_space: spaces.Dict
|
||||
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
|
||||
observations: Dict[str, np.ndarray] # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
|
|
@ -723,7 +742,6 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
||||
self.observations = {}
|
||||
for key, obs_input_shape in self.obs_shape.items():
|
||||
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
|
||||
|
|
@ -737,7 +755,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
self.generator_ready = False
|
||||
super(RolloutBuffer, self).reset()
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: Dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
|
|
@ -745,7 +763,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
episode_start: np.ndarray,
|
||||
value: th.Tensor,
|
||||
log_prob: th.Tensor,
|
||||
) -> None: # pytype: disable=signature-mismatch
|
||||
) -> None:
|
||||
"""
|
||||
:param obs: Observation
|
||||
:param action: Action
|
||||
|
|
@ -761,7 +779,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
for key in self.observations.keys():
|
||||
obs_ = np.array(obs[key]).copy()
|
||||
obs_ = np.array(obs[key])
|
||||
# Reshape needed when using multiple envs with discrete observations
|
||||
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
||||
|
|
@ -771,19 +789,19 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
||||
self.actions[self.pos] = np.array(action)
|
||||
self.rewards[self.pos] = np.array(reward)
|
||||
self.episode_starts[self.pos] = np.array(episode_start)
|
||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(
|
||||
def get( # type: ignore[override]
|
||||
self,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
|
||||
) -> Generator[DictRolloutBufferSamples, None, None]:
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -806,11 +824,11 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(
|
||||
def _get_samples( # type: ignore[override]
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
) -> DictRolloutBufferSamples:
|
||||
return DictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
actions=self.to_torch(self.actions[batch_inds]),
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ except ImportError:
|
|||
# if the progress bar is used
|
||||
tqdm = None
|
||||
|
||||
from stable_baselines3.common import base_class # pytype: disable=pyi-error
|
||||
from stable_baselines3.common import base_class
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
||||
|
||||
|
|
@ -680,7 +680,7 @@ class ProgressBarCallback(BaseCallback):
|
|||
using tqdm and rich packages.
|
||||
"""
|
||||
|
||||
pbar: tqdm # pytype: disable=invalid-annotation
|
||||
pbar: tqdm
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -112,18 +112,16 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
self.observation_space, # type: ignore[arg-type]
|
||||
self.action_space,
|
||||
device=self.device,
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
)
|
||||
# pytype:disable=not-instantiable
|
||||
self.policy = self.policy_class( # type: ignore[assignment]
|
||||
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
|
||||
)
|
||||
# pytype:enable=not-instantiable
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ def preprocess_obs(
|
|||
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# One hot encoding and convert to float to avoid errors
|
||||
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
|
||||
return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float()
|
||||
|
||||
elif isinstance(observation_space, spaces.MultiDiscrete):
|
||||
# Tensor concatenation of one hot encodings of each Categorical sub-space
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class RMSpropTFLike(Optimizer):
|
|||
group.setdefault("centered", False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
|
||||
"""Performs a single optimization step.
|
||||
|
||||
:param closure: A closure that reevaluates the model
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Union
|
|||
import gymnasium
|
||||
|
||||
try:
|
||||
import gym # pytype: disable=import-error
|
||||
import gym
|
||||
|
||||
gym_installed = True
|
||||
except ImportError:
|
||||
|
|
@ -37,7 +37,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
|
|||
)
|
||||
|
||||
try:
|
||||
import shimmy # pytype: disable=import-error
|
||||
import shimmy
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Missing shimmy installation. You provided an OpenAI Gym environment. "
|
||||
|
|
@ -83,7 +83,7 @@ def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Spac
|
|||
)
|
||||
|
||||
try:
|
||||
import shimmy # pytype: disable=import-error
|
||||
import shimmy
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Missing shimmy installation. You provided an OpenAI Gym space. "
|
||||
|
|
|
|||
|
|
@ -9,9 +9,6 @@ from stable_baselines3.common.preprocessing import is_image_space, is_image_spac
|
|||
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
|
||||
|
||||
|
||||
# Disable errors for pytype which doesn't play well with Generic[TypeVar]
|
||||
# mypy check passes though
|
||||
# pytype: disable=attribute-error
|
||||
class StackedObservations(Generic[TObs]):
|
||||
"""
|
||||
Frame stacking wrapper for data.
|
||||
|
|
@ -109,16 +106,14 @@ class StackedObservations(Generic[TObs]):
|
|||
:return: The stacked reset observation
|
||||
"""
|
||||
if isinstance(observation, dict):
|
||||
return {
|
||||
key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()
|
||||
} # pytype: disable=bad-return-type
|
||||
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
|
||||
|
||||
self.stacked_obs[...] = 0
|
||||
if self.channels_first:
|
||||
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
|
||||
else:
|
||||
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
|
||||
return self.stacked_obs # pytype: disable=bad-return-type
|
||||
return self.stacked_obs
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -109,9 +109,7 @@ class SubprocVecEnv(VecEnv):
|
|||
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
|
||||
args = (work_remote, remote, CloudpickleWrapper(env_fn))
|
||||
# daemon=True: if the main process crashes, we should not cause things to hang
|
||||
# pytype: disable=attribute-error
|
||||
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
|
||||
# pytype: enable=attribute-error
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
work_remote.close()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch as th
|
|||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer
|
||||
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, TensorDict
|
||||
from stable_baselines3.common.type_aliases import DictReplayBufferSamples
|
||||
from stable_baselines3.common.vec_env import VecEnv, VecNormalize
|
||||
from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy
|
||||
|
||||
|
|
@ -45,10 +45,12 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
False by default.
|
||||
"""
|
||||
|
||||
env: Optional[VecEnv]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
env: VecEnv,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -130,10 +132,10 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
|
||||
self.env = env
|
||||
|
||||
def add(
|
||||
def add( # type: ignore[override]
|
||||
self,
|
||||
obs: TensorDict,
|
||||
next_obs: TensorDict,
|
||||
obs: Dict[str, np.ndarray],
|
||||
next_obs: Dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
reward: np.ndarray,
|
||||
done: np.ndarray,
|
||||
|
|
@ -181,7 +183,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
# Update the current episode start
|
||||
self._current_ep_start[env_idx] = self.pos
|
||||
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[override]
|
||||
"""
|
||||
Sample elements from the replay buffer.
|
||||
|
||||
|
|
@ -264,6 +266,8 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
{key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}, env
|
||||
)
|
||||
|
||||
assert isinstance(obs_, dict)
|
||||
assert isinstance(next_obs_, dict)
|
||||
# Convert to torch tensor
|
||||
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
|
||||
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
|
||||
|
|
@ -309,6 +313,9 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
# The desired goal for the next observation must be the same as the previous one
|
||||
next_obs["desired_goal"] = new_goals
|
||||
|
||||
assert (
|
||||
self.env is not None
|
||||
), "You must initialize HerReplayBuffer with a VecEnv so it can compute rewards for virtual transitions"
|
||||
# Compute new reward
|
||||
rewards = self.env.env_method(
|
||||
"compute_reward",
|
||||
|
|
@ -326,8 +333,8 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
indices=[0],
|
||||
)
|
||||
rewards = rewards[0].astype(np.float32) # env_method returns a list containing one element
|
||||
obs = self._normalize_obs(obs, env)
|
||||
next_obs = self._normalize_obs(next_obs, env)
|
||||
obs = self._normalize_obs(obs, env) # type: ignore[assignment]
|
||||
next_obs = self._normalize_obs(next_obs, env) # type: ignore[assignment]
|
||||
|
||||
# Convert to torch tensor
|
||||
observations = {key: self.to_torch(obs) for key, obs in obs.items()}
|
||||
|
|
@ -342,7 +349,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
dones=self.to_torch(
|
||||
self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices])
|
||||
).reshape(-1, 1),
|
||||
rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)),
|
||||
rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> np.ndarray:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a5
|
||||
2.2.0a6
|
||||
|
|
|
|||
Loading…
Reference in a new issue