mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Avoid putting target networks into training mode (#553)
* make sure DQN policy is always in correct mode - train or eval * make set_training_mode an abstract method of the base policy - safer * update docstring of _build method to note that the target network is put into eval mode * use set_training_mode to put the dqn target network into eval mode * use set_training_mode to set the training model of the q-network * move set_training_mode abstract method from BasePolicy to BaseModel * set train and eval mode for TD3 * make sure critic is always in correct mode during train * set train and eval mode for SAC * add comment re batch norm and dropout * set train and eval mode for A2C and PPO * add tests for collect rollouts with batch norm * fix formatting * update change log * update version * remove Optional typing for batch size - causing type check to fail * Fix scipy dependency for toy text envs * implement set_training_mode method in BaseModel * move all tests of train/eval mode to test_train_eval_mode * call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm * remove extra calls to set_training_mode in train method of TD3 and SAC * Allow gradient_steps=0 * Refactor tests * Add comment + use aliases * Typos Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
3efab0d267
commit
1afc2f3abe
16 changed files with 446 additions and 89 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.2.0a2 (WIP)
|
||||
Release 1.2.0a3 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -17,6 +17,9 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed model predictions when using batch normalization and dropout layers by calling ``train()`` and ``eval()`` (@davidblom603)
|
||||
- Fixed model training for DQN, TD3 and SAC so that their target nets always remain in evaluation mode (@ayeright)
|
||||
- Passing ``gradient_steps=0`` to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment
|
||||
during the rollout in previous versions)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -738,4 +741,4 @@ And all the contributors:
|
|||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615
|
||||
@c-rizz @skandermoalla @MihaiAnca13 @davidblom603
|
||||
@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -100,6 +100,8 @@ setup(
|
|||
"isort>=5.0",
|
||||
# Reformat
|
||||
"black",
|
||||
# For toy text Gym envs
|
||||
"scipy>=1.4.1",
|
||||
],
|
||||
"docs": [
|
||||
"sphinx",
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
rollout buffer (one gradient step over whole data).
|
||||
"""
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.train()
|
||||
self.policy.set_training_mode(True)
|
||||
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
|
|
|||
|
|
@ -365,8 +365,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
# If no `gradient_steps` is specified,
|
||||
# do as many gradients steps as steps performed during the rollout
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
||||
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
|
||||
# Special case when the user passes `gradient_steps=0`
|
||||
if gradient_steps > 0:
|
||||
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
|
|
@ -537,7 +539,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
:return:
|
||||
"""
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.eval()
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
episode_rewards, total_timesteps = [], []
|
||||
num_collected_steps, num_collected_episodes = 0, 0
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.policy.eval()
|
||||
self.policy.set_training_mode(False)
|
||||
|
||||
n_steps = 0
|
||||
rollout_buffer.reset()
|
||||
|
|
|
|||
|
|
@ -194,6 +194,16 @@ class BaseModel(nn.Module, ABC):
|
|||
"""
|
||||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
"""
|
||||
Put the policy in either training or evaluation mode.
|
||||
|
||||
This affects certain modules, such as batch normalisation and dropout.
|
||||
|
||||
:param mode: if true, set to training mode, else set to evaluation mode
|
||||
"""
|
||||
self.train(mode)
|
||||
|
||||
|
||||
class BasePolicy(BaseModel):
|
||||
"""The base policy object.
|
||||
|
|
@ -268,7 +278,7 @@ class BasePolicy(BaseModel):
|
|||
# if mask is None:
|
||||
# mask = [False for _ in range(self.n_envs)]
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.eval()
|
||||
self.set_training_mode(False)
|
||||
|
||||
vectorized_env = False
|
||||
if isinstance(observation, dict):
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.train()
|
||||
self.policy.set_training_mode(True)
|
||||
# Update learning rate according to schedule
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
|
|
|
|||
|
|
@ -152,6 +152,8 @@ class DQNPolicy(BasePolicy):
|
|||
"""
|
||||
Create the network and the optimizer.
|
||||
|
||||
Put the target network into evaluation mode.
|
||||
|
||||
:param lr_schedule: Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
|
|
@ -159,6 +161,7 @@ class DQNPolicy(BasePolicy):
|
|||
self.q_net = self.make_q_net()
|
||||
self.q_net_target = self.make_q_net()
|
||||
self.q_net_target.load_state_dict(self.q_net.state_dict())
|
||||
self.q_net_target.set_training_mode(False)
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
|
@ -190,6 +193,17 @@ class DQNPolicy(BasePolicy):
|
|||
)
|
||||
return data
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
"""
|
||||
Put the policy in either training or evaluation mode.
|
||||
|
||||
This affects certain modules, such as batch normalisation and dropout.
|
||||
|
||||
:param mode: if true, set to training mode, else set to evaluation mode
|
||||
"""
|
||||
self.q_net.set_training_mode(mode)
|
||||
self.training = mode
|
||||
|
||||
|
||||
MlpPolicy = DQNPolicy
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Schedule] = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
batch_size: Optional[int] = 64,
|
||||
batch_size: int = 64,
|
||||
n_epochs: int = 10,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
|
|
@ -167,7 +167,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
Update policy using the currently gathered rollout buffer.
|
||||
"""
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.train()
|
||||
self.policy.set_training_mode(True)
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
|
|
|
|||
|
|
@ -317,6 +317,9 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
# Target networks should always be in eval mode
|
||||
self.critic_target.set_training_mode(False)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
|
|
@ -361,6 +364,18 @@ class SACPolicy(BasePolicy):
|
|||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
"""
|
||||
Put the policy in either training or evaluation mode.
|
||||
|
||||
This affects certain modules, such as batch normalisation and dropout.
|
||||
|
||||
:param mode: if true, set to training mode, else set to evaluation mode
|
||||
"""
|
||||
self.actor.set_training_mode(mode)
|
||||
self.critic.set_training_mode(mode)
|
||||
self.training = mode
|
||||
|
||||
|
||||
MlpPolicy = SACPolicy
|
||||
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.train()
|
||||
self.policy.set_training_mode(True)
|
||||
# Update optimizers learning rate
|
||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||
if self.ent_coef_optimizer is not None:
|
||||
|
|
|
|||
|
|
@ -191,6 +191,10 @@ class TD3Policy(BasePolicy):
|
|||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
# Target networks should always be in eval mode
|
||||
self.actor_target.set_training_mode(False)
|
||||
self.critic_target.set_training_mode(False)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
|
|
@ -225,6 +229,18 @@ class TD3Policy(BasePolicy):
|
|||
# Predictions are always deterministic.
|
||||
return self.actor(observation)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
"""
|
||||
Put the policy in either training or evaluation mode.
|
||||
|
||||
This affects certain modules, such as batch normalisation and dropout.
|
||||
|
||||
:param mode: if true, set to training mode, else set to evaluation mode
|
||||
"""
|
||||
self.actor.set_training_mode(mode)
|
||||
self.critic.set_training_mode(mode)
|
||||
self.training = mode
|
||||
|
||||
|
||||
MlpPolicy = TD3Policy
|
||||
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Switch to train mode (this affects batch norm / dropout)
|
||||
self.policy.train()
|
||||
self.policy.set_training_mode(True)
|
||||
|
||||
# Update learning rate according to lr schedule
|
||||
self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.2.0a2
|
||||
1.2.0a3
|
||||
|
|
|
|||
|
|
@ -1,12 +1,8 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
|
|
@ -73,74 +69,3 @@ def test_predict(model_class, env_id, device):
|
|||
|
||||
action, _ = model.predict(vec_env_obs, deterministic=False)
|
||||
assert action.shape[0] == vec_env_obs.shape[0]
|
||||
|
||||
|
||||
class FlattenBatchNormExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Feature extract that flatten the input and uses batch normalization.
|
||||
Used as a placeholder when feature extraction is not needed.
|
||||
|
||||
:param observation_space:
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenBatchNormExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
self.batch_norm = nn.BatchNorm1d(self._features_dim)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
result = self.flatten(observations)
|
||||
result = self.batch_norm(result)
|
||||
result = self.dropout(result)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
|
||||
def test_batch_norm_dropout(model_class, env_id):
|
||||
|
||||
if env_id == "CartPole-v1":
|
||||
if model_class in [SAC, TD3]:
|
||||
return
|
||||
elif model_class in [DQN]:
|
||||
return
|
||||
|
||||
model_kwargs = dict(seed=1)
|
||||
|
||||
if model_class in [DQN, TD3, SAC]:
|
||||
model_kwargs["learning_starts"] = 0
|
||||
else:
|
||||
model_kwargs["n_steps"] = 64
|
||||
|
||||
policy_kwargs = dict(
|
||||
features_extractor_class=FlattenBatchNormExtractor,
|
||||
net_arch=[16, 16],
|
||||
)
|
||||
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)
|
||||
|
||||
if model_class in [SAC, TD3]:
|
||||
batch_norm = model.policy.actor.features_extractor.batch_norm
|
||||
elif model_class in [PPO, A2C]:
|
||||
batch_norm = model.policy.features_extractor.batch_norm
|
||||
else:
|
||||
# DQN
|
||||
batch_norm = model.policy.q_net.features_extractor.batch_norm
|
||||
|
||||
# batch norm param before training
|
||||
bias_before_learn = batch_norm.bias.detach().cpu().numpy().copy()
|
||||
running_mean_before_learn = batch_norm.running_mean.detach().cpu().numpy().copy()
|
||||
model.learn(100)
|
||||
env = model.get_env()
|
||||
observation = env.reset()
|
||||
|
||||
bias_after_learn = batch_norm.bias.detach().cpu().numpy()
|
||||
running_mean_after_learn = batch_norm.running_mean.detach().cpu().numpy().copy()
|
||||
|
||||
# Run twice on the same observation to test if it is deterministic
|
||||
first_prediction, _ = model.predict(observation, deterministic=True)
|
||||
second_prediction, _ = model.predict(observation, deterministic=True)
|
||||
|
||||
np.testing.assert_allclose(first_prediction, second_prediction)
|
||||
assert not np.allclose(bias_before_learn, bias_after_learn)
|
||||
assert not np.allclose(running_mean_before_learn, running_mean_after_learn)
|
||||
|
|
|
|||
370
tests/test_train_eval_mode.py
Normal file
370
tests/test_train_eval_mode.py
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
from typing import Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
MODEL_LIST = [
|
||||
PPO,
|
||||
A2C,
|
||||
TD3,
|
||||
SAC,
|
||||
DQN,
|
||||
]
|
||||
|
||||
|
||||
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Feature extract that flatten the input and applies batch normalization and dropout.
|
||||
Used as a placeholder when feature extraction is not needed.
|
||||
|
||||
:param observation_space:
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenBatchNormDropoutExtractor, self).__init__(
|
||||
observation_space,
|
||||
get_flattened_obs_dim(observation_space),
|
||||
)
|
||||
self.flatten = nn.Flatten()
|
||||
self.batch_norm = nn.BatchNorm1d(self._features_dim)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
result = self.flatten(observations)
|
||||
result = self.batch_norm(result)
|
||||
result = self.dropout(result)
|
||||
return result
|
||||
|
||||
|
||||
def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
|
||||
"""
|
||||
Clone the bias and running mean from the given batch norm layer.
|
||||
|
||||
:param batch_norm:
|
||||
:return: the bias and running mean
|
||||
"""
|
||||
return batch_norm.bias.clone(), batch_norm.running_mean.clone()
|
||||
|
||||
|
||||
def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
|
||||
"""
|
||||
Clone the bias and running mean from the Q-network and target network.
|
||||
|
||||
:param model:
|
||||
:return: the bias and running mean from the Q-network and target network
|
||||
"""
|
||||
q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm
|
||||
q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm)
|
||||
|
||||
q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm
|
||||
q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm)
|
||||
|
||||
return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean
|
||||
|
||||
|
||||
def clone_td3_batch_norm_stats(
|
||||
model: TD3,
|
||||
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
|
||||
"""
|
||||
Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks.
|
||||
|
||||
:param model:
|
||||
:return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks
|
||||
"""
|
||||
actor_batch_norm = model.actor.features_extractor.batch_norm
|
||||
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
|
||||
|
||||
critic_batch_norm = model.critic.features_extractor.batch_norm
|
||||
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
|
||||
|
||||
actor_target_batch_norm = model.actor_target.features_extractor.batch_norm
|
||||
actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm)
|
||||
|
||||
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
|
||||
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
|
||||
|
||||
return (
|
||||
actor_bias,
|
||||
actor_running_mean,
|
||||
critic_bias,
|
||||
critic_running_mean,
|
||||
actor_target_bias,
|
||||
actor_target_running_mean,
|
||||
critic_target_bias,
|
||||
critic_target_running_mean,
|
||||
)
|
||||
|
||||
|
||||
def clone_sac_batch_norm_stats(
|
||||
model: SAC,
|
||||
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
|
||||
"""
|
||||
Clone the bias and running mean from the actor and critic networks and critic-target networks.
|
||||
|
||||
:param model:
|
||||
:return: the bias and running mean from the actor and critic networks and critic-target networks
|
||||
"""
|
||||
actor_batch_norm = model.actor.features_extractor.batch_norm
|
||||
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
|
||||
|
||||
critic_batch_norm = model.critic.features_extractor.batch_norm
|
||||
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
|
||||
|
||||
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
|
||||
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
|
||||
|
||||
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
|
||||
|
||||
|
||||
def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor):
|
||||
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
|
||||
|
||||
|
||||
CLONE_HELPERS = {
|
||||
A2C: clone_on_policy_batch_norm,
|
||||
DQN: clone_dqn_batch_norm_stats,
|
||||
SAC: clone_sac_batch_norm_stats,
|
||||
TD3: clone_td3_batch_norm_stats,
|
||||
PPO: clone_on_policy_batch_norm,
|
||||
}
|
||||
|
||||
|
||||
def test_dqn_train_with_batch_norm():
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
learning_starts=0,
|
||||
seed=1,
|
||||
tau=0, # do not clone the target
|
||||
)
|
||||
|
||||
(
|
||||
q_net_bias_before,
|
||||
q_net_running_mean_before,
|
||||
q_net_target_bias_before,
|
||||
q_net_target_running_mean_before,
|
||||
) = clone_dqn_batch_norm_stats(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
(
|
||||
q_net_bias_after,
|
||||
q_net_running_mean_after,
|
||||
q_net_target_bias_after,
|
||||
q_net_target_running_mean_after,
|
||||
) = clone_dqn_batch_norm_stats(model)
|
||||
|
||||
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
|
||||
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
|
||||
|
||||
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
|
||||
assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()
|
||||
|
||||
|
||||
def test_td3_train_with_batch_norm():
|
||||
model = TD3(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
learning_starts=0,
|
||||
tau=0, # do not copy the target
|
||||
seed=1,
|
||||
)
|
||||
|
||||
(
|
||||
actor_bias_before,
|
||||
actor_running_mean_before,
|
||||
critic_bias_before,
|
||||
critic_running_mean_before,
|
||||
actor_target_bias_before,
|
||||
actor_target_running_mean_before,
|
||||
critic_target_bias_before,
|
||||
critic_target_running_mean_before,
|
||||
) = clone_td3_batch_norm_stats(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
(
|
||||
actor_bias_after,
|
||||
actor_running_mean_after,
|
||||
critic_bias_after,
|
||||
critic_running_mean_after,
|
||||
actor_target_bias_after,
|
||||
actor_target_running_mean_after,
|
||||
critic_target_bias_after,
|
||||
critic_target_running_mean_after,
|
||||
) = clone_td3_batch_norm_stats(model)
|
||||
|
||||
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
|
||||
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
|
||||
|
||||
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
|
||||
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
|
||||
|
||||
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
|
||||
assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()
|
||||
|
||||
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
|
||||
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
|
||||
|
||||
|
||||
def test_sac_train_with_batch_norm():
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
learning_starts=0,
|
||||
tau=0, # do not copy the target
|
||||
seed=1,
|
||||
)
|
||||
|
||||
(
|
||||
actor_bias_before,
|
||||
actor_running_mean_before,
|
||||
critic_bias_before,
|
||||
critic_running_mean_before,
|
||||
critic_target_bias_before,
|
||||
critic_target_running_mean_before,
|
||||
) = clone_sac_batch_norm_stats(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
(
|
||||
actor_bias_after,
|
||||
actor_running_mean_after,
|
||||
critic_bias_after,
|
||||
critic_running_mean_after,
|
||||
critic_target_bias_after,
|
||||
critic_target_running_mean_after,
|
||||
) = clone_sac_batch_norm_stats(model)
|
||||
|
||||
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
|
||||
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
|
||||
|
||||
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
|
||||
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
|
||||
|
||||
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
|
||||
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
|
||||
def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
|
||||
|
||||
model.learn(total_timesteps=200)
|
||||
|
||||
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
|
||||
|
||||
assert ~th.isclose(bias_before, bias_after).all()
|
||||
assert ~th.isclose(running_mean_before, running_mean_after).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [DQN, TD3, SAC])
|
||||
def test_offpolicy_collect_rollout_batch_norm(model_class):
|
||||
if model_class in [DQN]:
|
||||
env_id = "CartPole-v1"
|
||||
else:
|
||||
env_id = "Pendulum-v0"
|
||||
|
||||
clone_helper = CLONE_HELPERS[model_class]
|
||||
|
||||
learning_starts = 10
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
learning_starts=learning_starts,
|
||||
seed=1,
|
||||
gradient_steps=0,
|
||||
train_freq=1,
|
||||
)
|
||||
|
||||
batch_norm_stats_before = clone_helper(model)
|
||||
|
||||
model.learn(total_timesteps=100)
|
||||
|
||||
batch_norm_stats_after = clone_helper(model)
|
||||
|
||||
# No change in batch norm params
|
||||
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
|
||||
assert th.isclose(param_before, param_after).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
|
||||
def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
|
||||
seed=1,
|
||||
n_steps=64,
|
||||
)
|
||||
|
||||
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
|
||||
|
||||
total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64, eval_env=model.get_env())
|
||||
|
||||
for _ in range(2):
|
||||
model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)
|
||||
|
||||
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
|
||||
|
||||
assert th.isclose(bias_before, bias_after).all()
|
||||
assert th.isclose(running_mean_before, running_mean_after).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
|
||||
def test_predict_with_dropout_batch_norm(model_class, env_id):
|
||||
if env_id == "CartPole-v1":
|
||||
if model_class in [SAC, TD3]:
|
||||
return
|
||||
elif model_class in [DQN]:
|
||||
return
|
||||
|
||||
model_kwargs = dict(seed=1)
|
||||
clone_helper = CLONE_HELPERS[model_class]
|
||||
|
||||
if model_class in [DQN, TD3, SAC]:
|
||||
model_kwargs["learning_starts"] = 0
|
||||
else:
|
||||
model_kwargs["n_steps"] = 64
|
||||
|
||||
policy_kwargs = dict(
|
||||
features_extractor_class=FlattenBatchNormDropoutExtractor,
|
||||
net_arch=[16, 16],
|
||||
)
|
||||
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)
|
||||
|
||||
batch_norm_stats_before = clone_helper(model)
|
||||
|
||||
env = model.get_env()
|
||||
observation = env.reset()
|
||||
first_prediction, _ = model.predict(observation, deterministic=True)
|
||||
for _ in range(5):
|
||||
prediction, _ = model.predict(observation, deterministic=True)
|
||||
np.testing.assert_allclose(first_prediction, prediction)
|
||||
|
||||
batch_norm_stats_after = clone_helper(model)
|
||||
|
||||
# No change in batch norm params
|
||||
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
|
||||
assert th.isclose(param_before, param_after).all()
|
||||
Loading…
Reference in a new issue