mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Cleanup, bug fixes + more tests
This commit is contained in:
parent
73fb8d1c63
commit
041f2bc59a
8 changed files with 208 additions and 132 deletions
|
|
@ -12,10 +12,10 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
|
||||
- Added ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
|
||||
customizer optimizers
|
||||
- Complete independent save/load for policies
|
||||
- Add ``CnnPolicies`` to support images as input
|
||||
- Add ``CnnPolicy`` and ``VecTransposeImage`` to support images as input
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ def test_cnn(model_class):
|
|||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512)))
|
||||
kwargs = dict(buffer_size=250,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||
model = model_class('CnnPolicy', env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import torch as th
|
|||
from torchy_baselines import A2C, PPO, SAC, TD3
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.identity_env import FakeImageEnv
|
||||
|
||||
|
||||
MODEL_LIST = [
|
||||
PPO,
|
||||
|
|
@ -30,12 +32,11 @@ def test_save_load(model_class):
|
|||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
|
||||
observations = observations.reshape(10, -1)
|
||||
observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0)
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.policy.state_dict())
|
||||
|
|
@ -90,7 +91,7 @@ def test_set_env(model_class):
|
|||
env3 = IdentityEnvBox(10)
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True)
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]))
|
||||
# learn
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
|
@ -115,7 +116,7 @@ def test_exclude_include_saved_params(model_class):
|
|||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
# create model, set verbose as 2, which is not standard
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2, create_eval_env=True)
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2)
|
||||
|
||||
# Check if exclude works
|
||||
model.save("test_save.zip", exclude=["verbose"])
|
||||
|
|
@ -163,21 +164,34 @@ def test_save_load_replay_buffer(model_class):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_save_load_policy(model_class):
|
||||
@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy'])
|
||||
def test_save_load_policy(model_class, policy_str):
|
||||
"""
|
||||
Test saving and loading policy only.
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
kwargs = {}
|
||||
if policy_str == 'MlpPolicy':
|
||||
env = IdentityEnvBox(10)
|
||||
else:
|
||||
if model_class in [SAC, TD3]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3,
|
||||
discrete=False)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
|
||||
observations = observations.reshape(10, -1)
|
||||
observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0)
|
||||
|
||||
policy = model.policy
|
||||
policy_class = policy.__class__
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ class A2C(PPO):
|
|||
|
||||
self.normalize_advantage = normalize_advantage
|
||||
# Override PPO optimizer to match original implementation
|
||||
if use_rms_prop and 'optimizer' not in self.policy_kwargs:
|
||||
self.policy_kwargs['optimizer'] = th.optim.RMSprop
|
||||
if use_rms_prop and 'optimizer_class' not in self.policy_kwargs:
|
||||
self.policy_kwargs['optimizer_class'] = th.optim.RMSprop
|
||||
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
|
||||
weight_decay=0)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,79 @@ from torchy_baselines.common.utils import get_device, get_schedule_fn
|
|||
from torchy_baselines.common.vec_env import VecTransposeImage
|
||||
|
||||
|
||||
class BaseFeaturesExtractor(nn.Module):
|
||||
"""
|
||||
Base class that represents a features extractor.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param feature_dim: (int) Number of features extracted.
|
||||
"""
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
super(BaseFeaturesExtractor, self).__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
self._features_dim = features_dim
|
||||
|
||||
@property
|
||||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Feature extract that flatten the input.
|
||||
Used as a placeholder when feature extraction is not needed.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
"""
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.flatten(observations)
|
||||
|
||||
|
||||
class NatureCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
CNN from DQN nature paper: https://arxiv.org/abs/1312.5602
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param feature_dim: (int) Number of features extracted.
|
||||
This corresponds to the number of unit for the last layer.
|
||||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Box,
|
||||
features_dim: int = 512):
|
||||
super(NatureCNN, self).__init__(observation_space, features_dim)
|
||||
# TODO: custom init?
|
||||
# We assume CxWxH images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False)
|
||||
assert is_image_input, ('You should use NatureCNN '
|
||||
f'only with images not with {observation_space} '
|
||||
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten())
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
||||
|
||||
|
||||
class BasePolicy(nn.Module):
|
||||
"""
|
||||
The base policy object
|
||||
|
|
@ -21,26 +94,50 @@ class BasePolicy(nn.Module):
|
|||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor: Optional[nn.Module] = None,
|
||||
normalize_images: bool = True):
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
squash_output: bool = False):
|
||||
super(BasePolicy, self).__init__()
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.device = get_device(device)
|
||||
self.features_extractor = features_extractor
|
||||
self.normalize_images = normalize_images
|
||||
self._squash_output = squash_output
|
||||
|
||||
self.optimizer_class = optimizer_class
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
self.optimizer = None # type: Optional[th.optim.Optimizer]
|
||||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -62,7 +159,7 @@ class BasePolicy(nn.Module):
|
|||
"""
|
||||
Orthogonal initialization (used in PPO and A2C)
|
||||
"""
|
||||
if isinstance(module, nn.Linear):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
nn.init.orthogonal_(module.weight, gain=gain)
|
||||
module.bias.data.fill_(0.0)
|
||||
|
||||
|
|
@ -104,11 +201,18 @@ class BasePolicy(nn.Module):
|
|||
# mask = [False for _ in range(self.n_envs)]
|
||||
observation = np.array(observation)
|
||||
|
||||
# Handle the different cases for images
|
||||
# as PyTorch use channel first format
|
||||
if is_image_space(self.observation_space, channels_last=False):
|
||||
# TODO: handle the different cases
|
||||
if (observation.shape != self.observation_space.shape
|
||||
and observation.shape[1:] != self.observation_space.shape):
|
||||
observation = VecTransposeImage.transpose_image(observation)
|
||||
if (observation.shape == self.observation_space.shape or
|
||||
observation.shape[1:] == self.observation_space.shape):
|
||||
pass
|
||||
else:
|
||||
# Try to re-order the channels
|
||||
transpose_obs = VecTransposeImage.transpose_image(observation)
|
||||
if (transpose_obs.shape == self.observation_space.shape
|
||||
or transpose_obs.shape[1:] == self.observation_space.shape):
|
||||
observation = transpose_obs
|
||||
|
||||
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
|
||||
|
||||
|
|
@ -463,54 +567,3 @@ class MlpExtractor(nn.Module):
|
|||
"""
|
||||
shared_latent = self.shared_net(features)
|
||||
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
||||
|
||||
|
||||
class BaseFeaturesExtractor(nn.Module):
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
super(BaseFeaturesExtractor, self).__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
self._features_dim = features_dim
|
||||
|
||||
@property
|
||||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.flatten(observations)
|
||||
|
||||
|
||||
class NatureCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space,
|
||||
features_dim: int = 512):
|
||||
super(NatureCNN, self).__init__(observation_space, features_dim)
|
||||
# TODO: custom init?
|
||||
# TODO: check that the observation space is an image
|
||||
# we assume CxWxH images
|
||||
# assert is_image_space(observation_space), observation_space
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten())
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
||||
|
|
|
|||
|
|
@ -38,9 +38,11 @@ class PPOPolicy(BasePolicy):
|
|||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -60,10 +62,24 @@ class PPOPolicy(BasePolicy):
|
|||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=squash_output)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
|
|
@ -74,19 +90,10 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
if optimizer == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor = features_extractor_class(self.observation_space)
|
||||
self.features_extractor = features_extractor_class(self.observation_space,
|
||||
**self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
|
|
@ -124,10 +131,11 @@ class PPOPolicy(BasePolicy):
|
|||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
ortho_init=self.ortho_init,
|
||||
features_extractor_class=self.features_extractor_class
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
|
|
@ -173,9 +181,12 @@ class PPOPolicy(BasePolicy):
|
|||
# Init weights: use orthogonal initialization
|
||||
# with small initial weight for the output
|
||||
if self.ortho_init:
|
||||
for module in [self.mlp_extractor, self.action_net, self.value_net]:
|
||||
# TODO: check for features_extractor
|
||||
for module in [self.features_extractor, self.mlp_extractor,
|
||||
self.action_net, self.value_net]:
|
||||
# Values from stable-baselines, TODO: check why
|
||||
gain = {
|
||||
self.features_extractor: np.sqrt(2),
|
||||
self.mlp_extractor: np.sqrt(2),
|
||||
self.action_net: 0.01,
|
||||
self.value_net: 1
|
||||
|
|
@ -300,9 +311,11 @@ class CnnPolicy(PPOPolicy):
|
|||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -322,8 +335,9 @@ class CnnPolicy(PPOPolicy):
|
|||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
|
|
@ -339,8 +353,9 @@ class CnnPolicy(PPOPolicy):
|
|||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class SACPolicy(BasePolicy):
|
|||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -270,9 +270,15 @@ class SACPolicy(BasePolicy):
|
|||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
super(SACPolicy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=True)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
|
|
@ -280,18 +286,9 @@ class SACPolicy(BasePolicy):
|
|||
else:
|
||||
net_arch = []
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
|
||||
# Create shared features extractor
|
||||
self.features_extractor = features_extractor_class(self.observation_space,
|
||||
**self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
|
|
@ -347,7 +344,7 @@ class SACPolicy(BasePolicy):
|
|||
use_expln=self.actor_kwargs['use_expln'],
|
||||
clip_mean=self.actor_kwargs['clip_mean'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer=self.optimizer_class,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
|
|
@ -392,7 +389,7 @@ class CnnPolicy(SACPolicy):
|
|||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -411,7 +408,7 @@ class CnnPolicy(SACPolicy):
|
|||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
|
|
@ -427,7 +424,7 @@ class CnnPolicy(SACPolicy):
|
|||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ class TD3Policy(BasePolicy):
|
|||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -314,9 +314,15 @@ class TD3Policy(BasePolicy):
|
|||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
super(TD3Policy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=True)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
if net_arch is None:
|
||||
|
|
@ -325,18 +331,8 @@ class TD3Policy(BasePolicy):
|
|||
else:
|
||||
net_arch = []
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
|
||||
self.features_extractor = features_extractor_class(self.observation_space,
|
||||
**self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
|
|
@ -400,7 +396,7 @@ class TD3Policy(BasePolicy):
|
|||
sde_net_arch=self.actor_kwargs['sde_net_arch'],
|
||||
use_expln=self.actor_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer=self.optimizer_class,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
|
|
@ -449,7 +445,7 @@ class CnnPolicy(TD3Policy):
|
|||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
|
|
@ -469,7 +465,7 @@ class CnnPolicy(TD3Policy):
|
|||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
|
|
@ -486,7 +482,7 @@ class CnnPolicy(TD3Policy):
|
|||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
|
|
|
|||
Loading…
Reference in a new issue