Cleanup, bug fixes + more tests

This commit is contained in:
Antonin RAFFIN 2020-04-22 13:14:22 +02:00
parent 73fb8d1c63
commit 041f2bc59a
8 changed files with 208 additions and 132 deletions

View file

@ -12,10 +12,10 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
- Added ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
customizer optimizers
- Complete independent save/load for policies
- Add ``CnnPolicies`` to support images as input
- Add ``CnnPolicy`` and ``VecTransposeImage`` to support images as input
Bug Fixes:

View file

@ -22,7 +22,8 @@ def test_cnn(model_class):
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512)))
kwargs = dict(buffer_size=250,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
model = model_class('CnnPolicy', env, **kwargs).learn(250)
obs = env.reset()

View file

@ -8,6 +8,8 @@ import torch as th
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import IdentityEnvBox
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines.common.identity_env import FakeImageEnv
MODEL_LIST = [
PPO,
@ -30,12 +32,11 @@ def test_save_load(model_class):
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
observations = observations.reshape(10, -1)
observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0)
# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
@ -90,7 +91,7 @@ def test_set_env(model_class):
env3 = IdentityEnvBox(10)
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True)
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]))
# learn
model.learn(total_timesteps=1000, eval_freq=500)
@ -115,7 +116,7 @@ def test_exclude_include_saved_params(model_class):
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
# create model, set verbose as 2, which is not standard
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2, create_eval_env=True)
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2)
# Check if exclude works
model.save("test_save.zip", exclude=["verbose"])
@ -163,21 +164,34 @@ def test_save_load_replay_buffer(model_class):
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load_policy(model_class):
@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy'])
def test_save_load_policy(model_class, policy_str):
"""
Test saving and loading policy only.
:param model_class: (BaseRLModel) A RL model
:param policy_str: (str) Name of the policy.
"""
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
kwargs = {}
if policy_str == 'MlpPolicy':
env = IdentityEnvBox(10)
else:
if model_class in [SAC, TD3]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3,
discrete=False)
env = DummyVecEnv([lambda: env])
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
observations = observations.reshape(10, -1)
observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0)
policy = model.policy
policy_class = policy.__class__

View file

@ -82,8 +82,8 @@ class A2C(PPO):
self.normalize_advantage = normalize_advantage
# Override PPO optimizer to match original implementation
if use_rms_prop and 'optimizer' not in self.policy_kwargs:
self.policy_kwargs['optimizer'] = th.optim.RMSprop
if use_rms_prop and 'optimizer_class' not in self.policy_kwargs:
self.policy_kwargs['optimizer_class'] = th.optim.RMSprop
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
weight_decay=0)

View file

