mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Create OffPolicyRLModel
This commit is contained in:
parent
9d52a7d7d6
commit
16121cf2b8
3 changed files with 253 additions and 196 deletions
|
|
@ -14,11 +14,12 @@ from torchy_baselines.common import logger
|
|||
from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
|
||||
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, OptimizerStateDict
|
||||
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
|
||||
|
||||
class BaseRLModel(ABC):
|
||||
|
|
@ -83,15 +84,12 @@ class BaseRLModel(ABC):
|
|||
self.n_envs = None
|
||||
self.num_timesteps = 0
|
||||
self.eval_env = None
|
||||
self.replay_buffer = None
|
||||
self.seed = seed
|
||||
self.action_noise = None # type: ActionNoise
|
||||
self.start_time = None
|
||||
self.policy, self.actor = None, None
|
||||
self.policy = None
|
||||
self.learning_rate = None
|
||||
# Used for SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
# Track the training progress (from 1 to 0)
|
||||
|
|
@ -548,190 +546,6 @@ class BaseRLModel(ABC):
|
|||
if maybe_ep_info is not None:
|
||||
self.ep_info_buffer.extend([maybe_ep_info])
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
callback: 'BaseCallback', # Type hint as string to avoid circular import
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
deterministic: bool = False,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer=None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
TODO: move this method to off-policy base class.
|
||||
|
||||
:param env: (VecEnv)
|
||||
:param n_episodes: (int)
|
||||
:param n_steps: (int)
|
||||
:param action_noise: (ActionNoise)
|
||||
:param deterministic: (bool)
|
||||
:param callback: (BaseCallback)
|
||||
:param learning_starts: (int)
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param obs: (np.ndarray)
|
||||
:param episode_num: (int)
|
||||
:param log_interval: (int)
|
||||
"""
|
||||
episode_rewards = []
|
||||
total_timesteps = []
|
||||
total_steps, total_episodes = 0, 0
|
||||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
# Reset environment: not needed for VecEnv
|
||||
# obs = env.reset()
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return 0.0, total_steps, total_episodes, None, continue_training
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
# TODO: use action from policy when using SDE during the warmup phase?
|
||||
# if num_timesteps < learning_starts and not self.use_sde:
|
||||
if self.num_timesteps < learning_starts:
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
|
||||
|
||||
done_bool = [float(done[0])]
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs_, new_obs_, clipped_action, reward_, done_bool)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
# TODO: reset SDE matrix at the end of the episode?
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
# logger.logkv("n_updates", n_updates)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs, continue_training
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
||||
|
|
@ -830,3 +644,238 @@ class BaseRLModel(ABC):
|
|||
params_to_save[name] = attr.state_dict()
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
|
||||
class OffPolicyRLModel(BaseRLModel):
|
||||
"""
|
||||
The base RL model for Off-Policy algorithm (ex: SAC/TD3)
|
||||
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
"""
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False):
|
||||
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, policy_kwargs, verbose,
|
||||
device, support_multi_env, create_eval_env, monitor_wrapper,
|
||||
seed, use_sde, sde_sample_freq)
|
||||
# For SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.actor = None
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
self.ep_info_buffer = None # type: deque
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
deterministic: bool = False,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
obs: Optional[np.ndarray] = None,
|
||||
episode_num: int = 0,
|
||||
log_interval: Optional[int] = None) -> Tuple[float, int, int, Optional[np.ndarray], bool]:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
||||
:param env: (VecEnv)
|
||||
:param n_episodes: (int)
|
||||
:param n_steps: (int)
|
||||
:param action_noise: (ActionNoise)
|
||||
:param deterministic: (bool)
|
||||
:param callback: (BaseCallback)
|
||||
:param learning_starts: (int)
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param obs: (np.ndarray)
|
||||
:param episode_num: (int)
|
||||
:param log_interval: (int)
|
||||
"""
|
||||
episode_rewards = []
|
||||
total_timesteps = []
|
||||
total_steps, total_episodes = 0, 0
|
||||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
# Reset environment: not needed for VecEnv
|
||||
# obs = env.reset()
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback() is False:
|
||||
continue_training = False
|
||||
return 0.0, total_steps, total_episodes, None, continue_training
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs_, new_obs_, clipped_action, reward_, done)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(done[0].copy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
# TODO: reset SDE matrix at the end of the episode?
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and (
|
||||
episode_num + total_episodes) % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", episode_num + total_episodes)
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
# logger.logkv("n_updates", n_updates)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
logger.dumpkvs()
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return mean_reward, total_steps, total_episodes, obs, continue_training
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import torch as th
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.sac.policies import SACPolicy
|
||||
from torchy_baselines.common import logger
|
||||
|
||||
|
||||
class SAC(BaseRLModel):
|
||||
class SAC(OffPolicyRLModel):
|
||||
"""
|
||||
Soft Actor-Critic (SAC)
|
||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||
|
|
@ -49,6 +49,8 @@ class SAC(BaseRLModel):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -65,13 +67,15 @@ class SAC(BaseRLModel):
|
|||
train_freq=1, gradient_steps=1, n_episodes_rollout=-1,
|
||||
target_entropy='auto', action_noise=None,
|
||||
gamma=0.99, use_sde=False, sde_sample_freq=-1,
|
||||
use_sde_at_warmup=False,
|
||||
tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.target_entropy = target_entropy
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import torch as th
|
|||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.base_class import BaseRLModel
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.td3.policies import TD3Policy
|
||||
|
||||
|
||||
class TD3(BaseRLModel):
|
||||
class TD3(OffPolicyRLModel):
|
||||
"""
|
||||
Twin Delayed DDPG (TD3)
|
||||
Addressing Function Approximation Error in Actor-Critic Methods.
|
||||
|
|
@ -45,6 +45,8 @@ class TD3(BaseRLModel):
|
|||
:param sde_max_grad_norm: (float)
|
||||
:param sde_ent_coef: (float)
|
||||
:param sde_log_std_scheduler: (callable)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
@ -59,13 +61,15 @@ class TD3(BaseRLModel):
|
|||
policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100,
|
||||
train_freq=-1, gradient_steps=-1, n_episodes_rollout=1,
|
||||
tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5,
|
||||
use_sde=False, sde_sample_freq=-1, sde_max_grad_norm=1, sde_ent_coef=0.0, sde_log_std_scheduler=None,
|
||||
use_sde=False, sde_sample_freq=-1, sde_max_grad_norm=1,
|
||||
sde_ent_coef=0.0, sde_log_std_scheduler=None, use_sde_at_warmup=False,
|
||||
tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0,
|
||||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.buffer_size = buffer_size
|
||||
self.learning_rate = learning_rate
|
||||
|
|
|
|||
Loading…
Reference in a new issue