2020-07-16 14:12:16 +00:00
|
|
|
import io
|
|
|
|
|
import pathlib
|
2020-06-09 11:54:18 +00:00
|
|
|
import time
|
|
|
|
|
import warnings
|
2021-02-27 16:33:50 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
import gym
|
|
|
|
|
import numpy as np
|
2020-07-16 14:12:16 +00:00
|
|
|
import torch as th
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
from stable_baselines3.common import logger
|
|
|
|
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
2020-07-16 14:12:16 +00:00
|
|
|
from stable_baselines3.common.buffers import ReplayBuffer
|
|
|
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
|
from stable_baselines3.common.noise import ActionNoise
|
2020-06-09 11:54:18 +00:00
|
|
|
from stable_baselines3.common.policies import BasePolicy
|
2020-07-16 14:12:16 +00:00
|
|
|
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
|
2021-02-27 16:33:50 +00:00
|
|
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
|
|
|
|
|
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
2020-06-09 11:54:18 +00:00
|
|
|
from stable_baselines3.common.vec_env import VecEnv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OffPolicyAlgorithm(BaseAlgorithm):
|
|
|
|
|
"""
|
|
|
|
|
The base for Off-Policy algorithms (ex: SAC/TD3)
|
|
|
|
|
|
|
|
|
|
:param policy: Policy object
|
|
|
|
|
:param env: The environment to learn from
|
|
|
|
|
(if registered in Gym, can be str. Can be None for loading trained models)
|
|
|
|
|
:param policy_base: The base policy used by this method
|
2020-10-02 17:05:55 +00:00
|
|
|
:param learning_rate: learning rate for the optimizer,
|
2020-06-09 11:54:18 +00:00
|
|
|
it can be a function of the current progress remaining (from 1 to 0)
|
2020-10-02 17:05:55 +00:00
|
|
|
:param buffer_size: size of the replay buffer
|
|
|
|
|
:param learning_starts: how many steps of the model to collect transitions for before learning starts
|
|
|
|
|
:param batch_size: Minibatch size for each gradient update
|
|
|
|
|
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
|
|
|
|
|
:param gamma: the discount factor
|
2021-02-27 16:33:50 +00:00
|
|
|
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
|
|
|
|
|
like ``(5, "step")`` or ``(2, "episode")``.
|
|
|
|
|
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
|
2020-07-16 12:14:22 +00:00
|
|
|
Set to ``-1`` means to do as many gradient steps as steps done in the environment
|
|
|
|
|
during the rollout.
|
2020-10-02 17:05:55 +00:00
|
|
|
:param action_noise: the action noise type (None by default), this can help
|
2020-06-29 09:16:54 +00:00
|
|
|
for hard exploration problem. Cf common.noise for the different action noise type.
|
2020-10-02 17:05:55 +00:00
|
|
|
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
|
2020-06-29 09:16:54 +00:00
|
|
|
at a cost of more complexity.
|
|
|
|
|
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
2020-06-09 11:54:18 +00:00
|
|
|
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
2020-10-02 17:05:55 +00:00
|
|
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
2020-06-09 11:54:18 +00:00
|
|
|
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
|
|
|
|
:param device: Device on which the code should run.
|
|
|
|
|
By default, it will try to use a Cuda compatible device and fallback to cpu
|
|
|
|
|
if it is not possible.
|
|
|
|
|
:param support_multi_env: Whether the algorithm supports training
|
|
|
|
|
with multiple environments (as in A2C)
|
|
|
|
|
:param create_eval_env: Whether to create a second environment that will be
|
|
|
|
|
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
|
|
|
|
:param monitor_wrapper: When creating an environment, whether to wrap it
|
|
|
|
|
or not in a Monitor wrapper.
|
|
|
|
|
:param seed: Seed for the pseudo random generators
|
|
|
|
|
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
|
|
|
|
instead of action noise exploration (default: False)
|
|
|
|
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
|
|
|
|
Default: -1 (only sample at the beginning of the rollout)
|
2020-10-02 17:05:55 +00:00
|
|
|
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
|
2020-06-09 11:54:18 +00:00
|
|
|
during the warm up phase (before learning starts)
|
2020-10-02 17:05:55 +00:00
|
|
|
:param sde_support: Whether the model support gSDE or not
|
2020-10-22 09:56:43 +00:00
|
|
|
:param remove_time_limit_termination: Remove terminations (dones) that are due to time limit.
|
|
|
|
|
See https://github.com/hill-a/stable-baselines/issues/863
|
2020-12-06 12:05:10 +00:00
|
|
|
:param supported_action_spaces: The action spaces supported by the algorithm.
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
policy: Type[BasePolicy],
|
|
|
|
|
env: Union[GymEnv, str],
|
|
|
|
|
policy_base: Type[BasePolicy],
|
2020-11-20 09:28:54 +00:00
|
|
|
learning_rate: Union[float, Schedule],
|
2021-03-25 09:35:21 +00:00
|
|
|
buffer_size: int = 1000000,
|
2020-07-16 14:12:16 +00:00
|
|
|
learning_starts: int = 100,
|
|
|
|
|
batch_size: int = 256,
|
|
|
|
|
tau: float = 0.005,
|
|
|
|
|
gamma: float = 0.99,
|
2021-02-27 16:33:50 +00:00
|
|
|
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
|
2020-07-16 14:12:16 +00:00
|
|
|
gradient_steps: int = 1,
|
|
|
|
|
action_noise: Optional[ActionNoise] = None,
|
|
|
|
|
optimize_memory_usage: bool = False,
|
|
|
|
|
policy_kwargs: Dict[str, Any] = None,
|
|
|
|
|
tensorboard_log: Optional[str] = None,
|
|
|
|
|
verbose: int = 0,
|
|
|
|
|
device: Union[th.device, str] = "auto",
|
|
|
|
|
support_multi_env: bool = False,
|
|
|
|
|
create_eval_env: bool = False,
|
|
|
|
|
monitor_wrapper: bool = True,
|
|
|
|
|
seed: Optional[int] = None,
|
|
|
|
|
use_sde: bool = False,
|
|
|
|
|
sde_sample_freq: int = -1,
|
|
|
|
|
use_sde_at_warmup: bool = False,
|
|
|
|
|
sde_support: bool = True,
|
2020-10-22 09:56:43 +00:00
|
|
|
remove_time_limit_termination: bool = False,
|
2020-12-06 12:05:10 +00:00
|
|
|
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
2020-07-16 14:12:16 +00:00
|
|
|
):
|
|
|
|
|
|
|
|
|
|
super(OffPolicyAlgorithm, self).__init__(
|
|
|
|
|
policy=policy,
|
|
|
|
|
env=env,
|
|
|
|
|
policy_base=policy_base,
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
policy_kwargs=policy_kwargs,
|
|
|
|
|
tensorboard_log=tensorboard_log,
|
|
|
|
|
verbose=verbose,
|
|
|
|
|
device=device,
|
|
|
|
|
support_multi_env=support_multi_env,
|
|
|
|
|
create_eval_env=create_eval_env,
|
|
|
|
|
monitor_wrapper=monitor_wrapper,
|
|
|
|
|
seed=seed,
|
|
|
|
|
use_sde=use_sde,
|
|
|
|
|
sde_sample_freq=sde_sample_freq,
|
2020-12-06 12:05:10 +00:00
|
|
|
supported_action_spaces=supported_action_spaces,
|
2020-07-16 14:12:16 +00:00
|
|
|
)
|
2020-06-09 11:54:18 +00:00
|
|
|
self.buffer_size = buffer_size
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
self.learning_starts = learning_starts
|
2020-06-29 09:16:54 +00:00
|
|
|
self.tau = tau
|
|
|
|
|
self.gamma = gamma
|
|
|
|
|
self.gradient_steps = gradient_steps
|
|
|
|
|
self.action_noise = action_noise
|
|
|
|
|
self.optimize_memory_usage = optimize_memory_usage
|
|
|
|
|
|
2020-10-22 09:56:43 +00:00
|
|
|
# Remove terminations (dones) that are due to time limit
|
|
|
|
|
# see https://github.com/hill-a/stable-baselines/issues/863
|
|
|
|
|
self.remove_time_limit_termination = remove_time_limit_termination
|
|
|
|
|
|
2021-02-27 18:53:13 +00:00
|
|
|
# Save train freq parameter, will be converted later to TrainFreq object
|
|
|
|
|
self.train_freq = train_freq
|
2020-06-29 09:16:54 +00:00
|
|
|
|
2020-06-09 11:54:18 +00:00
|
|
|
self.actor = None # type: Optional[th.nn.Module]
|
|
|
|
|
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
|
|
|
|
# Update policy keyword arguments
|
|
|
|
|
if sde_support:
|
2020-07-16 14:12:16 +00:00
|
|
|
self.policy_kwargs["use_sde"] = self.use_sde
|
2020-06-09 11:54:18 +00:00
|
|
|
# For gSDE only
|
|
|
|
|
self.use_sde_at_warmup = use_sde_at_warmup
|
|
|
|
|
|
2021-02-27 18:53:13 +00:00
|
|
|
def _convert_train_freq(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Convert `train_freq` parameter (int or tuple)
|
|
|
|
|
to a TrainFreq object.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(self.train_freq, TrainFreq):
|
|
|
|
|
train_freq = self.train_freq
|
|
|
|
|
|
|
|
|
|
# The value of the train frequency will be checked later
|
|
|
|
|
if not isinstance(train_freq, tuple):
|
|
|
|
|
train_freq = (train_freq, "step")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
|
|
|
|
|
|
|
|
|
if not isinstance(train_freq[0], int):
|
|
|
|
|
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
|
|
|
|
|
|
|
|
|
self.train_freq = TrainFreq(*train_freq)
|
|
|
|
|
|
2020-07-02 22:14:21 +00:00
|
|
|
def _setup_model(self) -> None:
|
2020-06-09 11:54:18 +00:00
|
|
|
self._setup_lr_schedule()
|
|
|
|
|
self.set_random_seed(self.seed)
|
2020-07-16 14:12:16 +00:00
|
|
|
self.replay_buffer = ReplayBuffer(
|
|
|
|
|
self.buffer_size,
|
|
|
|
|
self.observation_space,
|
|
|
|
|
self.action_space,
|
|
|
|
|
self.device,
|
|
|
|
|
optimize_memory_usage=self.optimize_memory_usage,
|
|
|
|
|
)
|
2021-03-06 13:17:43 +00:00
|
|
|
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
2020-07-16 14:12:16 +00:00
|
|
|
self.observation_space,
|
|
|
|
|
self.action_space,
|
|
|
|
|
self.lr_schedule,
|
2021-02-27 16:33:50 +00:00
|
|
|
**self.policy_kwargs, # pytype:disable=not-instantiable
|
2020-07-16 14:12:16 +00:00
|
|
|
)
|
2020-06-09 11:54:18 +00:00
|
|
|
self.policy = self.policy.to(self.device)
|
|
|
|
|
|
2021-02-27 18:53:13 +00:00
|
|
|
# Convert train freq parameter to TrainFreq object
|
|
|
|
|
self._convert_train_freq()
|
|
|
|
|
|
2020-07-02 22:14:21 +00:00
|
|
|
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
|
|
|
|
Save the replay buffer as a pickle file.
|
|
|
|
|
|
2020-10-02 17:05:55 +00:00
|
|
|
:param path: Path to the file where the replay buffer should be saved.
|
2020-07-02 22:14:21 +00:00
|
|
|
if path is a str or pathlib.Path, the path is automatically created if necessary.
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
|
|
|
|
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
2020-07-02 22:14:21 +00:00
|
|
|
save_to_pkl(path, self.replay_buffer, self.verbose)
|
2020-06-09 11:54:18 +00:00
|
|
|
|
2020-07-02 22:14:21 +00:00
|
|
|
def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
|
|
|
|
Load a replay buffer from a pickle file.
|
|
|
|
|
|
2020-10-02 17:05:55 +00:00
|
|
|
:param path: Path to the pickled replay buffer.
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
2020-07-02 22:14:21 +00:00
|
|
|
self.replay_buffer = load_from_pkl(path, self.verbose)
|
2020-07-16 14:12:16 +00:00
|
|
|
assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
|
|
|
|
|
|
|
|
|
|
def _setup_learn(
|
|
|
|
|
self,
|
|
|
|
|
total_timesteps: int,
|
|
|
|
|
eval_env: Optional[GymEnv],
|
2020-11-15 16:50:28 +00:00
|
|
|
callback: MaybeCallback = None,
|
2020-07-16 14:12:16 +00:00
|
|
|
eval_freq: int = 10000,
|
|
|
|
|
n_eval_episodes: int = 5,
|
|
|
|
|
log_path: Optional[str] = None,
|
|
|
|
|
reset_num_timesteps: bool = True,
|
|
|
|
|
tb_log_name: str = "run",
|
|
|
|
|
) -> Tuple[int, BaseCallback]:
|
2020-06-29 09:16:54 +00:00
|
|
|
"""
|
|
|
|
|
cf `BaseAlgorithm`.
|
|
|
|
|
"""
|
|
|
|
|
# Prevent continuity issue by truncating trajectory
|
|
|
|
|
# when using memory efficient replay buffer
|
|
|
|
|
# see https://github.com/DLR-RM/stable-baselines3/issues/46
|
2020-07-16 14:12:16 +00:00
|
|
|
truncate_last_traj = (
|
|
|
|
|
self.optimize_memory_usage
|
|
|
|
|
and reset_num_timesteps
|
|
|
|
|
and self.replay_buffer is not None
|
|
|
|
|
and (self.replay_buffer.full or self.replay_buffer.pos > 0)
|
|
|
|
|
)
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
if truncate_last_traj:
|
2020-07-16 14:12:16 +00:00
|
|
|
warnings.warn(
|
|
|
|
|
"The last trajectory in the replay buffer will be truncated, "
|
|
|
|
|
"see https://github.com/DLR-RM/stable-baselines3/issues/46."
|
|
|
|
|
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
|
|
|
|
|
"to avoid that issue."
|
|
|
|
|
)
|
2020-06-29 09:16:54 +00:00
|
|
|
# Go to the previous index
|
|
|
|
|
pos = (self.replay_buffer.pos - 1) % self.replay_buffer.buffer_size
|
|
|
|
|
self.replay_buffer.dones[pos] = True
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
return super()._setup_learn(
|
|
|
|
|
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, log_path, reset_num_timesteps, tb_log_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def learn(
|
|
|
|
|
self,
|
|
|
|
|
total_timesteps: int,
|
|
|
|
|
callback: MaybeCallback = None,
|
|
|
|
|
log_interval: int = 4,
|
|
|
|
|
eval_env: Optional[GymEnv] = None,
|
|
|
|
|
eval_freq: int = -1,
|
|
|
|
|
n_eval_episodes: int = 5,
|
|
|
|
|
tb_log_name: str = "run",
|
|
|
|
|
eval_log_path: Optional[str] = None,
|
|
|
|
|
reset_num_timesteps: bool = True,
|
|
|
|
|
) -> "OffPolicyAlgorithm":
|
|
|
|
|
|
|
|
|
|
total_timesteps, callback = self._setup_learn(
|
|
|
|
|
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
|
|
|
|
|
)
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
callback.on_training_start(locals(), globals())
|
|
|
|
|
|
|
|
|
|
while self.num_timesteps < total_timesteps:
|
2020-07-16 14:12:16 +00:00
|
|
|
rollout = self.collect_rollouts(
|
|
|
|
|
self.env,
|
2021-02-27 16:33:50 +00:00
|
|
|
train_freq=self.train_freq,
|
2020-07-16 14:12:16 +00:00
|
|
|
action_noise=self.action_noise,
|
|
|
|
|
callback=callback,
|
|
|
|
|
learning_starts=self.learning_starts,
|
|
|
|
|
replay_buffer=self.replay_buffer,
|
|
|
|
|
log_interval=log_interval,
|
|
|
|
|
)
|
2020-06-29 09:16:54 +00:00
|
|
|
|
|
|
|
|
if rollout.continue_training is False:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
|
|
|
|
# If no `gradient_steps` is specified,
|
|
|
|
|
# do as many gradients steps as steps performed during the rollout
|
|
|
|
|
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
|
|
|
|
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
|
|
|
|
|
|
|
|
|
callback.on_training_end()
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def train(self, gradient_steps: int, batch_size: int) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Sample the replay buffer and do the updates
|
|
|
|
|
(gradient descent and update target networks)
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
def _sample_action(
|
|
|
|
|
self, learning_starts: int, action_noise: Optional[ActionNoise] = None
|
|
|
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
2020-06-29 09:16:54 +00:00
|
|
|
"""
|
|
|
|
|
Sample an action according to the exploration policy.
|
|
|
|
|
This is either done by sampling the probability distribution of the policy,
|
|
|
|
|
or sampling a random action (from a uniform distribution over the action space)
|
|
|
|
|
or by adding noise to the deterministic output.
|
|
|
|
|
|
2020-10-02 17:05:55 +00:00
|
|
|
:param action_noise: Action noise that will be used for exploration
|
2020-06-29 09:16:54 +00:00
|
|
|
Required for deterministic policy (e.g. TD3). This can also be used
|
|
|
|
|
in addition to the stochastic policy for SAC.
|
2020-10-02 17:05:55 +00:00
|
|
|
:param learning_starts: Number of steps before learning for the warm-up phase.
|
|
|
|
|
:return: action to take in the environment
|
2020-06-29 09:16:54 +00:00
|
|
|
and scaled action that will be stored in the replay buffer.
|
|
|
|
|
The two differs when the action space is not normalized (bounds are not [-1, 1]).
|
|
|
|
|
"""
|
|
|
|
|
# Select action randomly or according to policy
|
|
|
|
|
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
|
|
|
|
# Warmup phase
|
|
|
|
|
unscaled_action = np.array([self.action_space.sample()])
|
|
|
|
|
else:
|
|
|
|
|
# Note: when using continuous actions,
|
|
|
|
|
# we assume that the policy uses tanh to scale the action
|
|
|
|
|
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
|
|
|
|
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
|
|
|
|
|
|
|
|
|
# Rescale the action from [low, high] to [-1, 1]
|
|
|
|
|
if isinstance(self.action_space, gym.spaces.Box):
|
|
|
|
|
scaled_action = self.policy.scale_action(unscaled_action)
|
|
|
|
|
|
|
|
|
|
# Add noise to the action (improve exploration)
|
|
|
|
|
if action_noise is not None:
|
|
|
|
|
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
|
|
|
|
|
|
|
|
|
|
# We store the scaled action in the buffer
|
|
|
|
|
buffer_action = scaled_action
|
|
|
|
|
action = self.policy.unscale_action(scaled_action)
|
|
|
|
|
else:
|
|
|
|
|
# Discrete case, no need to normalize or clip
|
|
|
|
|
buffer_action = unscaled_action
|
|
|
|
|
action = buffer_action
|
|
|
|
|
return action, buffer_action
|
|
|
|
|
|
|
|
|
|
def _dump_logs(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Write log.
|
|
|
|
|
"""
|
|
|
|
|
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
|
|
|
|
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
|
|
|
|
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
2020-07-16 14:12:16 +00:00
|
|
|
logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
|
|
|
logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
2020-06-29 09:16:54 +00:00
|
|
|
logger.record("time/fps", fps)
|
2020-07-16 14:12:16 +00:00
|
|
|
logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
2020-06-29 09:16:54 +00:00
|
|
|
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
|
|
|
|
if self.use_sde:
|
|
|
|
|
logger.record("train/std", (self.actor.get_std()).mean().item())
|
|
|
|
|
|
|
|
|
|
if len(self.ep_success_buffer) > 0:
|
2020-07-16 14:12:16 +00:00
|
|
|
logger.record("rollout/success rate", safe_mean(self.ep_success_buffer))
|
2020-06-29 09:16:54 +00:00
|
|
|
# Pass the number of timesteps for tensorboard
|
|
|
|
|
logger.dump(step=self.num_timesteps)
|
|
|
|
|
|
|
|
|
|
def _on_step(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Method called after each step in the environment.
|
|
|
|
|
It is meant to trigger DQN target network update
|
|
|
|
|
but can be used for other purposes
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
def _store_transition(
|
|
|
|
|
self,
|
|
|
|
|
replay_buffer: ReplayBuffer,
|
|
|
|
|
buffer_action: np.ndarray,
|
|
|
|
|
new_obs: np.ndarray,
|
|
|
|
|
reward: np.ndarray,
|
|
|
|
|
done: np.ndarray,
|
|
|
|
|
infos: List[Dict[str, Any]],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Store transition in the replay buffer.
|
|
|
|
|
We store the normalized action and the unnormalized observation.
|
|
|
|
|
It also handles terminal observations (because VecEnv resets automatically).
|
|
|
|
|
|
|
|
|
|
:param replay_buffer: Replay buffer object where to store the transition.
|
|
|
|
|
:param buffer_action: normalized action
|
|
|
|
|
:param new_obs: next observation in the current episode
|
|
|
|
|
or first observation of the episode (when done is True)
|
|
|
|
|
:param reward: reward for the current transition
|
|
|
|
|
:param done: Termination signal
|
|
|
|
|
:param infos: List of additional information about the transition.
|
|
|
|
|
It contains the terminal observations.
|
|
|
|
|
"""
|
|
|
|
|
# Store only the unnormalized version
|
|
|
|
|
if self._vec_normalize_env is not None:
|
|
|
|
|
new_obs_ = self._vec_normalize_env.get_original_obs()
|
|
|
|
|
reward_ = self._vec_normalize_env.get_original_reward()
|
|
|
|
|
else:
|
|
|
|
|
# Avoid changing the original ones
|
|
|
|
|
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
|
|
|
|
|
|
|
|
|
# As the VecEnv resets automatically, new_obs is already the
|
|
|
|
|
# first observation of the next episode
|
|
|
|
|
if done and infos[0].get("terminal_observation") is not None:
|
|
|
|
|
next_obs = infos[0]["terminal_observation"]
|
|
|
|
|
# VecNormalize normalizes the terminal observation
|
|
|
|
|
if self._vec_normalize_env is not None:
|
|
|
|
|
next_obs = self._vec_normalize_env.unnormalize_obs(next_obs)
|
|
|
|
|
else:
|
|
|
|
|
next_obs = new_obs_
|
|
|
|
|
|
|
|
|
|
replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done)
|
|
|
|
|
|
|
|
|
|
self._last_obs = new_obs
|
|
|
|
|
# Save the unnormalized observation
|
|
|
|
|
if self._vec_normalize_env is not None:
|
|
|
|
|
self._last_original_obs = new_obs_
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
def collect_rollouts(
|
|
|
|
|
self,
|
|
|
|
|
env: VecEnv,
|
|
|
|
|
callback: BaseCallback,
|
2021-02-27 16:33:50 +00:00
|
|
|
train_freq: TrainFreq,
|
|
|
|
|
replay_buffer: ReplayBuffer,
|
2020-07-16 14:12:16 +00:00
|
|
|
action_noise: Optional[ActionNoise] = None,
|
|
|
|
|
learning_starts: int = 0,
|
|
|
|
|
log_interval: Optional[int] = None,
|
|
|
|
|
) -> RolloutReturn:
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
2020-11-02 10:45:08 +00:00
|
|
|
Collect experiences and store them into a ``ReplayBuffer``.
|
2020-06-09 11:54:18 +00:00
|
|
|
|
2020-10-02 17:05:55 +00:00
|
|
|
:param env: The training environment
|
|
|
|
|
:param callback: Callback that will be called at each step
|
2020-06-09 11:54:18 +00:00
|
|
|
(and at the beginning and end of the rollout)
|
2021-02-27 16:33:50 +00:00
|
|
|
:param train_freq: How much experience to collect
|
|
|
|
|
by doing rollouts of current policy.
|
|
|
|
|
Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
|
|
|
|
|
or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
|
|
|
|
|
with ``<n>`` being an integer greater than 0.
|
2020-10-02 17:05:55 +00:00
|
|
|
:param action_noise: Action noise that will be used for exploration
|
2020-06-09 11:54:18 +00:00
|
|
|
Required for deterministic policy (e.g. TD3). This can also be used
|
|
|
|
|
in addition to the stochastic policy for SAC.
|
2020-10-02 17:05:55 +00:00
|
|
|
:param learning_starts: Number of steps before learning for the warm-up phase.
|
|
|
|
|
:param replay_buffer:
|
|
|
|
|
:param log_interval: Log data every ``log_interval`` episodes
|
|
|
|
|
:return:
|
2020-06-09 11:54:18 +00:00
|
|
|
"""
|
|
|
|
|
episode_rewards, total_timesteps = [], []
|
2021-02-27 16:33:50 +00:00
|
|
|
num_collected_steps, num_collected_episodes = 0, 0
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
|
|
|
|
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
|
2021-02-27 16:33:50 +00:00
|
|
|
assert train_freq.frequency > 0, "Should at least collect one step or episode."
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
if self.use_sde:
|
|
|
|
|
self.actor.reset_noise()
|
|
|
|
|
|
|
|
|
|
callback.on_rollout_start()
|
|
|
|
|
continue_training = True
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
2020-06-09 11:54:18 +00:00
|
|
|
done = False
|
|
|
|
|
episode_reward, episode_timesteps = 0.0, 0
|
|
|
|
|
|
|
|
|
|
while not done:
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
|
2020-06-09 11:54:18 +00:00
|
|
|
# Sample a new noise matrix
|
|
|
|
|
self.actor.reset_noise()
|
|
|
|
|
|
|
|
|
|
# Select action randomly or according to policy
|
2020-06-29 09:16:54 +00:00
|
|
|
action, buffer_action = self._sample_action(learning_starts, action_noise)
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
# Rescale and perform action
|
|
|
|
|
new_obs, reward, done, infos = env.step(action)
|
|
|
|
|
|
2020-08-28 09:36:33 +00:00
|
|
|
self.num_timesteps += 1
|
|
|
|
|
episode_timesteps += 1
|
2021-02-27 16:33:50 +00:00
|
|
|
num_collected_steps += 1
|
2020-08-28 09:36:33 +00:00
|
|
|
|
2020-08-23 12:34:01 +00:00
|
|
|
# Give access to local variables
|
|
|
|
|
callback.update_locals(locals())
|
2020-06-09 11:54:18 +00:00
|
|
|
# Only stop training if return value is False, not when it is None.
|
|
|
|
|
if callback.on_step() is False:
|
2021-02-27 16:33:50 +00:00
|
|
|
return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False)
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
episode_reward += reward
|
|
|
|
|
|
|
|
|
|
# Retrieve reward and episode length if using Monitor wrapper
|
|
|
|
|
self._update_info_buffer(infos, done)
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
# Store data in replay buffer (normalized action and unnormalized observation)
|
|
|
|
|
self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
|
2020-06-09 11:54:18 +00:00
|
|
|
|
2020-06-29 09:16:54 +00:00
|
|
|
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
|
|
|
|
|
|
|
|
|
|
# For DQN, check if the target network should be updated
|
|
|
|
|
# and update the exploration schedule
|
|
|
|
|
# For SAC/TD3, the update is done as the same time as the gradient update
|
|
|
|
|
# see https://github.com/hill-a/stable-baselines/issues/900
|
|
|
|
|
self._on_step()
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
|
2020-06-09 11:54:18 +00:00
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if done:
|
2021-02-27 16:33:50 +00:00
|
|
|
num_collected_episodes += 1
|
2020-06-09 11:54:18 +00:00
|
|
|
self._episode_num += 1
|
|
|
|
|
episode_rewards.append(episode_reward)
|
|
|
|
|
total_timesteps.append(episode_timesteps)
|
|
|
|
|
|
|
|
|
|
if action_noise is not None:
|
|
|
|
|
action_noise.reset()
|
|
|
|
|
|
|
|
|
|
# Log training infos
|
|
|
|
|
if log_interval is not None and self._episode_num % log_interval == 0:
|
2020-06-29 09:16:54 +00:00
|
|
|
self._dump_logs()
|
2020-06-09 11:54:18 +00:00
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0
|
2020-06-09 11:54:18 +00:00
|
|
|
|
|
|
|
|
callback.on_rollout_end()
|
|
|
|
|
|
2021-02-27 16:33:50 +00:00
|
|
|
return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training)
|