stable-baselines3/stable_baselines3/common
Scott Brownlie 1afc2f3abe
Avoid putting target networks into training mode (#553)
* make sure DQN policy is always in correct mode - train or eval

* make set_training_mode an abstract method of the base policy - safer

* update docstring of _build method to note that the target network is put into eval mode

* use set_training_mode to put the dqn target network into eval mode

* use set_training_mode to set the training model of the q-network

* move set_training_mode abstract method from BasePolicy to BaseModel

* set train and eval mode for TD3

* make sure critic is always in correct mode during train

* set train and eval mode for SAC

* add comment re batch norm and dropout

* set train and eval mode for A2C and PPO

* add tests for collect rollouts with batch norm

* fix formatting

* update change log

* update version

* remove Optional typing for batch size - causing type check to fail

* Fix scipy dependency for toy text envs

* implement set_training_mode method in BaseModel

* move all tests of train/eval mode to test_train_eval_mode

* call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm

* remove extra calls to set_training_mode in train method of TD3 and SAC

* Allow gradient_steps=0

* Refactor tests

* Add comment + use aliases

* Typos

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-08-30 17:42:41 +02:00
..
envs Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
sb2_compat Add callable signatures to type annotations. (#215) 2020-11-15 17:50:28 +01:00
vec_env Fix logger setup (#469) 2021-06-14 15:17:48 +02:00
__init__.py Update docs (custom policy, type hints) (#167) 2020-09-29 20:41:14 +03:00
atari_wrappers.py Add SUMO-RL as example project in the docs (#257) 2020-12-13 17:15:45 +01:00
base_class.py Training and evaluation: call model.train() and model.eval() (#537) 2021-08-14 14:08:27 +02:00
buffers.py Corrected DictReplayBuffer observation dtype #484 (#486) 2021-06-22 13:41:26 +02:00
callbacks.py Fix type annotations (#522) 2021-07-29 13:02:09 +02:00
distributions.py KL Divergence Helper Function (#431) 2021-05-20 19:01:07 +02:00
env_checker.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
env_util.py Added wrapper_kwargs argument to make_vec_env (#448) 2021-05-23 11:33:34 +02:00
evaluation.py Added support for vector envs in evaluation (#447) 2021-05-28 12:40:29 +02:00
logger.py Fix logger setup (#469) 2021-06-14 15:17:48 +02:00
monitor.py Fix type annotations (#522) 2021-07-29 13:02:09 +02:00
noise.py Improve typing coverage (#175) 2020-10-07 10:51:49 +02:00
off_policy_algorithm.py Avoid putting target networks into training mode (#553) 2021-08-30 17:42:41 +02:00
on_policy_algorithm.py Avoid putting target networks into training mode (#553) 2021-08-30 17:42:41 +02:00
policies.py Avoid putting target networks into training mode (#553) 2021-08-30 17:42:41 +02:00
preprocessing.py Documentation update (#450) 2021-05-23 13:13:11 +02:00
results_plotter.py Fix default arguments + add bugbear (#363) 2021-03-25 11:35:21 +02:00
running_mean_std.py Cleanup docstring types (#169) 2020-10-02 20:05:55 +03:00
save_util.py Fix type annotations (#522) 2021-07-29 13:02:09 +02:00
torch_layers.py Documentation update (#450) 2021-05-23 13:13:11 +02:00
type_aliases.py Dictionary Observations (#243) 2021-05-11 12:29:30 +02:00
utils.py Fix logger setup (#469) 2021-06-14 15:17:48 +02:00