stable-baselines3/stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Antonin RAFFIN 63a0bb9da1
Type annotation bundle (logger, vec env, custom envs) (#1479)
* Switch from List to Sequence for `seed()` type hint

* Fix logger type hints

* Improve replay buffer type hints

* Fix custom envs type annotations

* Fix VecMonitor type hints

* Fix RMSprop type hint

* Fix vec extract dict obs type hints

* Fix vec frame stack type annotations

* Fix base vec env type hints

* Fix dummy vec env type hints

* Fix for mypy

* Fixes for the tests

* mypy doesn't like when we overwrite type

* fix step of SimpleMultiObsEnv

* remove useless type specification

* Rm useless type hint

* Improve logger type hint

* format

* rm useless type hint

* Re-add variables in constructor, remove unused import

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2023-05-04 20:27:15 +02:00

33 lines
1.2 KiB
Python

import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
assert isinstance(
venv.observation_space, spaces.Dict
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
assert isinstance(obs, dict)
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
assert isinstance(obs, dict)
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos