stable-baselines3/tests
Scott Brownlie 1afc2f3abe
Avoid putting target networks into training mode (#553)
* make sure DQN policy is always in correct mode - train or eval

* make set_training_mode an abstract method of the base policy - safer

* update docstring of _build method to note that the target network is put into eval mode

* use set_training_mode to put the dqn target network into eval mode

* use set_training_mode to set the training model of the q-network

* move set_training_mode abstract method from BasePolicy to BaseModel

* set train and eval mode for TD3

* make sure critic is always in correct mode during train

* set train and eval mode for SAC

* add comment re batch norm and dropout

* set train and eval mode for A2C and PPO

* add tests for collect rollouts with batch norm

* fix formatting

* update change log

* update version

* remove Optional typing for batch size - causing type check to fail

* Fix scipy dependency for toy text envs

* implement set_training_mode method in BaseModel

* move all tests of train/eval mode to test_train_eval_mode

* call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm

* remove extra calls to set_training_mode in train method of TD3 and SAC

* Allow gradient_steps=0

* Refactor tests

* Add comment + use aliases

* Typos

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-08-30 17:42:41 +02:00
..
__init__.py Init: TD3 2019-09-05 17:29:41 +02:00
test_callbacks.py Fix Inconsistencies with EvalCallback tensorboard logs (#492) 2021-07-01 15:43:08 +02:00
test_cnn.py Documentation update (#450) 2021-05-23 13:13:11 +02:00
test_custom_policy.py Fix discrete obs support (#296) 2021-01-21 02:42:33 +02:00
test_deterministic.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
test_dict_env.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
test_distributions.py KL Divergence Helper Function (#431) 2021-05-20 19:01:07 +02:00
test_env_checker.py add check to ensure action space is non-dict non-tuple for env_checker nan check (#192) 2020-10-19 00:23:51 +03:00
test_envs.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
test_gae.py Add test for GAE + rename RolloutBuffer.dones for clarification (#375) 2021-04-16 15:52:55 +02:00
test_her.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
test_identity.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
test_logger.py Fix logger setup (#469) 2021-06-14 15:17:48 +02:00
test_monitor.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
test_predict.py Avoid putting target networks into training mode (#553) 2021-08-30 17:42:41 +02:00
test_run.py Fix train_freq at load time (#332) 2021-02-27 19:53:13 +01:00
test_save_load.py Fix ent coef loading for SAC (#429) 2021-05-12 12:21:54 +03:00
test_sde.py Implement HER (#120) 2020-10-22 11:56:43 +02:00
test_spaces.py Fix discrete obs support (#296) 2021-01-21 02:42:33 +02:00
test_tensorboard.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
test_train_eval_mode.py Avoid putting target networks into training mode (#553) 2021-08-30 17:42:41 +02:00
test_utils.py Added support for vector envs in evaluation (#447) 2021-05-28 12:40:29 +02:00
test_vec_check_nan.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
test_vec_envs.py Use Monitor episode reward/length for evaluate_policy (#220) 2020-11-16 11:52:28 +01:00
test_vec_extract_dict_obs.py Support for VecMonitor for gym3-style environments (#311) 2021-04-13 18:09:31 +02:00
test_vec_monitor.py Support for VecMonitor for gym3-style environments (#311) 2021-04-13 18:09:31 +02:00
test_vec_normalize.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00