stable-baselines3/tests/test_train_eval_mode.py

383 lines
13 KiB
Python
Raw Normal View History

from typing import Union
import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
MODEL_LIST = [
PPO,
A2C,
TD3,
SAC,
DQN,
]
class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization and dropout.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
super().__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
result = self.dropout(result)
return result
def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the given batch norm layer.
:param batch_norm:
:return: the bias and running mean
"""
return batch_norm.bias.clone(), batch_norm.running_mean.clone()
def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the Q-network and target network.
:param model:
:return: the bias and running mean from the Q-network and target network
"""
q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm
q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm)
q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm
q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm)
return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean
def clone_td3_batch_norm_stats(
model: TD3,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
actor_target_batch_norm = model.actor_target.features_extractor.batch_norm
actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (
actor_bias,
actor_running_mean,
critic_bias,
critic_running_mean,
actor_target_bias,
actor_target_running_mean,
critic_target_bias,
critic_target_running_mean,
)
def clone_sac_batch_norm_stats(
model: SAC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
"""
Clone the bias and running mean from the actor and critic networks and critic-target networks.
:param model:
:return: the bias and running mean from the actor and critic networks and critic-target networks
"""
actor_batch_norm = model.actor.features_extractor.batch_norm
actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)
critic_batch_norm = model.critic.features_extractor.batch_norm
critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)
critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)
return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)
def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor):
return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)
CLONE_HELPERS = {
A2C: clone_on_policy_batch_norm,
DQN: clone_dqn_batch_norm_stats,
SAC: clone_sac_batch_norm_stats,
TD3: clone_td3_batch_norm_stats,
PPO: clone_on_policy_batch_norm,
}
def test_dqn_train_with_batch_norm():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
tau=0.0, # do not clone the target
target_update_interval=100, # Copy the stats to the target
)
(
q_net_bias_before,
q_net_running_mean_before,
q_net_target_bias_before,
q_net_target_running_mean_before,
) = clone_dqn_batch_norm_stats(model)
model.learn(total_timesteps=200)
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Force stats copy
model.target_update_interval = 1
model._on_step()
(
q_net_bias_after,
q_net_running_mean_after,
q_net_target_bias_after,
q_net_target_running_mean_after,
) = clone_dqn_batch_norm_stats(model)
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# No weight update
assert th.isclose(q_net_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Running stat should be copied even when tau=0
assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all()
assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all()
def test_td3_train_with_batch_norm():
model = TD3(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
actor_target_bias_before,
actor_target_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_td3_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
actor_target_bias_after,
actor_target_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_td3_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Running stat should be copied even when tau=0
assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
def test_sac_train_with_batch_norm():
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
tau=0, # do not copy the target
seed=1,
)
(
actor_bias_before,
actor_running_mean_before,
critic_bias_before,
critic_running_mean_before,
critic_target_bias_before,
critic_target_running_mean_before,
) = clone_sac_batch_norm_stats(model)
model.learn(total_timesteps=200)
(
actor_bias_after,
actor_running_mean_after,
critic_bias_after,
critic_running_mean_after,
critic_target_bias_after,
critic_target_running_mean_after,
) = clone_sac_batch_norm_stats(model)
assert ~th.isclose(actor_bias_before, actor_bias_after).all()
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
Include `running_mean` and `running_val` when updating target networks (#1004) * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Burak Demirbilek <BurakDmb@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2022-08-23 08:20:43 +00:00
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
model.learn(total_timesteps=200)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert ~th.isclose(bias_before, bias_after).all()
assert ~th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", [DQN, TD3, SAC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
if model_class in [DQN]:
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"
clone_helper = CLONE_HELPERS[model_class]
learning_starts = 10
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=learning_starts,
seed=1,
gradient_steps=0,
train_freq=1,
)
batch_norm_stats_before = clone_helper(model)
model.learn(total_timesteps=100)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
model = model_class(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
seed=1,
n_steps=64,
)
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64)
for _ in range(2):
model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)
bias_after, running_mean_after = clone_on_policy_batch_norm(model)
assert th.isclose(bias_before, bias_after).all()
assert th.isclose(running_mean_before, running_mean_after).all()
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
model_kwargs = dict(seed=1)
clone_helper = CLONE_HELPERS[model_class]
if model_class in [DQN, TD3, SAC]:
model_kwargs["learning_starts"] = 0
else:
model_kwargs["n_steps"] = 64
policy_kwargs = dict(
features_extractor_class=FlattenBatchNormDropoutExtractor,
net_arch=[16, 16],
)
model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)
batch_norm_stats_before = clone_helper(model)
env = model.get_env()
observation = env.reset()
first_prediction, _ = model.predict(observation, deterministic=True)
for _ in range(5):
prediction, _ = model.predict(observation, deterministic=True)
np.testing.assert_allclose(first_prediction, prediction)
batch_norm_stats_after = clone_helper(model)
# No change in batch norm params
for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
assert th.isclose(param_before, param_after).all()