diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 3bbee83..d7dffb5 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -104,7 +104,7 @@ A2C PyTorch implementation of RMSprop `differs from Tensorflow's `_, which leads to `different and potentially more unstable results `_. Use ``stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike`` optimizer to match the results - with Tensorflow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))`` + with TensorFlow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))`` PPO diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8a64a29..ee5c331 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.11.0a0 (WIP) +Pre-Release 0.11.0a1 (WIP) ------------------------------- Breaking Changes: @@ -23,6 +23,7 @@ New Features: an environment for specific wrapper. - Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped with given Gym wrappers. +- Added ``monitor_kwargs`` parameter to ``make_vec_env`` and ``make_atari_env`` Bug Fixes: ^^^^^^^^^^ @@ -43,7 +44,9 @@ Documentation: ^^^^^^^^^^^^^^ - Updated algorithm table - Minor docstring improvements regarding rollout (@stheid) - +- Fix migration doc for ``A2C`` (epsilon parameter) +- Fix ``clip_range`` docstring +- Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico) Pre-Release 0.10.0 (2020-10-28) ------------------------------- @@ -518,3 +521,4 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray +@tfederico diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 011eb56..4c25040 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -14,7 +14,7 @@ It uses multiple workers to avoid the use of a replay buffer. If you find training unstable or want to match performance of stable-baselines A2C, consider using ``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``. - You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``. + You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``. Read more `here `_. diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 4720526..b78bf22 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -7,7 +7,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -55,7 +55,7 @@ class A2C(OnPolicyAlgorithm): self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 7e-4, + learning_rate: Union[float, Schedule] = 7e-4, n_steps: int = 5, gamma: float = 0.99, gae_lambda: float = 1.0, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index c99e713..908b999 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -18,7 +18,7 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy, get_policy_from_name from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -89,7 +89,7 @@ class BaseAlgorithm(ABC): policy: Type[BasePolicy], env: Union[GymEnv, str, None], policy_base: Type[BasePolicy], - learning_rate: Union[float, LearningRateSchedule], + learning_rate: Union[float, Schedule], policy_kwargs: Dict[str, Any] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -129,7 +129,7 @@ class BaseAlgorithm(ABC): self.policy = None self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log - self.lr_schedule = None # type: Optional[LearningRateSchedule] + self.lr_schedule = None # type: Optional[Schedule] self._last_obs = None # type: Optional[np.ndarray] self._last_dones = None # type: Optional[np.ndarray] # When using VecNormalize: diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c114a93..34aac9b 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -273,7 +273,6 @@ class EvalCallback(EventCallback): according to performance on the eval env will be saved. :param deterministic: Whether the evaluation should use a stochastic or deterministic actions. - :param deterministic: Whether to render or not the environment during evaluation :param render: Whether to render or not the environment during evaluation :param verbose: :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 2b8d1f0..177e744 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -45,6 +45,7 @@ def make_vec_env( env_kwargs: Optional[Dict[str, Any]] = None, vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, + monitor_kwargs: Optional[Dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored ``VecEnv``. @@ -63,10 +64,12 @@ def make_vec_env( :param env_kwargs: Optional keyword argument to pass to the env constructor :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. :return: The wrapped environment """ env_kwargs = {} if env_kwargs is None else env_kwargs vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs + monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs def make_env(rank): def _init(): @@ -83,7 +86,7 @@ def make_vec_env( # Create the monitor folder if needed if monitor_path is not None: os.makedirs(monitor_dir, exist_ok=True) - env = Monitor(env, filename=monitor_path) + env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper if wrapper_class is not None: env = wrapper_class(env) @@ -109,6 +112,7 @@ def make_atari_env( env_kwargs: Optional[Dict[str, Any]] = None, vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, vec_env_kwargs: Optional[Dict[str, Any]] = None, + monitor_kwargs: Optional[Dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. @@ -125,6 +129,7 @@ def make_atari_env( :param env_kwargs: Optional keyword argument to pass to the env constructor :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. :return: The wrapped environment """ if wrapper_kwargs is None: @@ -144,4 +149,5 @@ def make_atari_env( env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, vec_env_kwargs=vec_env_kwargs, + monitor_kwargs=monitor_kwargs, ) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 37662a6..4545610 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -15,7 +15,7 @@ from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback, RolloutReturn +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): policy: Type[BasePolicy], env: Union[GymEnv, str], policy_base: Type[BasePolicy], - learning_rate: Union[float, LearningRateSchedule], + learning_rate: Union[float, Schedule], buffer_size: int = int(1e6), learning_starts: int = 100, batch_size: int = 256, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 527cb9d..9f7a665 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -10,7 +10,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -52,7 +52,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule], + learning_rate: Union[float, Schedule], n_steps: int, gamma: float, gae_lambda: float, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index b0ae2dc..5c97431 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -21,7 +21,7 @@ from stable_baselines3.common.distributions import ( ) from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp -from stable_baselines3.common.type_aliases import LearningRateSchedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation from stable_baselines3.common.vec_env import VecTransposeImage from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper @@ -365,7 +365,7 @@ class ActorCriticPolicy(BasePolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, @@ -480,7 +480,7 @@ class ActorCriticPolicy(BasePolicy): self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device ) - def _build(self, lr_schedule: LearningRateSchedule) -> None: + def _build(self, lr_schedule: Schedule) -> None: """ Create the networks and the optimizer. @@ -663,7 +663,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index d189f5d..80cc354 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -14,7 +14,9 @@ GymStepReturn = Tuple[GymObs, float, bool, Dict] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] -LearningRateSchedule = Callable[[float], float] +# A schedule takes the remaining progress as input +# and ouputs a scalar (e.g. learning rate, clip range, ...) +Schedule = Callable[[float], float] class RolloutBufferSamples(NamedTuple): diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 6377c8e..e8ca9f6 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -16,7 +16,7 @@ except ImportError: SummaryWriter = None from stable_baselines3.common import logger -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule +from stable_baselines3.common.type_aliases import GymEnv, Schedule def set_random_seed(seed: int, using_cuda: bool = False) -> None: @@ -70,7 +70,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> param_group["lr"] = learning_rate -def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) -> LearningRateSchedule: +def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule: """ Transform (if needed) learning rate and clip range (for PPO) to callable. @@ -88,7 +88,7 @@ def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) -> return value_schedule -def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRateSchedule: +def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: """ Create a function that interpolates linearly between start and end between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. @@ -112,7 +112,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRate return func -def constant_fn(val: float) -> LearningRateSchedule: +def constant_fn(val: float) -> Schedule: """ Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index d4af6cc..e696aeb 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -4,7 +4,7 @@ import torch as th from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 @@ -55,7 +55,7 @@ class DDPG(TD3): self, policy: Union[str, Type[TD3Policy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 1e-3, + learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = int(1e6), learning_starts: int = 100, batch_size: int = 100, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 2c927af..f715a15 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update from stable_baselines3.dqn.policies import DQNPolicy @@ -59,7 +59,7 @@ class DQN(OffPolicyAlgorithm): self, policy: Union[str, Type[DQNPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 1e-4, + learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1000000, learning_starts: int = 50000, batch_size: Optional[int] = 32, diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 1bbff1d..f72424e 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -6,7 +6,7 @@ from torch import nn from stable_baselines3.common.policies import BasePolicy, register_policy from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp -from stable_baselines3.common.type_aliases import LearningRateSchedule +from stable_baselines3.common.type_aliases import Schedule class QNetwork(BasePolicy): @@ -104,7 +104,7 @@ class DQNPolicy(BasePolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, @@ -143,7 +143,7 @@ class DQNPolicy(BasePolicy): self.q_net, self.q_net_target = None, None self._build(lr_schedule) - def _build(self, lr_schedule: LearningRateSchedule) -> None: + def _build(self, lr_schedule: Schedule) -> None: """ Create the network and the optimizer. @@ -211,7 +211,7 @@ class CnnPolicy(DQNPolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 0f75c26..a2b6aea 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -66,14 +66,14 @@ class PPO(OnPolicyAlgorithm): self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 3e-4, + learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, batch_size: Optional[int] = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: float = 0.2, - clip_range_vf: Optional[float] = None, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 22a2b2a..8ba5897 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import ( create_mlp, get_actor_critic_arch, ) -from stable_baselines3.common.type_aliases import LearningRateSchedule +from stable_baselines3.common.type_aliases import Schedule # CAP the standard deviation of the actor LOG_STD_MAX = 2 @@ -228,7 +228,7 @@ class SACPolicy(BasePolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, @@ -295,7 +295,7 @@ class SACPolicy(BasePolicy): self._build(lr_schedule) - def _build(self, lr_schedule: LearningRateSchedule) -> None: + def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -398,7 +398,7 @@ class CnnPolicy(SACPolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index d52cf10..e94249f 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -7,7 +7,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update from stable_baselines3.sac.policies import SACPolicy @@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm): self, policy: Union[str, Type[SACPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 3e-4, + learning_rate: Union[float, Schedule] = 3e-4, buffer_size: int = int(1e6), learning_starts: int = 100, batch_size: int = 256, diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index e333734..225a7b6 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import ( create_mlp, get_actor_critic_arch, ) -from stable_baselines3.common.type_aliases import LearningRateSchedule +from stable_baselines3.common.type_aliases import Schedule class Actor(BasePolicy): @@ -109,7 +109,7 @@ class TD3Policy(BasePolicy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, @@ -164,7 +164,7 @@ class TD3Policy(BasePolicy): self._build(lr_schedule) - def _build(self, lr_schedule: LearningRateSchedule) -> None: + def _build(self, lr_schedule: Schedule) -> None: # Create actor and target # the features extractor should not be shared self.actor = self.make_actor(features_extractor=None) @@ -253,7 +253,7 @@ class CnnPolicy(TD3Policy): self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: LearningRateSchedule, + lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a0e5004..2c2d273 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -7,7 +7,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update from stable_baselines3.td3.policies import TD3Policy @@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm): self, policy: Union[str, Type[TD3Policy]], env: Union[GymEnv, str], - learning_rate: Union[float, LearningRateSchedule] = 1e-3, + learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = int(1e6), learning_starts: int = 100, batch_size: int = 100, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index d22e31d..e0cbcd5 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.11.0a0 +0.11.0a1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 88aa76a..d595fb6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -70,6 +70,20 @@ def test_vec_env_kwargs(): assert env.get_attr("goal_velocity")[0] == 0.11 +def test_vec_env_monitor_kwargs(): + env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False}) + assert env.get_attr("allow_early_resets")[0] is False + + env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False}) + assert env.get_attr("allow_early_resets")[0] is False + + env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True}) + assert env.get_attr("allow_early_resets")[0] is True + + env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True}) + assert env.get_attr("allow_early_resets")[0] is True + + def test_custom_vec_env(tmp_path): """ Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.