mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Merge pull request #50 from Antonin-Raffin/refactor/off-policy
Add Off Policy base class
This commit is contained in:
commit
02a080f647
9 changed files with 330 additions and 203 deletions
|
|
@ -11,11 +11,14 @@ Breaking Changes:
|
|||
- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above
|
||||
- Return type of `evaluation.evaluate_policy()` has been changed
|
||||
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
|
||||
- Created `OffPolicyRLModel` base class
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Add `seed()` method to `VecEnv` class
|
||||
- Add support for Callback (cf https://github.com/hill-a/stable-baselines/pull/644)
|
||||
- Add methods for saving and loading replay buffer
|
||||
- Add `extend()` method to the buffers
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -133,3 +133,28 @@ def test_exclude_include_saved_params(model_class):
|
|||
|
||||
# clear file from os
|
||||
os.remove("test_save.zip")
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3])
|
||||
def test_save_load_replay_buffer(model_class):
|
||||
log_folder = 'logs'
|
||||
replay_path = os.path.join(log_folder, 'replay_buffer.pkl')
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=1000)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(log_folder)
|
||||
model.replay_buffer = None
|
||||
model.load_replay_buffer(replay_path)
|
||||
|
||||
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
|
||||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||
assert np.allclose(old_replay_buffer.next_observations, model.replay_buffer.next_observations)
|
||||
assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards)
|
||||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
|
||||
# test extending replay buffer
|
||||
model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.next_observations,
|
||||
old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones)
|
||||
|
||||
# clear file from os
|
||||
os.remove(replay_path)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ class CEMRL(TD3):
|
|||
|
||||
rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
|
||||
n_steps=-1, action_noise=self.action_noise,
|
||||
deterministic=False, callback=callback,
|
||||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import time
|
|||
import os
|
||||
import io
|
||||
import zipfile
|
||||
import pickle
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
|
@ -14,11 +15,12 @@ from torchy_baselines.common import logger
|
|||
from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
|
||||
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict
|
||||
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
|
||||
|
||||
class BaseRLModel(ABC):
|
||||
|
|
@ -83,20 +85,20 @@ class BaseRLModel(ABC):
|
|||
self.n_envs = None
|
||||
self.num_timesteps = 0
|
||||
self.eval_env = None
|
||||
self.replay_buffer = None
|
||||
self.seed = seed
|
||||
self.action_noise = None # type: ActionNoise
|
||||
self.start_time = None
|
||||
self.policy, self.actor = None, None
|
||||
self.policy = None
|
||||
self.learning_rate = None
|
||||
# Used for SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
# Track the training progress (from 1 to 0)
|
||||
# this is used to update the learning rate
|
||||
self._current_progress = 1
|
||||
# Buffers for logging
|
||||
self.ep_info_buffer = None # type: deque
|
||||
self.ep_success_buffer = None # type: deque
|
||||
|
||||
# Create and wrap the env if needed
|
||||
if env is not None:
|
||||
|
|
@ -196,7 +198,7 @@ class BaseRLModel(ABC):
|
|||
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
|
||||
|
||||
@staticmethod
|
||||
def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray:
|
||||
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
|
@ -516,6 +518,7 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
self.start_time = time.time()
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
self.ep_success_buffer = deque(maxlen=100)
|
||||
|
||||
if self.action_noise is not None:
|
||||
self.action_noise.reset()
|
||||
|
|
@ -536,201 +539,22 @@ class BaseRLModel(ABC):
|
|||
|
||||
return episode_num, obs, callback
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]]) -> None:
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
Retrieve reward and episode length and update the buffer
|
||||
if using Monitor wrapper.
|
||||
|
||||
:param infos: ([dict])
|
||||
"""
|
||||
for info in infos:
|
||||
if dones is None:
|
||||
dones = np.array([False] * len(infos))
|
||||
for idx, info in enumerate(infos):
|
||||
maybe_ep_info = info.get('episode')
|
||||
maybe_is_success = info.get('is_success')
|
||||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
callback: 'BaseCallback', # Type hint as string to avoid circular import
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
deterministic: bool = False,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer=None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
TODO: move this method to off-policy base class.
|
||||
|
||||
:param env: (VecEnv)
|
||||
:param n_episodes: (int)
|
||||
:param n_steps: (int)
|
||||
:param action_noise: (ActionNoise)
|
||||
:param deterministic: (bool)
|
||||
:param callback: (BaseCallback)
|
||||
:param learning_starts: (int)
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param obs: (np.ndarray)
|
||||
:param episode_num: (int)
|
||||
:param log_interval: (int)
|
||||
"""
|
||||
episode_rewards = []
|
||||
total_timesteps = []
|
||||
total_steps, total_episodes = 0, 0
|
||||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
# Reset environment: not needed for VecEnv
|
||||
# obs = env.reset()
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return 0.0, total_steps, total_episodes, None, continue_training
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
# TODO: use action from policy when using SDE during the warmup phase?
|
||||
# if num_timesteps < learning_starts and not self.use_sde:
|
||||
if self.num_timesteps < learning_starts:
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
|
||||
|
||||
done_bool = [float(done[0])]
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs_, new_obs_, clipped_action, reward_, done_bool)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
# TODO: reset SDE matrix at the end of the episode?
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
# logger.logkv("n_updates", n_updates)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs, continue_training
|
||||
if maybe_is_success is not None and dones[idx]:
|
||||
self.ep_success_buffer.append(maybe_is_success)
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
|
||||
|
|
@ -830,3 +654,260 @@ class BaseRLModel(ABC):
|
|||
params_to_save[name] = attr.state_dict()
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
|
||||
class OffPolicyRLModel(BaseRLModel):
|
||||
"""
|
||||
The base RL model for Off-Policy algorithm (ex: SAC/TD3)
|
||||
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
"""
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False):
|
||||
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, policy_kwargs, verbose,
|
||||
device, support_multi_env, create_eval_env, monitor_wrapper,
|
||||
seed, use_sde, sde_sample_freq)
|
||||
# For SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.actor = None
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def save_replay_buffer(self, path: str):
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: (str) Path to a log folder
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
"""
|
||||
|
||||
:param path: (str) Path to the pickled replay buffer.
|
||||
"""
|
||||
with open(path, 'rb') as file_handler:
|
||||
self.replay_buffer = pickle.load(file_handler)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param n_episodes: (int) Number of episodes to use to collect rollout data
|
||||
You can also specify a `n_steps` instead
|
||||
:param n_steps: (int) Number of steps to use to collect rollout data
|
||||
You can also specify a `n_episodes` instead.
|
||||
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param callback: (BaseCallback) Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param obs: (np.ndarray) Last observation from the environment
|
||||
:param episode_num: (int) Episode index
|
||||
:param log_interval: (int) Log data every `log_interval` episodes
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
total_steps, total_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert env.num_envs == 1, "OffPolicyRLModel only support single environment"
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return 0.0, total_steps, total_episodes, None, continue_training
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos, done)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs_, new_obs_, clipped_action, reward_, done)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(done[0].copy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
# TODO: reset SDE matrix at the end of the episode?
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
# logger.logkv("n_updates", n_updates)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.logkv('success rate', self.safe_mean(self.ep_success_buffer))
|
||||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs, continue_training
|
||||
|
|
|
|||
|
|
@ -61,6 +61,14 @@ class BaseBuffer(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def extend(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Add a new batch of transitions to the buffer
|
||||
"""
|
||||
# Do a for loop along the batch axis
|
||||
for data in zip(*args):
|
||||
self.add(*data)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the buffer.
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ class EvalCallback(EventCallback):
|
|||
self.n_eval_episodes = n_eval_episodes
|
||||
self.eval_freq = eval_freq
|
||||
self.best_mean_reward = -np.inf
|
||||
self.last_mean_reward = -np.inf
|
||||
self.deterministic = deterministic
|
||||
self.render = render
|
||||
|
||||
|
|
@ -280,6 +281,7 @@ class EvalCallback(EventCallback):
|
|||
|
||||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||||
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
||||
self.last_mean_reward = mean_reward
|
||||
|
||||
if self.verbose > 0:
|
||||
print(f"Eval num_timesteps={self.num_timesteps}, "
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class Monitor(gym.Wrapper):
|
|||
|
||||
def __init__(self,
|
||||
env: gym.Env,
|
||||
filename: Optional[str],
|
||||
filename: Optional[str] = None,
|
||||
allow_early_resets: bool = True,
|
||||
reset_keywords=(),
|
||||
info_keywords=()):
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import torch as th
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.sac.policies import SACPolicy
|
||||
from torchy_baselines.common import logger
|
||||
|
||||
|
||||
class SAC(BaseRLModel):
|
||||
class SAC(OffPolicyRLModel):
|
||||
"""
|
||||
Soft Actor-Critic (SAC)
|
||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||
|
|
@ -49,6 +49,8 @@ class SAC(BaseRLModel):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -65,13 +67,15 @@ class SAC(BaseRLModel):
|
|||
train_freq=1, gradient_steps=1, n_episodes_rollout=-1,
|
||||
target_entropy='auto', action_noise=None,
|
||||
gamma=0.99, use_sde=False, sde_sample_freq=-1,
|
||||
use_sde_at_warmup=False,
|
||||
tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.target_entropy = target_entropy
|
||||
|
|
@ -268,7 +272,7 @@ class SAC(BaseRLModel):
|
|||
while self.num_timesteps < total_timesteps:
|
||||
rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
|
||||
n_steps=self.train_freq, action_noise=self.action_noise,
|
||||
deterministic=False, callback=callback,
|
||||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import torch as th
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
|
||||
|
||||
class TD3(BaseRLModel):
|
||||
class TD3(OffPolicyRLModel):
|
||||
"""
|
||||
Twin Delayed DDPG (TD3)
|
||||
Addressing Function Approximation Error in Actor-Critic Methods.
|
||||
|
|
@ -45,6 +45,8 @@ class TD3(BaseRLModel):
|
|||
:param sde_max_grad_norm: (float)
|
||||
:param sde_ent_coef: (float)
|
||||
:param sde_log_std_scheduler: (callable)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -59,13 +61,15 @@ class TD3(BaseRLModel):
|
|||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
use_sde=False, sde_sample_freq=-1, sde_max_grad_norm=1, sde_ent_coef=0.0, sde_log_std_scheduler=None,
|
||||
use_sde=False, sde_sample_freq=-1, sde_max_grad_norm=1,
|
||||
sde_ent_coef=0.0, sde_log_std_scheduler=None, use_sde_at_warmup=False,
|
||||
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
|
||||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.buffer_size = buffer_size
|
||||
self.learning_rate = learning_rate
|
||||
|
|
@ -263,7 +267,7 @@ class TD3(BaseRLModel):
|
|||
|
||||
rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
|
||||
n_steps=self.train_freq, action_noise=self.action_noise,
|
||||
deterministic=False, callback=callback,
|
||||
callback=callback,
|
||||
learning_starts=self.learning_starts,
|
||||
replay_buffer=self.replay_buffer,
|
||||
obs=obs, episode_num=episode_num,
|
||||
|
|
|
|||
Loading…
Reference in a new issue