stable-baselines3/stable_baselines3/common/vec_env/vec_extract_dict_obs.py
WeberSamuel 15c9daa2ba
Fix VecExtractDictObs does not handle terminal observation (#1443)
* VecExtractDictObs handle terminal_observation

* Added VecExtractDictObs handle terminal_output to changelog

* Update changelog.rst

* Update test_vec_extract_dict_obs.py

Add random dones in env to test if terminal_observation is properly handled

* Made test deterministic

* Fixed bug in test

* Improved test

* Fix format in test

* Update test

* Fix type hint

* Ignore pytype warning

* Ignore pytype

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2023-04-12 15:20:04 +02:00

27 lines
917 B
Python

import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos