mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix tensorboad video slow numpy->torch conversion (#1910)
* fixed tb video docs * updated changelog * add comment on expected render() output * Update changelog.rst --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
e93175084f
commit
35eccaf04f
2 changed files with 8 additions and 5 deletions
|
|
@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te
|
|||
|
||||
import gymnasium as gym
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
|
@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te
|
|||
:param _locals: A dictionary containing all local variables of the callback's scope
|
||||
:param _globals: A dictionary containing all global variables of the callback's scope
|
||||
"""
|
||||
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
|
||||
# with values in [0, 1], as described in
|
||||
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
|
||||
screen = self._eval_env.render(mode="rgb_array")
|
||||
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
|
||||
screens.append(screen.transpose(2, 0, 1))
|
||||
|
|
@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te
|
|||
)
|
||||
self.logger.record(
|
||||
"trajectory/video",
|
||||
Video(th.ByteTensor([screens]), fps=40),
|
||||
Video(th.from_numpy(np.asarray([screens])), fps=40),
|
||||
exclude=("stdout", "log", "json", "csv"),
|
||||
)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ Bug Fixes:
|
|||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added ER-MRL to the project page
|
||||
|
||||
- Added ER-MRL to the project page (@corentinlger)
|
||||
- Updated Tensorboard Logging Videos documentation (@NickLucche)
|
||||
|
||||
Release 2.3.1 (2024-04-22)
|
||||
--------------------------
|
||||
|
|
@ -50,7 +50,6 @@ Documentation:
|
|||
- Updated SBX documentation (CrossQ and deprecated DroQ)
|
||||
- Updated RL Tips and Tricks section
|
||||
|
||||
|
||||
Release 2.3.0 (2024-03-31)
|
||||
--------------------------
|
||||
|
||||
|
|
@ -1641,4 +1640,4 @@ And all the contributors:
|
|||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
|
||||
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
|
||||
@marekm4 @stagoverflow @rushitnshah @markscsmith
|
||||
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche
|
||||
|
|
|
|||
Loading…
Reference in a new issue