From 1afc2f3abe09cd4c16a6a70e16b08e9e6d138ea1 Mon Sep 17 00:00:00 2001 From: Scott Brownlie Date: Mon, 30 Aug 2021 16:42:41 +0100 Subject: [PATCH] Avoid putting target networks into training mode (#553) * make sure DQN policy is always in correct mode - train or eval * make set_training_mode an abstract method of the base policy - safer * update docstring of _build method to note that the target network is put into eval mode * use set_training_mode to put the dqn target network into eval mode * use set_training_mode to set the training model of the q-network * move set_training_mode abstract method from BasePolicy to BaseModel * set train and eval mode for TD3 * make sure critic is always in correct mode during train * set train and eval mode for SAC * add comment re batch norm and dropout * set train and eval mode for A2C and PPO * add tests for collect rollouts with batch norm * fix formatting * update change log * update version * remove Optional typing for batch size - causing type check to fail * Fix scipy dependency for toy text envs * implement set_training_mode method in BaseModel * move all tests of train/eval mode to test_train_eval_mode * call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm * remove extra calls to set_training_mode in train method of TD3 and SAC * Allow gradient_steps=0 * Refactor tests * Add comment + use aliases * Typos Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 7 +- setup.py | 2 + stable_baselines3/a2c/a2c.py | 2 +- .../common/off_policy_algorithm.py | 8 +- .../common/on_policy_algorithm.py | 2 +- stable_baselines3/common/policies.py | 12 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/dqn/policies.py | 14 + stable_baselines3/ppo/ppo.py | 4 +- stable_baselines3/sac/policies.py | 15 + stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/policies.py | 16 + stable_baselines3/td3/td3.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_predict.py | 75 ---- tests/test_train_eval_mode.py | 370 ++++++++++++++++++ 16 files changed, 446 insertions(+), 89 deletions(-) create mode 100644 tests/test_train_eval_mode.py diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6a1965a..1872c23 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.2.0a2 (WIP) +Release 1.2.0a3 (WIP) --------------------------- Breaking Changes: @@ -17,6 +17,9 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed model predictions when using batch normalization and dropout layers by calling ``train()`` and ``eval()`` (@davidblom603) +- Fixed model training for DQN, TD3 and SAC so that their target nets always remain in evaluation mode (@ayeright) +- Passing ``gradient_steps=0`` to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment + during the rollout in previous versions) Deprecations: ^^^^^^^^^^^^^ @@ -738,4 +741,4 @@ And all the contributors: @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 -@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 +@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright diff --git a/setup.py b/setup.py index 4c2b9e9..909e3b1 100644 --- a/setup.py +++ b/setup.py @@ -100,6 +100,8 @@ setup( "isort>=5.0", # Reformat "black", + # For toy text Gym envs + "scipy>=1.4.1", ], "docs": [ "sphinx", diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 88c8992..6641177 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -120,7 +120,7 @@ class A2C(OnPolicyAlgorithm): rollout buffer (one gradient step over whole data). """ # Switch to train mode (this affects batch norm / dropout) - self.policy.train() + self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index ae8569f..fce62e4 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -365,8 +365,10 @@ class OffPolicyAlgorithm(BaseAlgorithm): if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout - gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps - self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps + # Special case when the user passes `gradient_steps=0` + if gradient_steps > 0: + self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) callback.on_training_end() @@ -537,7 +539,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): :return: """ # Switch to eval mode (this affects batch norm / dropout) - self.policy.eval() + self.policy.set_training_mode(False) episode_rewards, total_timesteps = [], [] num_collected_steps, num_collected_episodes = 0, 0 diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index eb3417c..41e193d 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -148,7 +148,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): """ assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) - self.policy.eval() + self.policy.set_training_mode(False) n_steps = 0 rollout_buffer.reset() diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 377b548..0ed6be5 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -194,6 +194,16 @@ class BaseModel(nn.Module, ABC): """ return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.train(mode) + class BasePolicy(BaseModel): """The base policy object. @@ -268,7 +278,7 @@ class BasePolicy(BaseModel): # if mask is None: # mask = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) - self.eval() + self.set_training_mode(False) vectorized_env = False if isinstance(observation, dict): diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 9f5214b..a99220b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -153,7 +153,7 @@ class DQN(OffPolicyAlgorithm): def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) - self.policy.train() + self.policy.set_training_mode(True) # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index d39d9f2..6a8e6e1 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -152,6 +152,8 @@ class DQNPolicy(BasePolicy): """ Create the network and the optimizer. + Put the target network into evaluation mode. + :param lr_schedule: Learning rate schedule lr_schedule(1) is the initial learning rate """ @@ -159,6 +161,7 @@ class DQNPolicy(BasePolicy): self.q_net = self.make_q_net() self.q_net_target = self.make_q_net() self.q_net_target.load_state_dict(self.q_net.state_dict()) + self.q_net_target.set_training_mode(False) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -190,6 +193,17 @@ class DQNPolicy(BasePolicy): ) return data + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.q_net.set_training_mode(mode) + self.training = mode + MlpPolicy = DQNPolicy diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index b3a8e99..ab1129a 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -70,7 +70,7 @@ class PPO(OnPolicyAlgorithm): env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, - batch_size: Optional[int] = 64, + batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -167,7 +167,7 @@ class PPO(OnPolicyAlgorithm): Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) - self.policy.train() + self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 86d2185..8bb67df 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -317,6 +317,9 @@ class SACPolicy(BasePolicy): self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + # Target networks should always be in eval mode + self.critic_target.set_training_mode(False) + def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -361,6 +364,18 @@ class SACPolicy(BasePolicy): def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self.actor(observation, deterministic) + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.actor.set_training_mode(mode) + self.critic.set_training_mode(mode) + self.training = mode + MlpPolicy = SACPolicy diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index f53e399..605f086 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -181,7 +181,7 @@ class SAC(OffPolicyAlgorithm): def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Switch to train mode (this affects batch norm / dropout) - self.policy.train() + self.policy.set_training_mode(True) # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] if self.ent_coef_optimizer is not None: diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 1288d78..44f80d0 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -191,6 +191,10 @@ class TD3Policy(BasePolicy): self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + # Target networks should always be in eval mode + self.actor_target.set_training_mode(False) + self.critic_target.set_training_mode(False) + def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -225,6 +229,18 @@ class TD3Policy(BasePolicy): # Predictions are always deterministic. return self.actor(observation) + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.actor.set_training_mode(mode) + self.critic.set_training_mode(mode) + self.training = mode + MlpPolicy = TD3Policy diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index d8ad25d..1eb28f7 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -133,7 +133,7 @@ class TD3(OffPolicyAlgorithm): def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) - self.policy.train() + self.policy.set_training_mode(True) # Update learning rate according to lr schedule self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 7ce60fd..9b27bf0 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.2.0a2 +1.2.0a3 diff --git a/tests/test_predict.py b/tests/test_predict.py index 689c75d..2927796 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,12 +1,8 @@ import gym -import numpy as np import pytest import torch as th -import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 -from stable_baselines3.common.preprocessing import get_flattened_obs_dim -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -73,74 +69,3 @@ def test_predict(model_class, env_id, device): action, _ = model.predict(vec_env_obs, deterministic=False) assert action.shape[0] == vec_env_obs.shape[0] - - -class FlattenBatchNormExtractor(BaseFeaturesExtractor): - """ - Feature extract that flatten the input and uses batch normalization. - Used as a placeholder when feature extraction is not needed. - - :param observation_space: - """ - - def __init__(self, observation_space: gym.Space): - super(FlattenBatchNormExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) - self.flatten = nn.Flatten() - self.batch_norm = nn.BatchNorm1d(self._features_dim) - self.dropout = nn.Dropout(0.5) - - def forward(self, observations: th.Tensor) -> th.Tensor: - result = self.flatten(observations) - result = self.batch_norm(result) - result = self.dropout(result) - return result - - -@pytest.mark.parametrize("model_class", MODEL_LIST) -@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) -def test_batch_norm_dropout(model_class, env_id): - - if env_id == "CartPole-v1": - if model_class in [SAC, TD3]: - return - elif model_class in [DQN]: - return - - model_kwargs = dict(seed=1) - - if model_class in [DQN, TD3, SAC]: - model_kwargs["learning_starts"] = 0 - else: - model_kwargs["n_steps"] = 64 - - policy_kwargs = dict( - features_extractor_class=FlattenBatchNormExtractor, - net_arch=[16, 16], - ) - model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs) - - if model_class in [SAC, TD3]: - batch_norm = model.policy.actor.features_extractor.batch_norm - elif model_class in [PPO, A2C]: - batch_norm = model.policy.features_extractor.batch_norm - else: - # DQN - batch_norm = model.policy.q_net.features_extractor.batch_norm - - # batch norm param before training - bias_before_learn = batch_norm.bias.detach().cpu().numpy().copy() - running_mean_before_learn = batch_norm.running_mean.detach().cpu().numpy().copy() - model.learn(100) - env = model.get_env() - observation = env.reset() - - bias_after_learn = batch_norm.bias.detach().cpu().numpy() - running_mean_after_learn = batch_norm.running_mean.detach().cpu().numpy().copy() - - # Run twice on the same observation to test if it is deterministic - first_prediction, _ = model.predict(observation, deterministic=True) - second_prediction, _ = model.predict(observation, deterministic=True) - - np.testing.assert_allclose(first_prediction, second_prediction) - assert not np.allclose(bias_before_learn, bias_after_learn) - assert not np.allclose(running_mean_before_learn, running_mean_after_learn) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py new file mode 100644 index 0000000..c5eb283 --- /dev/null +++ b/tests/test_train_eval_mode.py @@ -0,0 +1,370 @@ +from typing import Union + +import gym +import numpy as np +import pytest +import torch as th +import torch.nn as nn + +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3.common.preprocessing import get_flattened_obs_dim +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + +MODEL_LIST = [ + PPO, + A2C, + TD3, + SAC, + DQN, +] + + +class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): + """ + Feature extract that flatten the input and applies batch normalization and dropout. + Used as a placeholder when feature extraction is not needed. + + :param observation_space: + """ + + def __init__(self, observation_space: gym.Space): + super(FlattenBatchNormDropoutExtractor, self).__init__( + observation_space, + get_flattened_obs_dim(observation_space), + ) + self.flatten = nn.Flatten() + self.batch_norm = nn.BatchNorm1d(self._features_dim) + self.dropout = nn.Dropout(0.5) + + def forward(self, observations: th.Tensor) -> th.Tensor: + result = self.flatten(observations) + result = self.batch_norm(result) + result = self.dropout(result) + return result + + +def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor): + """ + Clone the bias and running mean from the given batch norm layer. + + :param batch_norm: + :return: the bias and running mean + """ + return batch_norm.bias.clone(), batch_norm.running_mean.clone() + + +def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor): + """ + Clone the bias and running mean from the Q-network and target network. + + :param model: + :return: the bias and running mean from the Q-network and target network + """ + q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm + q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm) + + q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm + q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm) + + return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean + + +def clone_td3_batch_norm_stats( + model: TD3, +) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): + """ + Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks. + + :param model: + :return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks + """ + actor_batch_norm = model.actor.features_extractor.batch_norm + actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm) + + critic_batch_norm = model.critic.features_extractor.batch_norm + critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm) + + actor_target_batch_norm = model.actor_target.features_extractor.batch_norm + actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm) + + critic_target_batch_norm = model.critic_target.features_extractor.batch_norm + critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm) + + return ( + actor_bias, + actor_running_mean, + critic_bias, + critic_running_mean, + actor_target_bias, + actor_target_running_mean, + critic_target_bias, + critic_target_running_mean, + ) + + +def clone_sac_batch_norm_stats( + model: SAC, +) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor): + """ + Clone the bias and running mean from the actor and critic networks and critic-target networks. + + :param model: + :return: the bias and running mean from the actor and critic networks and critic-target networks + """ + actor_batch_norm = model.actor.features_extractor.batch_norm + actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm) + + critic_batch_norm = model.critic.features_extractor.batch_norm + critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm) + + critic_target_batch_norm = model.critic_target.features_extractor.batch_norm + critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm) + + return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean) + + +def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor): + return clone_batch_norm_stats(model.policy.features_extractor.batch_norm) + + +CLONE_HELPERS = { + A2C: clone_on_policy_batch_norm, + DQN: clone_dqn_batch_norm_stats, + SAC: clone_sac_batch_norm_stats, + TD3: clone_td3_batch_norm_stats, + PPO: clone_on_policy_batch_norm, +} + + +def test_dqn_train_with_batch_norm(): + model = DQN( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + learning_starts=0, + seed=1, + tau=0, # do not clone the target + ) + + ( + q_net_bias_before, + q_net_running_mean_before, + q_net_target_bias_before, + q_net_target_running_mean_before, + ) = clone_dqn_batch_norm_stats(model) + + model.learn(total_timesteps=200) + + ( + q_net_bias_after, + q_net_running_mean_after, + q_net_target_bias_after, + q_net_target_running_mean_after, + ) = clone_dqn_batch_norm_stats(model) + + assert ~th.isclose(q_net_bias_before, q_net_bias_after).all() + assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all() + + assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all() + assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all() + + +def test_td3_train_with_batch_norm(): + model = TD3( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + learning_starts=0, + tau=0, # do not copy the target + seed=1, + ) + + ( + actor_bias_before, + actor_running_mean_before, + critic_bias_before, + critic_running_mean_before, + actor_target_bias_before, + actor_target_running_mean_before, + critic_target_bias_before, + critic_target_running_mean_before, + ) = clone_td3_batch_norm_stats(model) + + model.learn(total_timesteps=200) + + ( + actor_bias_after, + actor_running_mean_after, + critic_bias_after, + critic_running_mean_after, + actor_target_bias_after, + actor_target_running_mean_after, + critic_target_bias_after, + critic_target_running_mean_after, + ) = clone_td3_batch_norm_stats(model) + + assert ~th.isclose(actor_bias_before, actor_bias_after).all() + assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() + + assert ~th.isclose(critic_bias_before, critic_bias_after).all() + assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() + + assert th.isclose(actor_target_bias_before, actor_target_bias_after).all() + assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all() + + assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() + assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all() + + +def test_sac_train_with_batch_norm(): + model = SAC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + learning_starts=0, + tau=0, # do not copy the target + seed=1, + ) + + ( + actor_bias_before, + actor_running_mean_before, + critic_bias_before, + critic_running_mean_before, + critic_target_bias_before, + critic_target_running_mean_before, + ) = clone_sac_batch_norm_stats(model) + + model.learn(total_timesteps=200) + + ( + actor_bias_after, + actor_running_mean_after, + critic_bias_after, + critic_running_mean_after, + critic_target_bias_after, + critic_target_running_mean_after, + ) = clone_sac_batch_norm_stats(model) + + assert ~th.isclose(actor_bias_before, actor_bias_after).all() + assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() + + assert ~th.isclose(critic_bias_before, critic_bias_after).all() + assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() + + assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() + assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all() + + +@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) +def test_a2c_ppo_train_with_batch_norm(model_class, env_id): + model = model_class( + "MlpPolicy", + env_id, + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + seed=1, + ) + + bias_before, running_mean_before = clone_on_policy_batch_norm(model) + + model.learn(total_timesteps=200) + + bias_after, running_mean_after = clone_on_policy_batch_norm(model) + + assert ~th.isclose(bias_before, bias_after).all() + assert ~th.isclose(running_mean_before, running_mean_after).all() + + +@pytest.mark.parametrize("model_class", [DQN, TD3, SAC]) +def test_offpolicy_collect_rollout_batch_norm(model_class): + if model_class in [DQN]: + env_id = "CartPole-v1" + else: + env_id = "Pendulum-v0" + + clone_helper = CLONE_HELPERS[model_class] + + learning_starts = 10 + model = model_class( + "MlpPolicy", + env_id, + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + learning_starts=learning_starts, + seed=1, + gradient_steps=0, + train_freq=1, + ) + + batch_norm_stats_before = clone_helper(model) + + model.learn(total_timesteps=100) + + batch_norm_stats_after = clone_helper(model) + + # No change in batch norm params + for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after): + assert th.isclose(param_before, param_after).all() + + +@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) +def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id): + model = model_class( + "MlpPolicy", + env_id, + policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), + seed=1, + n_steps=64, + ) + + bias_before, running_mean_before = clone_on_policy_batch_norm(model) + + total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64, eval_env=model.get_env()) + + for _ in range(2): + model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps) + + bias_after, running_mean_after = clone_on_policy_batch_norm(model) + + assert th.isclose(bias_before, bias_after).all() + assert th.isclose(running_mean_before, running_mean_after).all() + + +@pytest.mark.parametrize("model_class", MODEL_LIST) +@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) +def test_predict_with_dropout_batch_norm(model_class, env_id): + if env_id == "CartPole-v1": + if model_class in [SAC, TD3]: + return + elif model_class in [DQN]: + return + + model_kwargs = dict(seed=1) + clone_helper = CLONE_HELPERS[model_class] + + if model_class in [DQN, TD3, SAC]: + model_kwargs["learning_starts"] = 0 + else: + model_kwargs["n_steps"] = 64 + + policy_kwargs = dict( + features_extractor_class=FlattenBatchNormDropoutExtractor, + net_arch=[16, 16], + ) + model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs) + + batch_norm_stats_before = clone_helper(model) + + env = model.get_env() + observation = env.reset() + first_prediction, _ = model.predict(observation, deterministic=True) + for _ in range(5): + prediction, _ = model.predict(observation, deterministic=True) + np.testing.assert_allclose(first_prediction, prediction) + + batch_norm_stats_after = clone_helper(model) + + # No change in batch norm params + for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after): + assert th.isclose(param_before, param_after).all()