Fix video recorder and add test (#2063)

* Fix video recorder and add test

* Update github CI

* Install ffmpeg

* Revert "Update github CI"

This reverts commit 07791e97fccae4f003b2909428b23f59557d7034.

* Skip VecVideoRecorder test on github
This commit is contained in:
Antonin RAFFIN 2024-12-21 08:24:25 +01:00 committed by GitHub
parent 0fd0db0b7b
commit 57e8b97df5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 65 additions and 14 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.5.0a0 (WIP)
Release 2.5.0a1 (WIP)
--------------------------
Breaking Changes:
@ -42,6 +42,14 @@ Documentation:
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)
Release 2.4.1 (2024-12-20)
--------------------------
Bug Fixes:
^^^^^^^^^^
- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos
Release 2.4.0 (2024-11-18)
--------------------------

View file

@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: Prefix to the video name
"""
video_name: str
video_path: str
def __init__(
self,
venv: VecEnv,
@ -50,7 +53,7 @@ class VecVideoRecorder(VecEnvWrapper):
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else:
else: # pragma: no cover # assume gym interface
metadata = temp_env.metadata
self.env.metadata = metadata
@ -67,15 +70,12 @@ class VecVideoRecorder(VecEnvWrapper):
self.step_id = 0
self.video_length = video_length
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)
self.recording = False
self.recorded_frames: list[np.ndarray] = []
try:
import moviepy # noqa: F401
except ImportError as e:
except ImportError as e: # pragma: no cover
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
def reset(self) -> VecEnvObs:
@ -85,6 +85,9 @@ class VecVideoRecorder(VecEnvWrapper):
return obs
def _start_video_recorder(self) -> None:
# Update video name and path
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)
self._start_recording()
self._capture_frame()
@ -109,8 +112,6 @@ class VecVideoRecorder(VecEnvWrapper):
assert self.recording, "Cannot capture a frame, recording wasn't started."
frame = self.env.render()
if isinstance(frame, list):
frame = frame[-1]
if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
@ -123,12 +124,12 @@ class VecVideoRecorder(VecEnvWrapper):
def close(self) -> None:
"""Closes the wrapper then the video recorder."""
VecEnvWrapper.close(self)
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()
def _start_recording(self) -> None:
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()
self.recording = True
@ -137,7 +138,7 @@ class VecVideoRecorder(VecEnvWrapper):
"""Stop current recording and saves the video."""
assert self.recording, "_stop_recording was called, but no recording was started"
if len(self.recorded_frames) == 0:
if len(self.recorded_frames) == 0: # pragma: no cover
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
@ -150,5 +151,5 @@ class VecVideoRecorder(VecEnvWrapper):
def __del__(self) -> None:
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0:
if len(self.recorded_frames) > 0: # pragma: no cover
logger.warn("Unable to save last video! Did you call close()?")

View file

@ -1 +1 @@
2.5.0a0
2.5.0a1

View file

@ -11,9 +11,17 @@ import numpy as np
import pytest
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder
try:
import moviepy
have_moviepy = True
except ImportError:
have_moviepy = False
N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
@ -624,3 +632,37 @@ def test_render(vec_env_class):
vec_env.render()
vec_env.close()
@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
def test_video_recorder(tmp_path):
env_id = "CartPole-v1"
video_folder = str(tmp_path)
vec_env = make_vec_env(env_id, n_envs=1)
# Wrap to check unwrapping works
vec_env = VecNormalize(vec_env)
# Record the video starting at the first step
vec_env = VecVideoRecorder(
vec_env,
video_folder,
record_video_trigger=lambda x: x % 65 == 0,
video_length=10,
name_prefix=f"agent-{env_id}",
)
model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)
model.learn(total_timesteps=128)
# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
video_files = list(map(str, tmp_path.glob("*.mp4")))
# Clean up
vec_env.close()
assert len(video_files) == 2
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]