Doc fixes and add monitor_kwargs parameter (#230)

* Fix type annotation

* Fix migration doc for A2C

* Update version

* Add `monitor_kwargs` argument

* Update docs/guide/migration.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Fix make atari env

* Fix docstring

* Renamed LearningRateSchedule

Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
Antonin RAFFIN 2020-11-20 10:28:54 +01:00 committed by GitHub
parent 9069cf55f1
commit d04aad2a20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 74 additions and 49 deletions

View file

@ -104,7 +104,7 @@ A2C
PyTorch implementation of RMSprop `differs from Tensorflow's <https://github.com/pytorch/pytorch/issues/23796>`_,
which leads to `different and potentially more unstable results <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
Use ``stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike`` optimizer to match the results
with Tensorflow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``
with TensorFlow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``
PPO

View file

@ -3,7 +3,7 @@
Changelog
==========
Pre-Release 0.11.0a0 (WIP)
Pre-Release 0.11.0a1 (WIP)
-------------------------------
Breaking Changes:
@ -23,6 +23,7 @@ New Features:
an environment for specific wrapper.
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
with given Gym wrappers.
- Added ``monitor_kwargs`` parameter to ``make_vec_env`` and ``make_atari_env``
Bug Fixes:
^^^^^^^^^^
@ -43,7 +44,9 @@ Documentation:
^^^^^^^^^^^^^^
- Updated algorithm table
- Minor docstring improvements regarding rollout (@stheid)
- Fix migration doc for ``A2C`` (epsilon parameter)
- Fix ``clip_range`` docstring
- Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico)
Pre-Release 0.10.0 (2020-10-28)
-------------------------------
@ -518,3 +521,4 @@ And all the contributors:
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico

View file

@ -14,7 +14,7 @@ It uses multiple workers to avoid the use of a replay buffer.
If you find training unstable or want to match performance of stable-baselines A2C, consider using
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``.
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.

View file

@ -7,7 +7,7 @@ from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
@ -55,7 +55,7 @@ class A2C(OnPolicyAlgorithm):
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 7e-4,
learning_rate: Union[float, Schedule] = 7e-4,
n_steps: int = 5,
gamma: float = 0.99,
gae_lambda: float = 1.0,

View file

@ -18,7 +18,7 @@ from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
@ -89,7 +89,7 @@ class BaseAlgorithm(ABC):
policy: Type[BasePolicy],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, LearningRateSchedule],
learning_rate: Union[float, Schedule],
policy_kwargs: Dict[str, Any] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
@ -129,7 +129,7 @@ class BaseAlgorithm(ABC):
self.policy = None
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[LearningRateSchedule]
self.lr_schedule = None # type: Optional[Schedule]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
# When using VecNormalize:

View file

@ -273,7 +273,6 @@ class EvalCallback(EventCallback):
according to performance on the eval env will be saved.
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param deterministic: Whether to render or not the environment during evaluation
:param render: Whether to render or not the environment during evaluation
:param verbose:
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been

View file

@ -45,6 +45,7 @@ def make_vec_env(
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
@ -63,10 +64,12 @@ def make_vec_env(
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:return: The wrapped environment
"""
env_kwargs = {} if env_kwargs is None else env_kwargs
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
def make_env(rank):
def _init():
@ -83,7 +86,7 @@ def make_vec_env(
# Create the monitor folder if needed
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path)
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env)
@ -109,6 +112,7 @@ def make_atari_env(
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
@ -125,6 +129,7 @@ def make_atari_env(
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:return: The wrapped environment
"""
if wrapper_kwargs is None:
@ -144,4 +149,5 @@ def make_atari_env(
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
monitor_kwargs=monitor_kwargs,
)

View file

@ -15,7 +15,7 @@ from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback, RolloutReturn
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
policy: Type[BasePolicy],
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, LearningRateSchedule],
learning_rate: Union[float, Schedule],
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 256,

View file

@ -10,7 +10,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -52,7 +52,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule],
learning_rate: Union[float, Schedule],
n_steps: int,
gamma: float,
gae_lambda: float,

View file

@ -21,7 +21,7 @@ from stable_baselines3.common.distributions import (
)
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import LearningRateSchedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation
from stable_baselines3.common.vec_env import VecTransposeImage
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
@ -365,7 +365,7 @@ class ActorCriticPolicy(BasePolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
@ -480,7 +480,7 @@ class ActorCriticPolicy(BasePolicy):
self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device
)
def _build(self, lr_schedule: LearningRateSchedule) -> None:
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
@ -663,7 +663,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,

View file

@ -14,7 +14,9 @@ GymStepReturn = Tuple[GymObs, float, bool, Dict]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
LearningRateSchedule = Callable[[float], float]
# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Schedule = Callable[[float], float]
class RolloutBufferSamples(NamedTuple):

View file

@ -16,7 +16,7 @@ except ImportError:
SummaryWriter = None
from stable_baselines3.common import logger
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule
from stable_baselines3.common.type_aliases import GymEnv, Schedule
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
@ -70,7 +70,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) -> LearningRateSchedule:
def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
@ -88,7 +88,7 @@ def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) ->
return value_schedule
def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRateSchedule:
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
@ -112,7 +112,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRate
return func
def constant_fn(val: float) -> LearningRateSchedule:
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)

View file

@ -4,7 +4,7 @@ import torch as th
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
@ -55,7 +55,7 @@ class DDPG(TD3):
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 1e-3,
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 100,

View file

@ -6,7 +6,7 @@ from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import DQNPolicy
@ -59,7 +59,7 @@ class DQN(OffPolicyAlgorithm):
self,
policy: Union[str, Type[DQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 1e-4,
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1000000,
learning_starts: int = 50000,
batch_size: Optional[int] = 32,

View file

@ -6,7 +6,7 @@ from torch import nn
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import LearningRateSchedule
from stable_baselines3.common.type_aliases import Schedule
class QNetwork(BasePolicy):
@ -104,7 +104,7 @@ class DQNPolicy(BasePolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -143,7 +143,7 @@ class DQNPolicy(BasePolicy):
self.q_net, self.q_net_target = None, None
self._build(lr_schedule)
def _build(self, lr_schedule: LearningRateSchedule) -> None:
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the network and the optimizer.
@ -211,7 +211,7 @@ class CnnPolicy(DQNPolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,

View file

@ -8,7 +8,7 @@ from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
@ -66,14 +66,14 @@ class PPO(OnPolicyAlgorithm):
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 3e-4,
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 2048,
batch_size: Optional[int] = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: float = 0.2,
clip_range_vf: Optional[float] = None,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,

View file

@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import (
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import LearningRateSchedule
from stable_baselines3.common.type_aliases import Schedule
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
@ -228,7 +228,7 @@ class SACPolicy(BasePolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
@ -295,7 +295,7 @@ class SACPolicy(BasePolicy):
self._build(lr_schedule)
def _build(self, lr_schedule: LearningRateSchedule) -> None:
def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
@ -398,7 +398,7 @@ class CnnPolicy(SACPolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,

View file

@ -7,7 +7,7 @@ from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy
@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm):
self,
policy: Union[str, Type[SACPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 3e-4,
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 256,

View file

@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
create_mlp,
get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import LearningRateSchedule
from stable_baselines3.common.type_aliases import Schedule
class Actor(BasePolicy):
@ -109,7 +109,7 @@ class TD3Policy(BasePolicy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@ -164,7 +164,7 @@ class TD3Policy(BasePolicy):
self._build(lr_schedule)
def _build(self, lr_schedule: LearningRateSchedule) -> None:
def _build(self, lr_schedule: Schedule) -> None:
# Create actor and target
# the features extractor should not be shared
self.actor = self.make_actor(features_extractor=None)
@ -253,7 +253,7 @@ class CnnPolicy(TD3Policy):
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: LearningRateSchedule,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,

View file

@ -7,7 +7,7 @@ from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.td3.policies import TD3Policy
@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm):
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, LearningRateSchedule] = 1e-3,
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 100,

View file

@ -1 +1 @@
0.11.0a0
0.11.0a1

View file

@ -70,6 +70,20 @@ def test_vec_env_kwargs():
assert env.get_attr("goal_velocity")[0] == 0.11
def test_vec_env_monitor_kwargs():
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
assert env.get_attr("allow_early_resets")[0] is False
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
assert env.get_attr("allow_early_resets")[0] is False
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True})
assert env.get_attr("allow_early_resets")[0] is True
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True})
assert env.get_attr("allow_early_resets")[0] is True
def test_custom_vec_env(tmp_path):
"""
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.