@ -12,6 +12,79 @@ from torchy_baselines.common.utils import get_device, get_schedule_fn
from torchy_baselines.common.vec_env import VecTransposeImage
class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space: (gym.Space)
:param feature_dim: (int) Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
super(BaseFeaturesExtractor, self).__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space: (gym.Space)
"""
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN nature paper: https://arxiv.org/abs/1312.5602
:param observation_space: (gym.Space)
:param feature_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box,
features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
# TODO: custom init?
# We assume CxWxH images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False)
assert is_image_input, ('You should use NatureCNN '
f'only with images not with {observation_space} '
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten())
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
class BasePolicy(nn.Module):
"""
The base policy object
@ -21,26 +94,50 @@ class BasePolicy(nn.Module):
:param device: (Union[th.device, str]) Device on which the code should run.
:param squash_output: (bool) For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[th.device, str] = 'auto',
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[nn.Module] = None,
normalize_images: bool = True):
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
squash_output: bool = False):
super(BasePolicy, self).__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.observation_space = observation_space
self.action_space = action_space
self.device = get_device(device)
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self._squash_output = squash_output
self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
self.optimizer = None # type: Optional[th.optim.Optimizer]
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
def extract_features(self, obs: th.Tensor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
@ -62,7 +159,7 @@ class BasePolicy(nn.Module):
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)
@ -104,11 +201,18 @@ class BasePolicy(nn.Module):
# mask = [False for _ in range(self.n_envs)]
observation = np.array(observation)
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space, channels_last=False):
# TODO: handle the different cases
if (observation.shape != self.observation_space.shape
and observation.shape[1:] != self.observation_space.shape):
observation = VecTransposeImage.transpose_image(observation)
if (observation.shape == self.observation_space.shape or
observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if (transpose_obs.shape == self.observation_space.shape
or transpose_obs.shape[1:] == self.observation_space.shape):
observation = transpose_obs
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
@ -463,54 +567,3 @@ class MlpExtractor(nn.Module):
"""
shared_latent = self.shared_net(features)
return self.policy_net(shared_latent), self.value_net(shared_latent)
class BaseFeaturesExtractor(nn.Module):
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
super(BaseFeaturesExtractor, self).__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space,
features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
# TODO: custom init?
# TODO: check that the observation space is an image
# we assume CxWxH images
# assert is_image_space(observation_space), observation_space
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten())
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))

View file

@ -38,9 +38,11 @@ class PPOPolicy(BasePolicy):
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -60,10 +62,24 @@ class PPOPolicy(BasePolicy):
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in ADAM optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs['eps'] = 1e-5
super(PPOPolicy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output)
# Default network architecture, from stable-baselines
if net_arch is None:
@ -74,19 +90,10 @@ class PPOPolicy(BasePolicy):
self.net_arch = net_arch
self.activation_fn = activation_fn
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in ADAM optimizer
if optimizer == th.optim.Adam:
optimizer_kwargs['eps'] = 1e-5
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.ortho_init = ortho_init
self.features_extractor_class = features_extractor_class
self.features_extractor = features_extractor_class(self.observation_space)
self.features_extractor = features_extractor_class(self.observation_space,
**self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
@ -124,10 +131,11 @@ class PPOPolicy(BasePolicy):
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
ortho_init=self.ortho_init,
features_extractor_class=self.features_extractor_class
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data
@ -173,9 +181,12 @@ class PPOPolicy(BasePolicy):
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
for module in [self.mlp_extractor, self.action_net, self.value_net]:
# TODO: check for features_extractor
for module in [self.features_extractor, self.mlp_extractor,
self.action_net, self.value_net]:
# Values from stable-baselines, TODO: check why
gain = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1
@ -300,9 +311,11 @@ class CnnPolicy(PPOPolicy):
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -322,8 +335,9 @@ class CnnPolicy(PPOPolicy):
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(CnnPolicy, self).__init__(observation_space,
action_space,
@ -339,8 +353,9 @@ class CnnPolicy(PPOPolicy):
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer,
optimizer_class,
optimizer_kwargs)

View file

@ -251,7 +251,7 @@ class SACPolicy(BasePolicy):
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -270,9 +270,15 @@ class SACPolicy(BasePolicy):
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
super(SACPolicy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True)
if net_arch is None:
if features_extractor_class == FlattenExtractor:
@ -280,18 +286,9 @@ class SACPolicy(BasePolicy):
else:
net_arch = []
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
# Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space,
**self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
@ -347,7 +344,7 @@ class SACPolicy(BasePolicy):
use_expln=self.actor_kwargs['use_expln'],
clip_mean=self.actor_kwargs['clip_mean'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer=self.optimizer_class,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
@ -392,7 +389,7 @@ class CnnPolicy(SACPolicy):
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -411,7 +408,7 @@ class CnnPolicy(SACPolicy):
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(CnnPolicy, self).__init__(observation_space,
action_space,
@ -427,7 +424,7 @@ class CnnPolicy(SACPolicy):
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer,
optimizer_class,
optimizer_kwargs)

View file

@ -294,7 +294,7 @@ class TD3Policy(BasePolicy):
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -314,9 +314,15 @@ class TD3Policy(BasePolicy):
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
super(TD3Policy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True)
# Default network architecture, from the original paper
if net_arch is None:
@ -325,18 +331,8 @@ class TD3Policy(BasePolicy):
else:
net_arch = []
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
self.features_extractor = features_extractor_class(self.observation_space,
**self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
@ -400,7 +396,7 @@ class TD3Policy(BasePolicy):
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer=self.optimizer_class,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
@ -449,7 +445,7 @@ class CnnPolicy(TD3Policy):
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
@ -469,7 +465,7 @@ class CnnPolicy(TD3Policy):
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(CnnPolicy, self).__init__(observation_space,
action_space,
@ -486,7 +482,7 @@ class CnnPolicy(TD3Policy):
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer,
optimizer_class,
optimizer_kwargs)
register_policy("MlpPolicy", MlpPolicy)