Commit graph

34 commits

Author SHA1 Message Date
Antonin RAFFIN
c62e9259db
Add custom objects support + bug fix (#336)
* Add support for custom objects

* Add python 3.8 to the CI

* Bump version

* PyType fixes

* [ci skip] Fix typo

* Add note about slow-down + fix typos

* Minor edits to the doc

* Bug fix for DQN

* Update test

* Add test for custom objects
2021-03-06 15:17:43 +02:00
M. Ernestus
0c50d75ecb
TD3 Code review (#245)
* Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor.

* Add learning rate schedule example (#248)

* Add learning rate schedule example

* Update docs/guide/examples.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Address comments

Co-authored-by: Adam Gleave <adam@gleave.me>

* Add supported action spaces checks (#254)

* Add supported action spaces checks

* Address comment

* Use `pass` in an abstractmethod instead of deleting the arguments.

* Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways.

* Rename _get_data to _get_data_to_reconstruct_model.

_get_data was too generic and could have meant anything.

* Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead.

* Fix docstring of `train_freq` parameter.

* Black fixes.

* Fix TD3 delayed update + rename `_get_data()`

* Fix TD3 test

* Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert.

* Make one step the default train frequency.

* Black fixes.

* Change np.bool to bool.

* Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm.

* Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER.

* Use named tuple for train freq

* Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation.

* Black fixes.

* Revert to train_freq

* Fix terminal observation issues

* Typo

* Fix action noise bug in HER

* Add assert when loading HER models

* Update version

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 17:33:50 +01:00
Antonin RAFFIN
c722c4f5bd
Fix numpy warning and update migration guide (#307) 2021-02-01 11:24:44 +01:00
Antonin RAFFIN
2b9fc1f923
Add supported action spaces checks (#254)
* Add supported action spaces checks

* Address comment
2020-12-06 14:05:10 +02:00
Antonin RAFFIN
3207bdab17
Automatically wrap with a Monitor when possible (#237)
* Automatically wrap with a Monitor when possible

* Update stable_baselines3/common/base_class.py

Co-authored-by: Anssi <kaneran21@hotmail.com>
2020-11-20 18:08:00 +02:00
Antonin RAFFIN
d04aad2a20
Doc fixes and add monitor_kwargs parameter (#230)
* Fix type annotation

* Fix migration doc for A2C

* Update version

* Add `monitor_kwargs` argument

* Update docs/guide/migration.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

* Fix make atari env

* Fix docstring

* Renamed LearningRateSchedule

Co-authored-by: Adam Gleave <adam@gleave.me>
2020-11-20 10:28:54 +01:00
Anssi
18d10dbf42
Use Monitor episode reward/length for evaluate_policy (#220)
* Update evaluate_policy to use monitor data if available

* Update documentation

* Cleaning up

* Remove unnecessary typing trickery

* Update doc

* Rename is_wrapped to clarify it is for vecenvs

* Add is_wrapped for regular envs

* Add is_wrapped call for subprocvecenv and update code for circular imports

* Move new functions back to env_util and fix imports

* Update changelog

* Clarify evaluate_policy docs

* Add tests for wrapped modifying episode lengths

* Fix tests

* Update changelog

* Minor edits

* Add warn switch to evaluate_policy and update tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-11-16 11:52:28 +01:00
M. Ernestus
c74509ae9d
Add callable signatures to type annotations. (#215)
* Add callback signature to the learning rate type annotations.

* Add callback signature to the learning rate schedule type annotations.

* Add missing type annotations for learning rate callbacks.

* Add signature to old-style learning and evaluation callbacks.

* Add signature to env wrapper callback.

* Add type annotation to closure function.

* Use MaybeCallback more consistently.

* Update changelog.

* Remove now unused List import.

* Fix import order.

* Add type alias for learning rate schedules.

* Optimize imports.

* Fix messed up import.

* Remove resolved TODO.

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-11-15 17:50:28 +01:00
Anssi
e2b6f5460f
Avoid transposing channel-first envs (#213)
* Add test for channel-first environments

* Add support for channel-first envs, including more tests

* Update changelog

* Run black

* Run black, again

* Improve NatureCNN error message

* Update image checks and FrameStack wrapper

* Update tests

* Update docs

* Run isort

* Reformat

* Fixes: avoid breaking changes for non-image env

* Add additional checks

* Update docstring

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-11-03 12:34:09 +01:00
Antonin RAFFIN
6327cc6156
Fix env loading (#203)
* Fix loading bug

* [ci skip] Fix docstring
2020-10-27 23:12:52 +02:00
Megan Klaiber
dd6e361204
Implement HER (#120)
* Added working her version, Online sampling is missing.

* Updated test_her.

* Added first version of online her sampling. Still problems with tensor dimensions.

* Reformat

* Fixed tests

* Added some comments.

* Updated changelog.

* Add missing init file

* Fixed some small bugs.

* Reduced arguments for HER, small changes.

* Added getattr. Fixed bug for online sampling.

* Updated save/load funtions. Small changes.

* Added her to init.

* Updated save method.

* Updated her ratio.

* Move obs_wrapper

* Added DQN test.

* Fix potential bug

* Offline and online her share same sample_goal function.

* Changed lists into arrays.

* Updated her test.

* Fix online sampling

* Fixed action bug. Updated time limit for episodes.

* Updated convert_dict method to take keys as arguments.

* Renamed obs dict wrapper.

* Seed bit flipping env

* Remove get_episode_dict

* Add fast online sampling version

* Added documentation.

* Vectorized reward computation

* Vectorized goal sampling

* Update time limit for episodes in online her sampling.

* Fix max episode length inference

* Bug fix for Fetch envs

* Fix for HER + gSDE

* Reformat (new black version)

* Added info dict to compute new reward. Check her_replay_buffer again.

* Fix info buffer

* Updated done flag.

* Fixes for gSDE

* Offline her version uses now HerReplayBuffer as episode storage.

* Fix num_timesteps computation

* Fix get torch params

* Vectorized version for offline sampling.

* Modified offline her sampling to use sample method of her_replay_buffer

* Updated HER tests.

* Updated documentation

* Cleanup docstrings

* Updated to review comments

* Fix pytype

* Update according to review comments.

* Removed random goal strategy. Updated sample transitions.

* Updated migration. Removed time signal removal.

* Update doc

* Fix potential load issue

* Add VecNormalize support for dict obs

* Updated saving/loading replay buffer for HER.

* Fix test memory usage

* Fixed save/load replay buffer.

* Fixed save/load replay buffer

* Fixed transition index after loading replay buffer in online sampling

* Better error handling

* Add tests for get_time_limit

* More tests for VecNormalize with dict obs

* Update doc

* Improve HER description

* Add test for sde support

* Add comments

* Add comments

* Remove check that was always valid

* Fix for terminal observation

* Updated buffer size in offline version and reset of HER buffer

* Reformat

* Update doc

* Remove np.empty + add doc

* Fix loading

* Updated loading replay buffer

* Separate online and offline sampling + bug fixes

* Update tensorboard log name

* Version bump

* Bug fix for special case

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 11:56:43 +02:00
Antonin RAFFIN
a1e055695c
Improve typing coverage (#175)
* Improve typing coverage

* Even more types

* Fixes

* Update changelog

* Unified docstrings

* Improve error messages for unsupported spaces
2020-10-07 10:51:49 +02:00
Antonin RAFFIN
55912576ed
Cleanup docstring types (#169)
* Cleanup docstring types

* Update style

* Test with js hack

* Revert "Test with js hack"

This reverts commit d091f438e8851ab8d01b66628e06a104f5e5ec69.

* Fix types

* Fix typo

* Update CONTRIBUTING example
2020-10-02 20:05:55 +03:00
Antonin RAFFIN
8b16324ba7 Hotfix: properly delete device in policy_kwargs 2020-09-25 15:34:53 +02:00
Anssi
9855486488
Get/set parameters and review of saving and loading (#138)
* Update comments and docstrings

* Rename get_torch_variables to private and update docs

* Clarify documentation on data, params and tensors

* Make excluded_save_params private and update docs

* Update get_torch_variable_names to get_torch_save_params for description

* Simplify saving code and update docs on params vs tensors

* Rename saved item tensors to pytorch_variables for clarity

* Reformat

* Fix a typo

* Add get/set_parameters, update tests accordingly

* Use f-strings for formatting

* Fix load docstring

* Reorganize functions in BaseClass

* Update changelog

* Add library version to the stored models

* Actually run isort this time

* Fix flake8 complaints and also fix testing code

* Fix isort

* ...and black

* Fix set_random_seed

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
2020-09-24 14:28:27 +02:00
liorcohen5
f5104a5efc
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-09-20 19:13:18 +02:00
Anssi
2cd6a4f93b
Match performance with stable-baselines (discrete case) (#110)
* Fix storing correct episode dones

* Fix number of filters in NatureCNN network

* Add TF-like RMSprop for matching performance with sb2

* Remove stuff that was accidentally included

* Reformat

* Clarify variable naming

* Update changelog

* Add comment on RMSprop implementations to A2C

* Add test for RMSpropTFLike

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-08-03 22:22:51 +02:00
Antonin RAFFIN
23afedb254
Auto-formatting with black and isort (#97)
* Add auto formatting with black and isort

* Reformat code

* Ignore typing errors

* Add note about line length

* Add minimum version for isort

* Add commit-checks

* Update docker image

* Fixed lost import (during last merge)

* Fix opencv dependency
2020-07-16 16:12:16 +02:00
Antonin RAFFIN
5ff176b2f1
Implement DDPG (#92)
* Add DDPG + TD3 with any number of critics

* Allow any number of critics for SAC

* Update doc

* [ci skip] Update DDPG example

* Remove unused parameter

* Add DDPG to identity test

* Fix computation with n_critics=1,3

* Update doc

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update docstrings for off-policy algos

* Add check for sde

Co-authored-by: Adam Gleave <adam@gleave.me>
2020-07-16 14:14:22 +02:00
Antonin RAFFIN
208890dfc8
Ignore errors from new pytype version (#107) 2020-07-16 11:54:37 +02:00
Adam Gleave
91bbc28c0f Address minor issues after clarification by @araffin 2020-07-07 18:39:55 -07:00
Adam Gleave
1f0443f332 Review base_class 2020-07-02 18:49:59 -07:00
Stelios Tymvios
4aa66ed34a
Automatically create paths for saved objects (#80)
* automatically create paths for saved objects

* Minor Corrections, more tests

* linting

* typing

* Correct mode checking

* corrected tests to reflect new verbose functionality
2020-07-03 01:14:21 +03:00
Noah
96b771f24e
Implement DQN (#28)
* Created DQN template according to the paper.
Next steps:
- Create Policy
- Complete Training
- Debug

* Changed Base Class

* refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice.

* Added simple DQN policy

* Finished learn and train function
- missing correct loss computation

* changed collect_rollouts to work with discrete space

* moved discrete space collect_rollouts to dqn

* basic dqn working

* deleted SDE related code

* added gradient clipping and moved greedy policy to policy

* changed policy to implement target network
and added soft update(in fact standart tau is 1 so hard update)

* fixed policy setup

* rebase target_update_intervall on _n_updates

* adapted all tests
all tests passing

* Move to stable-baseline3

* Fixes for DQN

* Fix tests + add CNNPolicy

* Allow any optimizer for DQN

* added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule

* more documentation

* changed buffer dtype

* refactor and document

* Added Sphinx Documentation
Updated changelog.rst

* removed custom collect_rollouts as it is no longer necessary

* Implemented suggestions to clean code and documentation.

* extracted some functions on tests to reduce duplicated code

* added support for exploration_fraction

* Fixed exploration_fraction

* Added documentation

* Fixed get_linear_fn -> proper progress scaling

* Merged master

* Added nature reference

* Changed default parameters to https://www.nature.com/articles/nature14236/tables/1

* Fixed n_updates to be incremented correctly

* Correct train_freq

* Doc update

* added special parameter for DQN in tests

* different fix for test_discrete

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/modules/dqn.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Added RMSProp in optimizer_kwargs, as described in nature paper

* Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper

* Changelog update for buffer dtype

* standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter

* slightly more iterations on test_discrete to pass the test

* added param use_rms_prop instead of mutable default argument

* forgot alpha

* using huber loss, adam and learning rate 1e-4

* account for train_freq in update_target_network

* Added memory check for both buffers

* Doc updated for buffer allocation

* Added psutil Requirement

* Adapted test_identity.py

* Fixes with new SB3 version

* Fix for tensorboard name

* Convert assert to warning and fix tests

* Refactor off-policy algorithms

* Fixes

* test: remove next_obs in replay buffer

* Update changelog

* Fix tests and use tmp_path where possible

* Fix sampling bug in buffer

* Do not store next obs on episode termination

* Fix replay buffer sampling

* Update comment

* moved epsilon from policy to model

* Update predict method

* Update atari wrappers to match SB2

* Minor edit in the buffers

* Update changelog

* Merge branch 'master' into dqn

* Update DQN to new structure

* Fix tests and remove hardcoded path

* Fix for DQN

* Disable memory efficient replay buffer by default

* Fix docstring

* Add tests for memory efficient buffer

* Update changelog

* Split collect rollout

* Move target update outside `train()` for DQN

* Update changelog

* Update linear schedule doc

* Cleanup DQN code

* Minor edit

* Update version and docker images

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 11:16:54 +02:00
Anssi
44f8218df0
Review of code (A2C, PPO and refactoring) (#35)
* Split torch module code into torch_layers file

* Updated reference to CNN

* Change 'CxWxH' to 'CxHxW', as per common notion

* Fix missing import in policies.py

* Move PPOPolicy to OnlineActorCriticPolicy

* Create OnPolicyRLModel from PPO, and make A2C and PPO inherit

* Update A2C optimizer comment

* Clean weight init scales for clarity

* Fix A2C log_interval default parameter

* Rename 'progress' to 'progress_remaining

* Rename 'Models' to 'Algorithms'

* Rename 'OnlineActorCriticPolicy' to 'ActorCriticPolicy'

* Move static functions out from BaseAlgorithm

* Move on/off_policy base algorithms to their own files

* Add  files for A2C/PPO

* Fix docs

* Fix pytype

* Update documentation on OnPolicyAlgorithm

* Add proper doctstring for on_policy rollout gathering

* Add bit clarification on the mlppolicy/cnnpolicy naming

* Move static function is_vectorized_policies to utils.py

* Checking docstrings, pep8 fixes

* Update changelog

* Clean changelog

* Remove policy warnings for sac/td3

* Add monitor_wrapper for OnPolicyAlgorithm. Clean tb logging variables. Add parameter keywords to OffPolicyAlgorithm super init

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-09 13:54:18 +02:00
Antonin RAFFIN
11d33eb4ae
Fix gSDE loading issue in test mode (#45)
* Fix gSDE loading issue in test mode

* Forward `reset_noise` method

* Re-add `make_actor`

* Reformat
2020-06-08 11:15:10 +02:00
Antonin RAFFIN
403fff5d50
Pre-Release v0.6.0 (#39)
* Prepare release

* Update docker images
2020-06-01 13:09:47 +02:00
Roland Gavrilescu
bb01253261
Tensorboard integration (#30)
* init commit tensorboard-integration

* Added tb logger to ppo (with output exclusions)

* fixed truncated stdout

* categorize stdout outputs by tag

* separated exclusions from values, added missing logs

* saving exclusions as dict instead of list

* reformatting, auto run indexing

* included renaming suggestions, fixed tests

* tb support for sac

* linting

* moved logging to base class

* tb support for td3

* removed histograms, non-verbose output working

* modifed changelog

* linting

* fixed type error

* moved logger config to utils

* removed episode_rewards log from ppo

* Enable tensorboard in tests

* Remove unused import

* Update logger sub titles

* Minor edit for PPO

* Update logger and tb log folder

* Pass correct logger to Callbacks

* updated docs

* added tb example image to docs

* add support for continuing training in tensorboard

* added tensorboard to docs index

* added tb test

* moved logger config to _setup_learn, updated tests

* accessing verbose from base class

* Update doc and tests

* Rename session -> time

* Update version

* Update logger truncate

* Update types

* Remove duplicated code

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-01 11:55:44 +02:00
Antonin RAFFIN
9b42b9717a
Fix sde_sample_freq for SAC (#32)
* Fix `sde_sample_freq` for SAC

* [ci skip] Add Acknowledgments
2020-05-24 16:44:44 +02:00
Antonin RAFFIN
15ff6d47ee
Documentation update and style fixes (#21)
* Update doc: add gSDE

* Fix codestyle

* Remove travis script

* Add lint check to gitlab
2020-05-15 13:54:06 +02:00
Antonin RAFFIN
413a2386d9 Cleanup + reformat code 2020-05-08 15:10:46 +02:00
Antonin RAFFIN
c20af230f7 Remove SDE support for TD3 2020-05-08 15:00:34 +02:00
Antonin RAFFIN
4a4da90671 Remove saved device + update doc 2020-05-05 17:19:21 +02:00
Antonin RAFFIN
d542732c8d Rename to stable-baselines3 2020-05-05 15:02:35 +02:00
Renamed from torchy_baselines/common/base_class.py (Browse further)