stable-baselines3/stable_baselines3/common/utils.py

349 lines
12 KiB
Python
Raw Normal View History

import glob
import os
import random
from collections import deque
from itertools import zip_longest
from typing import Iterable, Optional, Union
import gym
2019-09-05 15:29:41 +00:00
import numpy as np
2019-09-21 15:17:09 +00:00
import torch as th
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from stable_baselines3.common import logger
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TrainFreq, TrainFrequencyUnit
2019-09-05 15:29:41 +00:00
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
2019-09-05 15:29:41 +00:00
"""
Seed the different random generators.
:param seed:
:param using_cuda:
2019-09-05 15:29:41 +00:00
"""
2020-05-05 14:52:22 +00:00
# Seed python RNG
2019-09-05 15:29:41 +00:00
random.seed(seed)
2020-05-05 14:52:22 +00:00
# Seed numpy RNG
2019-09-05 15:29:41 +00:00
np.random.seed(seed)
2020-05-05 14:52:22 +00:00
# seed the RNG for all devices (both CPU and CUDA)
2019-09-05 15:29:41 +00:00
th.manual_seed(seed)
if using_cuda:
2020-05-05 14:52:22 +00:00
# Deterministic operations for CuDNN, it may impact performances
2019-09-05 15:29:41 +00:00
th.backends.cudnn.deterministic = True
2020-05-05 14:52:22 +00:00
th.backends.cudnn.benchmark = False
2019-09-18 13:35:17 +00:00
2019-09-19 09:43:27 +00:00
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
2019-09-19 09:43:27 +00:00
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
2019-09-19 09:43:27 +00:00
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
2019-10-28 15:47:13 +00:00
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
2019-10-28 15:47:13 +00:00
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer:
:param learning_rate:
2019-10-28 15:47:13 +00:00
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
2019-10-28 15:47:13 +00:00
def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
2019-10-28 15:47:13 +00:00
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule:
:return:
2019-10-28 15:47:13 +00:00
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return:
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
"""
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
def constant_fn(val: float) -> Schedule:
2019-10-28 15:47:13 +00:00
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val:
:return:
2019-10-28 15:47:13 +00:00
"""
def func(_):
return val
return func
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return:
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True
) -> None:
"""
Configure the logger's outputs.
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
"""
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
if verbose >= 1:
logger.configure(save_path, ["stdout", "tensorboard"])
else:
logger.configure(save_path, ["tensorboard"])
elif verbose == 0:
logger.configure(format_strings=[""])
def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation_space, gym.spaces.Box):
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape."
)
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
else:
raise ValueError(
"Error: Cannot determine if the observation is vectorized " + f" with the space type {observation_space}."
)
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr:
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.
:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
2020-07-20 09:20:12 +00:00
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
2020-07-20 09:20:12 +00:00
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)