mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Doc fixes and add monitor_kwargs parameter (#230)
* Fix type annotation * Fix migration doc for A2C * Update version * Add `monitor_kwargs` argument * Update docs/guide/migration.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Fix make atari env * Fix docstring * Renamed LearningRateSchedule Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
parent
9069cf55f1
commit
d04aad2a20
22 changed files with 74 additions and 49 deletions
|
|
@ -104,7 +104,7 @@ A2C
|
|||
PyTorch implementation of RMSprop `differs from Tensorflow's <https://github.com/pytorch/pytorch/issues/23796>`_,
|
||||
which leads to `different and potentially more unstable results <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
|
||||
Use ``stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike`` optimizer to match the results
|
||||
with Tensorflow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``
|
||||
with TensorFlow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``
|
||||
|
||||
|
||||
PPO
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.11.0a0 (WIP)
|
||||
Pre-Release 0.11.0a1 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -23,6 +23,7 @@ New Features:
|
|||
an environment for specific wrapper.
|
||||
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
|
||||
with given Gym wrappers.
|
||||
- Added ``monitor_kwargs`` parameter to ``make_vec_env`` and ``make_atari_env``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -43,7 +44,9 @@ Documentation:
|
|||
^^^^^^^^^^^^^^
|
||||
- Updated algorithm table
|
||||
- Minor docstring improvements regarding rollout (@stheid)
|
||||
|
||||
- Fix migration doc for ``A2C`` (epsilon parameter)
|
||||
- Fix ``clip_range`` docstring
|
||||
- Fix duplicated parameter in ``EvalCallback`` docstring (thanks @tfederico)
|
||||
|
||||
Pre-Release 0.10.0 (2020-10-28)
|
||||
-------------------------------
|
||||
|
|
@ -518,3 +521,4 @@ And all the contributors:
|
|||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ It uses multiple workers to avoid the use of a replay buffer.
|
|||
|
||||
If you find training unstable or want to match performance of stable-baselines A2C, consider using
|
||||
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
|
||||
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
|
||||
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``.
|
||||
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 7e-4,
|
||||
learning_rate: Union[float, Schedule] = 7e-4,
|
||||
n_steps: int = 5,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 1.0,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from stable_baselines3.common.noise import ActionNoise
|
|||
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import (
|
||||
check_for_correct_spaces,
|
||||
get_device,
|
||||
|
|
@ -89,7 +89,7 @@ class BaseAlgorithm(ABC):
|
|||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str, None],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, LearningRateSchedule],
|
||||
learning_rate: Union[float, Schedule],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
|
|
@ -129,7 +129,7 @@ class BaseAlgorithm(ABC):
|
|||
self.policy = None
|
||||
self.learning_rate = learning_rate
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self.lr_schedule = None # type: Optional[LearningRateSchedule]
|
||||
self.lr_schedule = None # type: Optional[Schedule]
|
||||
self._last_obs = None # type: Optional[np.ndarray]
|
||||
self._last_dones = None # type: Optional[np.ndarray]
|
||||
# When using VecNormalize:
|
||||
|
|
|
|||
|
|
@ -273,7 +273,6 @@ class EvalCallback(EventCallback):
|
|||
according to performance on the eval env will be saved.
|
||||
:param deterministic: Whether the evaluation should
|
||||
use a stochastic or deterministic actions.
|
||||
:param deterministic: Whether to render or not the environment during evaluation
|
||||
:param render: Whether to render or not the environment during evaluation
|
||||
:param verbose:
|
||||
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def make_vec_env(
|
|||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored ``VecEnv``.
|
||||
|
|
@ -63,10 +64,12 @@ def make_vec_env(
|
|||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
env_kwargs = {} if env_kwargs is None else env_kwargs
|
||||
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
|
||||
monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
|
||||
|
||||
def make_env(rank):
|
||||
def _init():
|
||||
|
|
@ -83,7 +86,7 @@ def make_vec_env(
|
|||
# Create the monitor folder if needed
|
||||
if monitor_path is not None:
|
||||
os.makedirs(monitor_dir, exist_ok=True)
|
||||
env = Monitor(env, filename=monitor_path)
|
||||
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
|
||||
# Optionally, wrap the environment with the provided wrapper
|
||||
if wrapper_class is not None:
|
||||
env = wrapper_class(env)
|
||||
|
|
@ -109,6 +112,7 @@ def make_atari_env(
|
|||
env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
|
||||
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
||||
monitor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> VecEnv:
|
||||
"""
|
||||
Create a wrapped, monitored VecEnv for Atari.
|
||||
|
|
@ -125,6 +129,7 @@ def make_atari_env(
|
|||
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
||||
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
||||
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
||||
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
|
||||
:return: The wrapped environment
|
||||
"""
|
||||
if wrapper_kwargs is None:
|
||||
|
|
@ -144,4 +149,5 @@ def make_atari_env(
|
|||
env_kwargs=env_kwargs,
|
||||
vec_env_cls=vec_env_cls,
|
||||
vec_env_kwargs=vec_env_kwargs,
|
||||
monitor_kwargs=monitor_kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
|||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback, RolloutReturn
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, LearningRateSchedule],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3.common.base_class import BaseAlgorithm
|
|||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule],
|
||||
learning_rate: Union[float, Schedule],
|
||||
n_steps: int,
|
||||
gamma: float,
|
||||
gae_lambda: float,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from stable_baselines3.common.distributions import (
|
|||
)
|
||||
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.type_aliases import LearningRateSchedule
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage
|
||||
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
|
||||
|
|
@ -365,7 +365,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
|
|
@ -480,7 +480,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device
|
||||
)
|
||||
|
||||
def _build(self, lr_schedule: LearningRateSchedule) -> None:
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
|
|
@ -663,7 +663,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ GymStepReturn = Tuple[GymObs, float, bool, Dict]
|
|||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
|
||||
LearningRateSchedule = Callable[[float], float]
|
||||
# A schedule takes the remaining progress as input
|
||||
# and ouputs a scalar (e.g. learning rate, clip range, ...)
|
||||
Schedule = Callable[[float], float]
|
||||
|
||||
|
||||
class RolloutBufferSamples(NamedTuple):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ except ImportError:
|
|||
SummaryWriter = None
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule
|
||||
from stable_baselines3.common.type_aliases import GymEnv, Schedule
|
||||
|
||||
|
||||
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
|
||||
|
|
@ -70,7 +70,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
|
|||
param_group["lr"] = learning_rate
|
||||
|
||||
|
||||
def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) -> LearningRateSchedule:
|
||||
def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
|
||||
"""
|
||||
Transform (if needed) learning rate and clip range (for PPO)
|
||||
to callable.
|
||||
|
|
@ -88,7 +88,7 @@ def get_schedule_fn(value_schedule: Union[LearningRateSchedule, float, int]) ->
|
|||
return value_schedule
|
||||
|
||||
|
||||
def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRateSchedule:
|
||||
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
|
||||
"""
|
||||
Create a function that interpolates linearly between start and end
|
||||
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
|
||||
|
|
@ -112,7 +112,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> LearningRate
|
|||
return func
|
||||
|
||||
|
||||
def constant_fn(val: float) -> LearningRateSchedule:
|
||||
def constant_fn(val: float) -> Schedule:
|
||||
"""
|
||||
Create a function that returns a constant
|
||||
It is useful for learning rate schedule (to avoid code duplication)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch as th
|
|||
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.td3.policies import TD3Policy
|
||||
from stable_baselines3.td3.td3 import TD3
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class DDPG(TD3):
|
|||
self,
|
||||
policy: Union[str, Type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 1e-3,
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch.nn import functional as F
|
|||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.dqn.policies import DQNPolicy
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[DQNPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 1e-4,
|
||||
learning_rate: Union[float, Schedule] = 1e-4,
|
||||
buffer_size: int = 1000000,
|
||||
learning_starts: int = 50000,
|
||||
batch_size: Optional[int] = 32,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch import nn
|
|||
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
|
||||
from stable_baselines3.common.type_aliases import LearningRateSchedule
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
|
||||
|
||||
class QNetwork(BasePolicy):
|
||||
|
|
@ -104,7 +104,7 @@ class DQNPolicy(BasePolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -143,7 +143,7 @@ class DQNPolicy(BasePolicy):
|
|||
self.q_net, self.q_net_target = None, None
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: LearningRateSchedule) -> None:
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
"""
|
||||
Create the network and the optimizer.
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ class CnnPolicy(DQNPolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
|
||||
|
|
@ -66,14 +66,14 @@ class PPO(OnPolicyAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 3e-4,
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
batch_size: Optional[int] = 64,
|
||||
n_epochs: int = 10,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_range: float = 0.2,
|
||||
clip_range_vf: Optional[float] = None,
|
||||
clip_range: Union[float, Schedule] = 0.2,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import LearningRateSchedule
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
LOG_STD_MAX = 2
|
||||
|
|
@ -228,7 +228,7 @@ class SACPolicy(BasePolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
|
|
@ -295,7 +295,7 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: LearningRateSchedule) -> None:
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
|
|
@ -398,7 +398,7 @@ class CnnPolicy(SACPolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.sac.policies import SACPolicy
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[SACPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 3e-4,
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import LearningRateSchedule
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
|
||||
|
||||
class Actor(BasePolicy):
|
||||
|
|
@ -109,7 +109,7 @@ class TD3Policy(BasePolicy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -164,7 +164,7 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: LearningRateSchedule) -> None:
|
||||
def _build(self, lr_schedule: Schedule) -> None:
|
||||
# Create actor and target
|
||||
# the features extractor should not be shared
|
||||
self.actor = self.make_actor(features_extractor=None)
|
||||
|
|
@ -253,7 +253,7 @@ class CnnPolicy(TD3Policy):
|
|||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: LearningRateSchedule,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, LearningRateSchedule, MaybeCallback
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.td3.policies import TD3Policy
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
self,
|
||||
policy: Union[str, Type[TD3Policy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, LearningRateSchedule] = 1e-3,
|
||||
learning_rate: Union[float, Schedule] = 1e-3,
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 100,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.11.0a0
|
||||
0.11.0a1
|
||||
|
|
|
|||
|
|
@ -70,6 +70,20 @@ def test_vec_env_kwargs():
|
|||
assert env.get_attr("goal_velocity")[0] == 0.11
|
||||
|
||||
|
||||
def test_vec_env_monitor_kwargs():
|
||||
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
|
||||
assert env.get_attr("allow_early_resets")[0] is False
|
||||
|
||||
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": False})
|
||||
assert env.get_attr("allow_early_resets")[0] is False
|
||||
|
||||
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True})
|
||||
assert env.get_attr("allow_early_resets")[0] is True
|
||||
|
||||
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1, seed=0, monitor_kwargs={"allow_early_resets": True})
|
||||
assert env.get_attr("allow_early_resets")[0] is True
|
||||
|
||||
|
||||
def test_custom_vec_env(tmp_path):
|
||||
"""
|
||||
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
|
||||
|
|
|
|||
Loading…
Reference in a new issue