mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Allow env_kwargs in make_vec_env when env ID string supplied (#189)
* Allow env_kwargs in make_vec_env when env ID string supplied Resolves #188 * Update docs/misc/changelog.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Add test for env kargs in make_vec_env * remove unnecessary args in test_vec_env_kwargs function * Fixes and reformat * Doc fix Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
2599f04940
commit
fe6ade3089
4 changed files with 9 additions and 4 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -6,6 +6,7 @@
|
|||
.pytest_cache
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.coverage
|
||||
.coverage.*
|
||||
__pycache__/
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
|
||||
- Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -25,6 +26,7 @@ Others:
|
|||
^^^^^^^
|
||||
- Improved typing coverage
|
||||
- Improved error messages for unsupported spaces
|
||||
- Added ``.vscode`` to the gitignore
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
import gym
|
||||
|
|
@ -45,9 +44,7 @@ def make_vec_env(
|
|||
def make_env(rank):
|
||||
def _init():
|
||||
if isinstance(env_id, str):
|
||||
env = gym.make(env_id)
|
||||
if len(env_kwargs) > 0:
|
||||
warnings.warn("No environment class was passed (only an env ID) so ``env_kwargs`` will be ignored")
|
||||
env = gym.make(env_id, **env_kwargs)
|
||||
else:
|
||||
env = env_id(**env_kwargs)
|
||||
if seed is not None:
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
|
|||
assert np.max(np.abs(reward)) < 1.0
|
||||
|
||||
|
||||
def test_vec_env_kwargs():
|
||||
env = make_vec_env("MountainCarContinuous-v0", n_envs=1, seed=0, env_kwargs={"goal_velocity": 0.11})
|
||||
assert env.get_attr("goal_velocity")[0] == 0.11
|
||||
|
||||
|
||||
def test_custom_vec_env(tmp_path):
|
||||
"""
|
||||
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
|
||||
|
|
|
|||
Loading…
Reference in a new issue