stable-baselines3/tests/test_train_eval_mode.py
Scott Brownlie 1afc2f3abe
Avoid putting target networks into training mode (#553)
* make sure DQN policy is always in correct mode - train or eval

* make set_training_mode an abstract method of the base policy - safer

* update docstring of _build method to note that the target network is put into eval mode

* use set_training_mode to put the dqn target network into eval mode

* use set_training_mode to set the training model of the q-network

* move set_training_mode abstract method from BasePolicy to BaseModel

* set train and eval mode for TD3

* make sure critic is always in correct mode during train

* set train and eval mode for SAC

* add comment re batch norm and dropout

* set train and eval mode for A2C and PPO

* add tests for collect rollouts with batch norm

* fix formatting

* update change log

* update version

* remove Optional typing for batch size - causing type check to fail

* Fix scipy dependency for toy text envs

* implement set_training_mode method in BaseModel

* move all tests of train/eval mode to test_train_eval_mode

* call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm

* remove extra calls to set_training_mode in train method of TD3 and SAC

* Allow gradient_steps=0

* Refactor tests

* Add comment + use aliases

* Typos

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-08-30 17:42:41 +02:00

370 lines
12 KiB
Python

from typing import Union
import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
MODEL_LIST = [
PPO,
A2C,
TD3,
SAC,
DQN,
]
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization and dropout.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
result = self.dropout(result)
return result
def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the given batch norm layer.
:param batch_norm:
:return: the bias and running mean
"""
return batch_norm.bias.clone(), batch_norm.running_mean.clone()
def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the Q-network and target network.
:param model:
:return: the bias and running mean from the Q-network and target network
"""
q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm
q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm)
q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm
q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm)
return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean
def clone_td3_batch_norm_stats(
model: TD3,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
actor_target_batch_norm = model.actor_target.features_extractor.batch_norm
actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (
actor_bias,
actor_running_mean,
critic_bias,
critic_running_mean,
actor_target_bias,
actor_target_running_mean,
critic_target_bias,
critic_target_running_mean,
)
def clone_sac_batch_norm_stats(
model: SAC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor):
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
CLONE_HELPERS = {
A2C: clone_on_policy_batch_norm,
DQN: clone_dqn_batch_norm_stats,
SAC: clone_sac_batch_norm_stats,
TD3: clone_td3_batch_norm_stats,
PPO: clone_on_policy_batch_norm,
}
def test_dqn_train_with_batch_norm():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
tau=0, # do not clone the target
)
(
q_net_bias_before,
q_net_running_mean_before,
q_net_target_bias_before,
q_net_target_running_mean_before,
) = clone_dqn_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
q_net_bias_after,
q_net_running_mean_after,
q_net_target_bias_after,
q_net_target_running_mean_after,
) = clone_dqn_batch_norm_stats(model)
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()
def test_td3_train_with_batch_norm():
model = TD3(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
actor_target_bias_before,
actor_target_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_td3_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
actor_target_bias_after,
actor_target_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_td3_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
def test_sac_train_with_batch_norm():
model = SAC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_sac_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_sac_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
model.learn(total_timesteps=200)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert ~th.isclose(bias_before, bias_after).all()
assert ~th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", [DQN, TD3, SAC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [DQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v0"
clone_helper = CLONE_HELPERS[model_class]
learning_starts = 10
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=learning_starts,
seed=1,
gradient_steps=0,
train_freq=1,
)
batch_norm_stats_before = clone_helper(model)
model.learn(total_timesteps=100)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
n_steps=64,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64, eval_env=model.get_env())
for _ in range(2):
model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert th.isclose(bias_before, bias_after).all()
assert th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
model_kwargs = dict(seed=1)
clone_helper = CLONE_HELPERS[model_class]
if model_class in [DQN, TD3, SAC]:
model_kwargs["learning_starts"] = 0
else:
model_kwargs["n_steps"] = 64
policy_kwargs = dict(
features_extractor_class=FlattenBatchNormDropoutExtractor,
net_arch=[16, 16],
)
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)
batch_norm_stats_before = clone_helper(model)
env = model.get_env()
observation = env.reset()
first_prediction, _ = model.predict(observation, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()