mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update Gymnasium to v1.0.0 (#1837)
* Update Gymnasium to v1.0.0a1 * Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord) * Fix ruff warnings * Register Atari envs * Update `getattr` to `Env.get_wrapper_attr` * Reorder imports * Fix `seed` order * Fix collecting `max_steps` * Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper * Use `typing.List` rather than list * Fix env attribute forwarding * Separate out env attribute collection from its utilisation * Update for Gymnasium alpha 2 * Remove assert for OrderedDict * Update setup.py * Add type: ignore * Test with Gymnasium main * Remove `gymnasium.logger.debug/info` * Fix github CI yaml * Run gym 0.29.1 on python 3.10 * Update lower bounds * Integrate video recorder * Remove ordered dict * Update changelog --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
dd3d0acf15
commit
8f0b488bc5
16 changed files with 148 additions and 120 deletions
20
.github/workflows/ci.yml
vendored
20
.github/workflows/ci.yml
vendored
|
|
@ -21,7 +21,12 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
|
||||
include:
|
||||
# Default version
|
||||
- gymnasium-version: "1.0.0"
|
||||
# Add a new config to test gym<1.0
|
||||
- python-version: "3.10"
|
||||
gymnasium-version: "0.29.1"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
|
@ -37,15 +42,14 @@ jobs:
|
|||
# See https://github.com/astral-sh/uv/issues/1497
|
||||
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Install Atari Roms
|
||||
uv pip install --system autorom
|
||||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
|
||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
|
||||
AutoROM --accept-license --source-file Roms.tar.gz
|
||||
|
||||
uv pip install --system .[extra_no_roms,tests,docs]
|
||||
uv pip install --system .[extra,tests,docs]
|
||||
# Use headless version
|
||||
uv pip install --system opencv-python-headless
|
||||
- name: Install specific version of gym
|
||||
run: |
|
||||
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
|
||||
# Only run for python 3.10, downgrade gym to 0.29.1
|
||||
if: matrix.gymnasium-version != '1.0.0'
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
make lint
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ dependencies:
|
|||
- python=3.11
|
||||
- pytorch=2.5.0=py3.11_cpu_0
|
||||
- pip:
|
||||
- gymnasium>=0.28.1,<0.30
|
||||
- gymnasium>=0.29.1,<1.1.0
|
||||
- cloudpickle
|
||||
- opencv-python-headless
|
||||
- pandas
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.4.0a10 (WIP)
|
||||
Release 2.4.0a11 (WIP)
|
||||
--------------------------
|
||||
|
||||
**New algorithm: CrossQ in SB3 Contrib**
|
||||
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP)
|
|||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Increase minimum required version of Gymnasium to 0.29.1
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
|
||||
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
|
||||
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
|
||||
- Added support for Gymnasium v1.0
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -69,6 +71,7 @@ Others:
|
|||
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
|
||||
- Switched to uv to download packages faster on GitHub CI
|
||||
- Updated dependencies for read the doc
|
||||
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
|
|||
# ClassVar, implicit optional check not needed for tests
|
||||
"./tests/*.py" = ["RUF012", "RUF013"]
|
||||
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 15
|
||||
|
|
|
|||
43
setup.py
43
setup.py
|
|
@ -70,37 +70,13 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
|
|||
|
||||
""" # noqa:E501
|
||||
|
||||
# Atari Games download is sometimes problematic:
|
||||
# https://github.com/Farama-Foundation/AutoROM/issues/39
|
||||
# That's why we define extra packages without it.
|
||||
extra_no_roms = [
|
||||
# For render
|
||||
"opencv-python",
|
||||
"pygame",
|
||||
# Tensorboard support
|
||||
"tensorboard>=2.9.1",
|
||||
# Checking memory taken by replay buffer
|
||||
"psutil",
|
||||
# For progress bar callback
|
||||
"tqdm",
|
||||
"rich",
|
||||
# For atari games,
|
||||
"shimmy[atari]~=1.3.0",
|
||||
"pillow",
|
||||
]
|
||||
|
||||
extra_packages = extra_no_roms + [ # noqa: RUF005
|
||||
# For atari roms,
|
||||
"autorom[accept-rom-license]~=0.6.1",
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="stable_baselines3",
|
||||
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
|
||||
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"gymnasium>=0.28.1,<0.30",
|
||||
"gymnasium>=0.29.1,<1.1.0",
|
||||
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
|
||||
"torch>=1.13",
|
||||
# For saving models
|
||||
|
|
@ -133,8 +109,21 @@ setup(
|
|||
# Copy button for code snippets
|
||||
"sphinx_copybutton",
|
||||
],
|
||||
"extra": extra_packages,
|
||||
"extra_no_roms": extra_no_roms,
|
||||
"extra": [
|
||||
# For render
|
||||
"opencv-python",
|
||||
"pygame",
|
||||
# Tensorboard support
|
||||
"tensorboard>=2.9.1",
|
||||
# Checking memory taken by replay buffer
|
||||
"psutil",
|
||||
# For progress bar callback
|
||||
"tqdm",
|
||||
"rich",
|
||||
# For atari games,
|
||||
"ale-py>=0.9.0",
|
||||
"pillow",
|
||||
],
|
||||
},
|
||||
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
|
||||
author="Antonin Raffin",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from stable_baselines3.common.vec_env.patch_gym import _patch_env
|
||||
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
|
||||
|
||||
|
||||
class DummyVecEnv(VecEnv):
|
||||
|
|
@ -110,12 +110,12 @@ class DummyVecEnv(VecEnv):
|
|||
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
|
||||
|
||||
def _obs_from_buf(self) -> VecEnvObs:
|
||||
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
|
||||
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
|
||||
|
||||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
||||
"""Set attribute inside vectorized environments (see base class)."""
|
||||
|
|
@ -126,7 +126,7 @@ class DummyVecEnv(VecEnv):
|
|||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
|
|||
"Missing shimmy installation. You provided an OpenAI Gym environment. "
|
||||
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
|
||||
"In order to use OpenAI Gym environments with SB3, you need to "
|
||||
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
|
||||
"install shimmy (`pip install 'shimmy>=2.0'`)."
|
||||
) from e
|
||||
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import multiprocessing as mp
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
|
@ -54,10 +53,10 @@ def _worker(
|
|||
elif cmd == "get_spaces":
|
||||
remote.send((env.observation_space, env.action_space))
|
||||
elif cmd == "env_method":
|
||||
method = getattr(env, data[0])
|
||||
method = env.get_wrapper_attr(data[0])
|
||||
remote.send(method(*data[1], **data[2]))
|
||||
elif cmd == "get_attr":
|
||||
remote.send(getattr(env, data))
|
||||
remote.send(env.get_wrapper_attr(data))
|
||||
elif cmd == "set_attr":
|
||||
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
|
||||
elif cmd == "is_wrapped":
|
||||
|
|
@ -129,7 +128,7 @@ class SubprocVecEnv(VecEnv):
|
|||
results = [remote.recv() for remote in self.remotes]
|
||||
self.waiting = False
|
||||
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
|
||||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
|
||||
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
for env_idx, remote in enumerate(self.remotes):
|
||||
|
|
@ -139,7 +138,7 @@ class SubprocVecEnv(VecEnv):
|
|||
# Seeds and options are only used once
|
||||
self._reset_seeds()
|
||||
self._reset_options()
|
||||
return _flatten_obs(obs, self.observation_space)
|
||||
return _stack_obs(obs, self.observation_space)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.closed:
|
||||
|
|
@ -206,27 +205,28 @@ class SubprocVecEnv(VecEnv):
|
|||
return [self.remotes[i] for i in indices]
|
||||
|
||||
|
||||
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
|
||||
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
|
||||
"""
|
||||
Flatten observations, depending on the observation space.
|
||||
Stack observations (convert from a list of single env obs to a stack of obs),
|
||||
depending on the observation space.
|
||||
|
||||
:param obs: observations.
|
||||
A list or tuple of observations, one per environment.
|
||||
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
|
||||
:return: flattened observations.
|
||||
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
|
||||
:return: Concatenated observations.
|
||||
A NumPy array or a dict or tuple of stacked numpy arrays.
|
||||
Each NumPy array has the environment index as its first axis.
|
||||
"""
|
||||
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
|
||||
assert len(obs) > 0, "need observations from at least one environment"
|
||||
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
|
||||
assert len(obs_list) > 0, "need observations from at least one environment"
|
||||
|
||||
if isinstance(space, spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
|
||||
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
|
||||
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
|
||||
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
|
||||
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
|
||||
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
|
||||
obs_len = len(space.spaces)
|
||||
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
|
||||
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
|
||||
else:
|
||||
return np.stack(obs) # type: ignore[arg-type]
|
||||
return np.stack(obs_list) # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
Helpers for dealing with vectorized environments.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -12,17 +11,6 @@ from stable_baselines3.common.preprocessing import check_for_nested_spaces
|
|||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
|
||||
|
||||
|
||||
def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Deep-copy a dict of numpy arrays.
|
||||
|
||||
:param obs: a dict of numpy arrays.
|
||||
:return: a dict of copied numpy arrays.
|
||||
"""
|
||||
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
|
||||
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
|
||||
|
||||
|
||||
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
|
||||
"""
|
||||
Convert an internal representation raw_obs into the appropriate type
|
||||
|
|
@ -60,13 +48,13 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
|
|||
"""
|
||||
check_for_nested_spaces(obs_space)
|
||||
if isinstance(obs_space, spaces.Dict):
|
||||
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
|
||||
subspaces = obs_space.spaces
|
||||
elif isinstance(obs_space, spaces.Tuple):
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc]
|
||||
else:
|
||||
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
|
||||
subspaces = {None: obs_space} # type: ignore[assignment]
|
||||
subspaces = {None: obs_space} # type: ignore[assignment,dict-item]
|
||||
keys = []
|
||||
shapes = {}
|
||||
dtypes = {}
|
||||
|
|
@ -74,4 +62,4 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
|
|||
keys.append(key)
|
||||
shapes[key] = box.shape
|
||||
dtypes[key] = box.dtype
|
||||
return keys, shapes, dtypes
|
||||
return keys, shapes, dtypes # type: ignore[return-value]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
from typing import Callable
|
||||
import os.path
|
||||
from typing import Callable, List
|
||||
|
||||
from gymnasium.wrappers.monitoring import video_recorder
|
||||
import numpy as np
|
||||
from gymnasium import error, logger
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
|
@ -13,6 +15,11 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
|
||||
It requires ffmpeg or avconv to be installed on the machine.
|
||||
|
||||
Note: for now it only allows to record one video and all videos
|
||||
must have at least two frames.
|
||||
|
||||
The video recorder code was adapted from Gymnasium v1.0.
|
||||
|
||||
:param venv:
|
||||
:param video_folder: Where to save videos
|
||||
:param record_video_trigger: Function that defines when to start recording.
|
||||
|
|
@ -22,8 +29,6 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
:param name_prefix: Prefix to the video name
|
||||
"""
|
||||
|
||||
video_recorder: video_recorder.VideoRecorder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
|
|
@ -51,6 +56,8 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
self.env.metadata = metadata
|
||||
assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}"
|
||||
|
||||
self.frames_per_sec = self.env.metadata.get("render_fps", 30)
|
||||
|
||||
self.record_video_trigger = record_video_trigger
|
||||
self.video_folder = os.path.abspath(video_folder)
|
||||
# Create output folder if needed
|
||||
|
|
@ -60,54 +67,88 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
self.step_id = 0
|
||||
self.video_length = video_length
|
||||
|
||||
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
|
||||
self.video_path = os.path.join(self.video_folder, self.video_name)
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
self.recorded_frames: list[np.ndarray] = []
|
||||
|
||||
try:
|
||||
import moviepy # noqa: F401
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs = self.venv.reset()
|
||||
self.start_video_recorder()
|
||||
if self._video_enabled():
|
||||
self._start_video_recorder()
|
||||
return obs
|
||||
|
||||
def start_video_recorder(self) -> None:
|
||||
self.close_video_recorder()
|
||||
|
||||
video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
|
||||
base_path = os.path.join(self.video_folder, video_name)
|
||||
self.video_recorder = video_recorder.VideoRecorder(
|
||||
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
|
||||
)
|
||||
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames = 1
|
||||
self.recording = True
|
||||
def _start_video_recorder(self) -> None:
|
||||
self._start_recording()
|
||||
self._capture_frame()
|
||||
|
||||
def _video_enabled(self) -> bool:
|
||||
return self.record_video_trigger(self.step_id)
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
obs, rewards, dones, infos = self.venv.step_wait()
|
||||
|
||||
self.step_id += 1
|
||||
if self.recording:
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.recorded_frames > self.video_length:
|
||||
print(f"Saving video to {self.video_recorder.path}")
|
||||
self.close_video_recorder()
|
||||
self._capture_frame()
|
||||
if len(self.recorded_frames) > self.video_length:
|
||||
print(f"Saving video to {self.video_path}")
|
||||
self._stop_recording()
|
||||
elif self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
self._start_video_recorder()
|
||||
|
||||
return obs, rews, dones, infos
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def close_video_recorder(self) -> None:
|
||||
if self.recording:
|
||||
self.video_recorder.close()
|
||||
self.recording = False
|
||||
self.recorded_frames = 1
|
||||
def _capture_frame(self) -> None:
|
||||
assert self.recording, "Cannot capture a frame, recording wasn't started."
|
||||
|
||||
frame = self.env.render()
|
||||
if isinstance(frame, List):
|
||||
frame = frame[-1]
|
||||
|
||||
if isinstance(frame, np.ndarray):
|
||||
self.recorded_frames.append(frame)
|
||||
else:
|
||||
self._stop_recording()
|
||||
logger.warn(
|
||||
f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}."
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the wrapper then the video recorder."""
|
||||
VecEnvWrapper.close(self)
|
||||
self.close_video_recorder()
|
||||
if self.recording:
|
||||
self._stop_recording()
|
||||
|
||||
def __del__(self):
|
||||
self.close_video_recorder()
|
||||
def _start_recording(self) -> None:
|
||||
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
|
||||
if self.recording:
|
||||
self._stop_recording()
|
||||
|
||||
self.recording = True
|
||||
|
||||
def _stop_recording(self) -> None:
|
||||
"""Stop current recording and saves the video."""
|
||||
assert self.recording, "_stop_recording was called, but no recording was started"
|
||||
|
||||
if len(self.recorded_frames) == 0:
|
||||
logger.warn("Ignored saving a video as there were zero frames to save.")
|
||||
else:
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
|
||||
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
|
||||
clip.write_videofile(self.video_path)
|
||||
|
||||
self.recorded_frames = []
|
||||
self.recording = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Warn the user in case last video wasn't saved."""
|
||||
if len(self.recorded_frames) > 0:
|
||||
logger.warn("Unable to save last video! Did you call close()?")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.4.0a10
|
||||
2.4.0a11
|
||||
|
|
|
|||
|
|
@ -117,12 +117,11 @@ def test_consistency(model_class):
|
|||
"""
|
||||
use_discrete_actions = model_class == DQN
|
||||
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
|
||||
dict_env.seed(10)
|
||||
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
|
||||
env = gym.wrappers.FlattenObservation(dict_env)
|
||||
dict_env.seed(10)
|
||||
obs, _ = dict_env.reset()
|
||||
|
||||
kwargs = {}
|
||||
n_steps = 256
|
||||
|
||||
if model_class in {A2C, PPO}:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class CheckGAECallback(BaseCallback):
|
|||
buffer = self.model.rollout_buffer
|
||||
rollout_size = buffer.size()
|
||||
|
||||
max_steps = self.training_env.envs[0].max_steps
|
||||
max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps")
|
||||
gamma = self.model.gamma
|
||||
gae_lambda = self.model.gae_lambda
|
||||
value = self.model.policy.constant_value
|
||||
|
|
|
|||
|
|
@ -592,6 +592,7 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path):
|
|||
"""
|
||||
|
||||
STATS_WINDOW_SIZE = 10
|
||||
|
||||
# Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
|
||||
dummy_successes = [
|
||||
[True] * 3 + [False] * 7,
|
||||
|
|
@ -603,16 +604,17 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path):
|
|||
# Monitor the env to track the success info
|
||||
monitor_file = str(tmp_path / "monitor.csv")
|
||||
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
|
||||
steps_per_log = env.unwrapped.steps_per_log
|
||||
|
||||
# Equip the model of a custom logger to check the success_rate info
|
||||
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
|
||||
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1)
|
||||
logger = InMemoryLogger()
|
||||
model.set_logger(logger)
|
||||
|
||||
# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.3
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.5
|
||||
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
||||
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
|
||||
assert logger.name_to_value["rollout/success_rate"] == 0.8
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import ale_py
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
|
@ -24,6 +25,8 @@ from stable_baselines3.common.utils import (
|
|||
)
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
|
||||
gym.register_envs(ale_py)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")])
|
||||
@pytest.mark.parametrize("n_envs", [1, 2])
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ def test_vecenv_dict_spaces(vec_env_class):
|
|||
space = spaces.Dict(SPACES)
|
||||
|
||||
def obs_assert(obs):
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
assert isinstance(obs, dict)
|
||||
assert obs.keys() == space.spaces.keys()
|
||||
for key, values in obs.items():
|
||||
check_vecenv_obs(values, space.spaces[key])
|
||||
|
|
|
|||
Loading…
Reference in a new issue