mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Refactor handling of obs and action space
+ remove duplicated code
This commit is contained in:
parent
7251b9d2c2
commit
57b37513b6
12 changed files with 152 additions and 76 deletions
|
|
@ -3,6 +3,29 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.4.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Refactor handling of observation and action spaces
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Pre-Release 0.3.0 (2020-02-14)
|
||||
------------------------------
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ class BaseRLModel(ABC):
|
|||
:param env: (Union[GymEnv, str]) The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: (Type[BasePolicy]) The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
|
|
@ -53,6 +55,7 @@ class BaseRLModel(ABC):
|
|||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
|
|
@ -89,7 +92,7 @@ class BaseRLModel(ABC):
|
|||
self.action_noise = None # type: Optional[ActionNoise]
|
||||
self.start_time = None
|
||||
self.policy = None
|
||||
self.learning_rate = None # type: Optional[float]
|
||||
self.learning_rate = learning_rate
|
||||
self.lr_schedule = None # type: Optional[Callable]
|
||||
# Used for SDE only
|
||||
self.use_sde = use_sde
|
||||
|
|
@ -135,7 +138,7 @@ class BaseRLModel(ABC):
|
|||
@abstractmethod
|
||||
def _setup_model(self) -> None:
|
||||
"""
|
||||
Create networks and optimizers
|
||||
Create networks, buffer and optimizers
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -761,6 +764,11 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
|
|
@ -785,6 +793,10 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
|
|
@ -796,16 +808,32 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False):
|
||||
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, policy_kwargs, verbose,
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate,
|
||||
policy_kwargs, verbose,
|
||||
device, support_multi_env, create_eval_env, monitor_wrapper,
|
||||
seed, use_sde, sde_sample_freq)
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.learning_starts = learning_starts
|
||||
self.actor = None
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
# Update policy keyword arguments
|
||||
self.policy_kwargs['use_sde'] = self.use_sde
|
||||
self.policy_kwargs['device'] = self.device
|
||||
# For SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.actor = None
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _setup_model(self):
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
|
||||
self.action_space, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: str):
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ from typing import Union, Optional, Generator
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from torchy_baselines.common.vec_env import VecNormalize
|
||||
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
|
||||
from torchy_baselines.common.preprocessing import get_obs_dim, get_action_dim
|
||||
|
||||
|
||||
class BaseBuffer(object):
|
||||
|
|
@ -12,23 +14,24 @@ class BaseBuffer(object):
|
|||
Base class that represent a buffer (rollout or replay)
|
||||
|
||||
:param buffer_size: (int) Max number of element in the buffer
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param action_dim: (int) Dimension of the action space
|
||||
:param observation_space: (spaces.Space) Observation space
|
||||
:param action_space: (spaces.Space) Action space
|
||||
:param device: (Union[th.device, str]) PyTorch device
|
||||
to which the values will be converted
|
||||
:param n_envs: (int) Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
buffer_size: int,
|
||||
obs_dim: int,
|
||||
action_dim: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
n_envs: int = 1):
|
||||
super(BaseBuffer, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.obs_dim = obs_dim
|
||||
self.action_dim = action_dim
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.obs_dim = get_obs_dim(observation_space)
|
||||
self.action_dim = get_action_dim(action_space)
|
||||
self.pos = 0
|
||||
self.full = False
|
||||
self.device = device
|
||||
|
|
@ -137,19 +140,20 @@ class ReplayBuffer(BaseBuffer):
|
|||
Replay buffer used in off-policy algorithms like SAC/TD3.
|
||||
|
||||
:param buffer_size: (int) Max number of element in the buffer
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param action_dim: (int) Dimension of the action space
|
||||
:param observation_space: (spaces.Space) Observation space
|
||||
:param action_space: (spaces.Space) Action space
|
||||
:param device: (th.device)
|
||||
:param n_envs: (int) Number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
buffer_size: int,
|
||||
obs_dim: int,
|
||||
action_dim: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
n_envs: int = 1):
|
||||
super(ReplayBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs)
|
||||
super(ReplayBuffer, self).__init__(buffer_size, observation_space,
|
||||
action_space, device, n_envs=n_envs)
|
||||
|
||||
assert n_envs == 1, "Replay buffer only support single environment for now"
|
||||
|
||||
|
|
@ -194,8 +198,8 @@ class RolloutBuffer(BaseBuffer):
|
|||
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
||||
|
||||
:param buffer_size: (int) Max number of element in the buffer
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param action_dim: (int) Dimension of the action space
|
||||
:param observation_space: (spaces.Space) Observation space
|
||||
:param action_space: (spaces.Space) Action space
|
||||
:param device: (th.device)
|
||||
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
|
|
@ -205,14 +209,15 @@ class RolloutBuffer(BaseBuffer):
|
|||
|
||||
def __init__(self,
|
||||
buffer_size: int,
|
||||
obs_dim: int,
|
||||
action_dim: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1):
|
||||
|
||||
super(RolloutBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs)
|
||||
super(RolloutBuffer, self).__init__(buffer_size, observation_space,
|
||||
action_space, device, n_envs=n_envs)
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import torch.nn as nn
|
|||
from torch.distributions import Normal, Categorical
|
||||
from gym import spaces
|
||||
|
||||
from torchy_baselines.common.preprocessing import get_action_dim
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
def __init__(self):
|
||||
|
|
@ -59,7 +61,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
Gaussian distribution with diagonal covariance matrix,
|
||||
for continuous actions.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
|
|
@ -144,7 +146,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
Gaussian distribution with diagonal covariance matrix,
|
||||
followed by a squashing function (tanh) to ensure bounds.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
|
||||
|
|
@ -252,7 +254,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
It is used to create the noise exploration matrix and
|
||||
compute the log probabilty of an action with that noise.
|
||||
|
||||
:param action_dim: (int) Number of continuous actions
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,)
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
|
|
@ -495,8 +497,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
|
|||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
if use_sde:
|
||||
return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs)
|
||||
return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs)
|
||||
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
# elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
|
|
|
|||
44
torchy_baselines/common/preprocessing.py
Normal file
44
torchy_baselines/common/preprocessing.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
|
||||
def is_image(observation_space):
|
||||
return False
|
||||
|
||||
|
||||
def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space) -> th.Tensor:
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
if is_image(observation_space):
|
||||
return obs / 255.0
|
||||
return obs
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# TODO: one hot encoding
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
if is_image(observation_space):
|
||||
return observation_space.shape
|
||||
return np.prod(observation_space.shape)
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
return 1
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_action_dim(action_space: spaces.Space) -> int:
|
||||
if isinstance(action_space, spaces.Box):
|
||||
return int(np.prod(action_space.shape))
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
# Action is an int
|
||||
return 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -6,6 +6,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.preprocessing import get_obs_dim
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import (make_proba_distribution, Distribution,
|
||||
|
|
@ -54,7 +55,7 @@ class PPOPolicy(BasePolicy):
|
|||
use_expln: bool = False,
|
||||
squash_output: bool = False):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
self.obs_dim = get_obs_dim(self.observation_space)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
|
|
|
|||
|
|
@ -98,11 +98,10 @@ class PPO(BaseRLModel):
|
|||
device: Union[th.device, str] = 'auto',
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs,
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
create_eval_env=create_eval_env, support_multi_env=True, seed=seed)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.n_steps = n_steps
|
||||
|
|
@ -123,19 +122,12 @@ class PPO(BaseRLModel):
|
|||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
# TODO: preprocessing: one hot vector for obs discrete
|
||||
state_dim = self.observation_space.shape[0]
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
# Action is a 1D vector
|
||||
action_dim = self.action_space.shape[0]
|
||||
elif isinstance(self.action_space, spaces.Discrete):
|
||||
# Action is a scalar
|
||||
action_dim = 1
|
||||
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space,
|
||||
self.action_space, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import gym
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
|
@ -191,8 +192,8 @@ class SACPolicy(BasePolicy):
|
|||
if net_arch is None:
|
||||
net_arch = [256, 256]
|
||||
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
self.action_dim = self.action_space.shape[0]
|
||||
self.obs_dim = get_obs_dim(self.observation_space)
|
||||
self.action_dim = get_action_dim(self.action_space)
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
|
|
|
|||
|
|
@ -89,7 +89,9 @@ class SAC(OffPolicyRLModel):
|
|||
device: Union[th.device, str] = 'auto',
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, learning_rate,
|
||||
buffer_size, learning_starts, batch_size,
|
||||
policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
|
@ -97,11 +99,6 @@ class SAC(OffPolicyRLModel):
|
|||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||
self.target_update_interval = target_update_interval
|
||||
self.buffer_size = buffer_size
|
||||
# In the original paper, same learning rate is used for all networks
|
||||
self.learning_rate = learning_rate
|
||||
self.learning_starts = learning_starts
|
||||
self.batch_size = batch_size
|
||||
self.tau = tau
|
||||
# Entropy coefficient / Entropy temperature
|
||||
# Inverse of the reward scale
|
||||
|
|
@ -118,10 +115,8 @@ class SAC(OffPolicyRLModel):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
if self.seed is not None:
|
||||
self.set_random_seed(self.seed)
|
||||
super(SAC, self)._setup_model()
|
||||
self._create_aliases()
|
||||
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == 'auto':
|
||||
|
|
@ -152,13 +147,6 @@ class SAC(OffPolicyRLModel):
|
|||
# is passed
|
||||
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
|
||||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.actor = self.policy.actor
|
||||
self.critic = self.policy.critic
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import gym
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
|
||||
|
|
@ -230,8 +231,8 @@ class TD3Policy(BasePolicy):
|
|||
if net_arch is None:
|
||||
net_arch = [400, 300]
|
||||
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
self.action_dim = self.action_space.shape[0]
|
||||
self.obs_dim = get_obs_dim(self.observation_space)
|
||||
self.action_dim = get_action_dim(self.action_space)
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
|
|
|
|||
|
|
@ -88,18 +88,16 @@ class TD3(OffPolicyRLModel):
|
|||
device: Union[th.device, str] = 'auto',
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, learning_rate,
|
||||
buffer_size, learning_starts, batch_size,
|
||||
policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.buffer_size = buffer_size
|
||||
self.learning_rate = learning_rate
|
||||
self.learning_starts = learning_starts
|
||||
self.train_freq = train_freq
|
||||
self.gradient_steps = gradient_steps
|
||||
self.n_episodes_rollout = n_episodes_rollout
|
||||
self.batch_size = batch_size
|
||||
self.tau = tau
|
||||
self.gamma = gamma
|
||||
self.action_noise = action_noise
|
||||
|
|
@ -118,14 +116,7 @@ class TD3(OffPolicyRLModel):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
super(TD3, self)._setup_model()
|
||||
self._create_aliases()
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.3.0
|
||||
0.4.0a0
|
||||
|
|
|
|||
Loading…
Reference in a new issue