mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-29 03:31:08 +00:00
Fix video recorder and add test (#2063)
* Fix video recorder and add test * Update github CI * Install ffmpeg * Revert "Update github CI" This reverts commit 07791e97fccae4f003b2909428b23f59557d7034. * Skip VecVideoRecorder test on github
This commit is contained in:
parent
0fd0db0b7b
commit
57e8b97df5
4 changed files with 65 additions and 14 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.5.0a0 (WIP)
|
||||
Release 2.5.0a1 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -42,6 +42,14 @@ Documentation:
|
|||
- Add FootstepNet Envs to the project page (@cgaspard3333)
|
||||
- Added FRASA to the project page (@MarcDcls)
|
||||
|
||||
Release 2.4.1 (2024-12-20)
|
||||
--------------------------
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos
|
||||
|
||||
|
||||
Release 2.4.0 (2024-11-18)
|
||||
--------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
:param name_prefix: Prefix to the video name
|
||||
"""
|
||||
|
||||
video_name: str
|
||||
video_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
|
|
@ -50,7 +53,7 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
|
||||
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
|
||||
metadata = temp_env.get_attr("metadata")[0]
|
||||
else:
|
||||
else: # pragma: no cover # assume gym interface
|
||||
metadata = temp_env.metadata
|
||||
|
||||
self.env.metadata = metadata
|
||||
|
|
@ -67,15 +70,12 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
self.step_id = 0
|
||||
self.video_length = video_length
|
||||
|
||||
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
|
||||
self.video_path = os.path.join(self.video_folder, self.video_name)
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames: list[np.ndarray] = []
|
||||
|
||||
try:
|
||||
import moviepy # noqa: F401
|
||||
except ImportError as e:
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
|
|
@ -85,6 +85,9 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
return obs
|
||||
|
||||
def _start_video_recorder(self) -> None:
|
||||
# Update video name and path
|
||||
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
|
||||
self.video_path = os.path.join(self.video_folder, self.video_name)
|
||||
self._start_recording()
|
||||
self._capture_frame()
|
||||
|
||||
|
|
@ -109,8 +112,6 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
assert self.recording, "Cannot capture a frame, recording wasn't started."
|
||||
|
||||
frame = self.env.render()
|
||||
if isinstance(frame, list):
|
||||
frame = frame[-1]
|
||||
|
||||
if isinstance(frame, np.ndarray):
|
||||
self.recorded_frames.append(frame)
|
||||
|
|
@ -123,12 +124,12 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
def close(self) -> None:
|
||||
"""Closes the wrapper then the video recorder."""
|
||||
VecEnvWrapper.close(self)
|
||||
if self.recording:
|
||||
if self.recording: # pragma: no cover
|
||||
self._stop_recording()
|
||||
|
||||
def _start_recording(self) -> None:
|
||||
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
|
||||
if self.recording:
|
||||
if self.recording: # pragma: no cover
|
||||
self._stop_recording()
|
||||
|
||||
self.recording = True
|
||||
|
|
@ -137,7 +138,7 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
"""Stop current recording and saves the video."""
|
||||
assert self.recording, "_stop_recording was called, but no recording was started"
|
||||
|
||||
if len(self.recorded_frames) == 0:
|
||||
if len(self.recorded_frames) == 0: # pragma: no cover
|
||||
logger.warn("Ignored saving a video as there were zero frames to save.")
|
||||
else:
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
|
|
@ -150,5 +151,5 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
|
||||
def __del__(self) -> None:
|
||||
"""Warn the user in case last video wasn't saved."""
|
||||
if len(self.recorded_frames) > 0:
|
||||
if len(self.recorded_frames) > 0: # pragma: no cover
|
||||
logger.warn("Unable to save last video! Did you call close()?")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.5.0a0
|
||||
2.5.0a1
|
||||
|
|
|
|||
|
|
@ -11,9 +11,17 @@ import numpy as np
|
|||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder
|
||||
|
||||
try:
|
||||
import moviepy
|
||||
|
||||
have_moviepy = True
|
||||
except ImportError:
|
||||
have_moviepy = False
|
||||
|
||||
N_ENVS = 3
|
||||
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
|
||||
|
|
@ -624,3 +632,37 @@ def test_render(vec_env_class):
|
|||
vec_env.render()
|
||||
|
||||
vec_env.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
|
||||
def test_video_recorder(tmp_path):
|
||||
env_id = "CartPole-v1"
|
||||
video_folder = str(tmp_path)
|
||||
|
||||
vec_env = make_vec_env(env_id, n_envs=1)
|
||||
|
||||
# Wrap to check unwrapping works
|
||||
vec_env = VecNormalize(vec_env)
|
||||
|
||||
# Record the video starting at the first step
|
||||
vec_env = VecVideoRecorder(
|
||||
vec_env,
|
||||
video_folder,
|
||||
record_video_trigger=lambda x: x % 65 == 0,
|
||||
video_length=10,
|
||||
name_prefix=f"agent-{env_id}",
|
||||
)
|
||||
|
||||
model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)
|
||||
|
||||
model.learn(total_timesteps=128)
|
||||
|
||||
# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
|
||||
video_files = list(map(str, tmp_path.glob("*.mp4")))
|
||||
|
||||
# Clean up
|
||||
vec_env.close()
|
||||
|
||||
assert len(video_files) == 2
|
||||
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
|
||||
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue