Commit graph

24 commits

Author SHA1 Message Date
Antonin RAFFIN
2ca94cb73d
Add check for common mistake when mixing Gym/VecEnv API (#1696) 2023-09-25 12:39:22 +02:00
Kallinteris Andreas
9c338f917a
vec_envs fix seed() causing a reset (#1486)
* `dummy_vec_env` fix `seed()` causing a reset

* rename `seed`

* fixes

* bug fix

* fix seed return type

* Cleanup seeding, add test and remove compat wrapper

* Update env checker and tests

* Add deterministic test for make_vec_env

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2023-05-20 10:30:54 +02:00
Antonin RAFFIN
40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00
Quentin Gallouédec
4fa17dcf0f
Standardize the use of from gym import spaces (#1240)
* generalize the use of `from gym import spaces`

* command line get system info

* Documentation line length for doc

* update changelog

* add space before os plateform to avoid ref to other issue

* format

* get_system_info update in changelog

* fix type check error

* fix get system info

* add comment about regex

* update version
2023-01-02 14:51:11 +01:00
Quentin Gallouédec
96b1a7cf01
env_id consistency in tests (#1224)
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-12-20 16:01:26 +01:00
Quentin Gallouédec
e3b24829a5
Drop gym.GoalEnv and other minor changes initally from #780 (#1184)
* Various changes from #780

* Fix env_checker for goal_env detection
2022-11-28 18:22:31 +01:00
Antonin RAFFIN
508f8ffd59
Remove deprecated features and attributes (#1104)
* Remove deprecated eval env

* Remove deprecated ret attribute

* Remove sde net arch

* Remove unused code

* Update test comment

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2022-10-11 10:55:16 +02:00
tobirohrer
d8a430e088
Deprecate create_eval_env, eval_env and eval_freq parameter (#1082)
* Adds deprecation warning if `eval_env` or `eval_freq` parameters are used. See #925

* added changelog entry

* added missing backtick

* deprecating `create_eval_env` parameter as well and adding comments to explain the `stacklevel` parameter used

* Updated tests to ignore DeprecationWarnings

* Updated changelog entry

* - Removed the `create_eval_env` parameter from the examples in the docs
- Removed information about the `create_eval_env` parameter from the migration docs
- Added information about deprecation of the `create_eval_env` parameter in the docs

* Add alternative in docstring

* Update docstrings

* `eval_freq` warning in docstring

* Add deprecation comments in tests

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2022-10-10 15:39:38 +02:00
Quentin Gallouédec
fda3d4d748
Fix returned type in predict (#964)
* `arr[0]` to `arr.squeeze(0)`

* `squeeze(axis=0)` to `squeeze(0)`

* Type testing

* Add type test for unvectorized observation

* `squeeze(0)` to `squeeze(axis=0)`

* Treatment of the laziness symptoms

* Update changelog

* Udate changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-07-18 11:22:19 +02:00
Carlos Luis
5143cd19f7
Gym fixes - Follow up from #705 (#734)
* fix Atari in CI

* fix dtype and atari extra

* Update setup.py

* remove 3.6

* note about how to install Atari

* pendulum-v1

* atari v5

* black

* fix pendulum capitalization

* add minimum version

* moved things in changelog to breaking changes

* partial v5 fix

* env update to pass tests

* mismatch env version fixed

* Fix tests after merge

* Include autorom in setup.py

* Blacken code

* Fix dtype issue in more robust way

* Fix GitLab CI: switch to Docker container with new black version

* Remove workaround from GitLab. (May need to rebuild Docker for this though.)

* Revert to v4

* Update setup.py

* Apply suggestions from code review

* Remove unnecessary autorom

* Consistent gym versions

Co-authored-by: J K Terry <justinkterry@gmail.com>
Co-authored-by: Anssi <kaneran21@hotmail.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: modanesh <mohamad4danesh@gmail.com>
Co-authored-by: Adam Gleave <adam@gleave.me>
2022-02-04 15:13:57 -08:00
Antonin RAFFIN
306e49fda6
Fixes in is_vectorized_observation (#587)
* Fix is vectorized bug in DQN

* Fix sub-classed obs
2021-09-28 21:57:49 +02:00
Scott Brownlie
1afc2f3abe
Avoid putting target networks into training mode (#553)
* make sure DQN policy is always in correct mode - train or eval

* make set_training_mode an abstract method of the base policy - safer

* update docstring of _build method to note that the target network is put into eval mode

* use set_training_mode to put the dqn target network into eval mode

* use set_training_mode to set the training model of the q-network

* move set_training_mode abstract method from BasePolicy to BaseModel

* set train and eval mode for TD3

* make sure critic is always in correct mode during train

* set train and eval mode for SAC

* add comment re batch norm and dropout

* set train and eval mode for A2C and PPO

* add tests for collect rollouts with batch norm

* fix formatting

* update change log

* update version

* remove Optional typing for batch size - causing type check to fail

* Fix scipy dependency for toy text envs

* implement set_training_mode method in BaseModel

* move all tests of train/eval mode to test_train_eval_mode

* call learn with learning_starts = total_timesteps to test that collect_rollouts does not update batch norm

* remove extra calls to set_training_mode in train method of TD3 and SAC

* Allow gradient_steps=0

* Refactor tests

* Add comment + use aliases

* Typos

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-08-30 17:42:41 +02:00
David Blom
3efab0d267
Training and evaluation: call model.train() and model.eval() (#537)
* training and evaluation: call model.train() and model.eval() to enable and disable dropout and batchnorm

* Add comment documentation

* Fix train and eval for the Actor class

* Run black

* Add github handle to changelog

* Add unit tests for PPO and DQN

* Refactor unit test

* Run black

* unit test: add a dropout layer and check that calling predict with deterministic=True is deterministic

* documentation: add bugfix description to changelog

* unit test: use learning_starts=0, decrease the size of the network and use more training steps

* on policy algorithms: call policy.train() and policy.eval() instead of disable_training and enable_training as it is a th.nn.module

* Rename unit test

* unit test: use drop out probability of 0.5

* Call policy.train and policy.eval

* Fixes + update tests

* Remove unneeded eval

Co-authored-by: David Blom <davidsblom@gmail.com>
Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
2021-08-14 14:08:27 +02:00
Antonin RAFFIN
9069cf55f1
Fix DQN predict shape for single Gym env (#222)
* Fix DQN predict shape for single Gym env

* Remove unused imports
2020-11-17 00:43:26 +02:00
liorcohen5
f5104a5efc
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-09-20 19:13:18 +02:00
Sam Toyer
42ef6d4677
Remove "device" argument from policies (#141)
* Remove device arg from policies

* Clean up for PR

* Update test and doc

* Fix codestyle

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-08-23 13:27:52 +02:00
Antonin RAFFIN
23afedb254
Auto-formatting with black and isort (#97)
* Add auto formatting with black and isort

* Reformat code

* Ignore typing errors

* Add note about line length

* Add minimum version for isort

* Add commit-checks

* Update docker image

* Fixed lost import (during last merge)

* Fix opencv dependency
2020-07-16 16:12:16 +02:00
Noah
96b771f24e
Implement DQN (#28)
* Created DQN template according to the paper.
Next steps:
- Create Policy
- Complete Training
- Debug

* Changed Base Class

* refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice.

* Added simple DQN policy

* Finished learn and train function
- missing correct loss computation

* changed collect_rollouts to work with discrete space

* moved discrete space collect_rollouts to dqn

* basic dqn working

* deleted SDE related code

* added gradient clipping and moved greedy policy to policy

* changed policy to implement target network
and added soft update(in fact standart tau is 1 so hard update)

* fixed policy setup

* rebase target_update_intervall on _n_updates

* adapted all tests
all tests passing

* Move to stable-baseline3

* Fixes for DQN

* Fix tests + add CNNPolicy

* Allow any optimizer for DQN

* added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule

* more documentation

* changed buffer dtype

* refactor and document

* Added Sphinx Documentation
Updated changelog.rst

* removed custom collect_rollouts as it is no longer necessary

* Implemented suggestions to clean code and documentation.

* extracted some functions on tests to reduce duplicated code

* added support for exploration_fraction

* Fixed exploration_fraction

* Added documentation

* Fixed get_linear_fn -> proper progress scaling

* Merged master

* Added nature reference

* Changed default parameters to https://www.nature.com/articles/nature14236/tables/1

* Fixed n_updates to be incremented correctly

* Correct train_freq

* Doc update

* added special parameter for DQN in tests

* different fix for test_discrete

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Added RMSProp in optimizer_kwargs, as described in nature paper

* Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper

* Changelog update for buffer dtype

* standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter

* slightly more iterations on test_discrete to pass the test

* added param use_rms_prop instead of mutable default argument

* forgot alpha

* using huber loss, adam and learning rate 1e-4

* account for train_freq in update_target_network

* Added memory check for both buffers

* Doc updated for buffer allocation

* Added psutil Requirement

* Adapted test_identity.py

* Fixes with new SB3 version

* Fix for tensorboard name

* Convert assert to warning and fix tests

* Refactor off-policy algorithms

* Fixes

* test: remove next_obs in replay buffer

* Update changelog

* Fix tests and use tmp_path where possible

* Fix sampling bug in buffer

* Do not store next obs on episode termination

* Fix replay buffer sampling

* Update comment

* moved epsilon from policy to model

* Update predict method

* Update atari wrappers to match SB2

* Minor edit in the buffers

* Update changelog

* Merge branch 'master' into dqn

* Update DQN to new structure

* Fix tests and remove hardcoded path

* Fix for DQN

* Disable memory efficient replay buffer by default

* Fix docstring

* Add tests for memory efficient buffer

* Update changelog

* Split collect rollout

* Move target update outside `train()` for DQN

* Update changelog

* Update linear schedule doc

* Cleanup DQN code

* Minor edit

* Update version and docker images

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 11:16:54 +02:00
Antonin RAFFIN
d542732c8d Rename to stable-baselines3 2020-05-05 15:02:35 +02:00
Antonin RAFFIN
dcb54b5301 Remove CEMRL 2020-03-23 14:48:38 +01:00
Antonin RAFFIN
9485b90a41 Sync predict with SB and add version file 2020-03-18 15:11:19 +01:00
Antonin Raffin
18f38f8cf5 Reformat 2020-03-12 11:12:10 +01:00
Antonin Raffin
4392759057 Comment unused code 2020-02-14 14:15:55 +01:00
Antonin Raffin
e31b139c47 Add test for predict method 2020-02-14 14:03:41 +01:00