mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
* add policy_base input to OnPolicyAlgorithms * update changelog * Fix pytype error Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
6f822b9ed7
commit
35da0b59b9
3 changed files with 7 additions and 4 deletions
|
|
@ -18,6 +18,7 @@ New Features:
|
|||
to handle gym3-style vectorized environments (@vwxyzjn)
|
||||
- Ignored the terminal observation if the it is not provided by the environment
|
||||
such as the gym3-style vectorized environments. (@vwxyzjn)
|
||||
- Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -655,4 +656,4 @@ And all the contributors:
|
|||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
|
||||
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from stable_baselines3.common import logger
|
|||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
|
@ -35,6 +35,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
|
|
@ -62,6 +63,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
policy_base: Type[BasePolicy] = ActorCriticPolicy,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
|
|
@ -76,7 +78,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
super(OnPolicyAlgorithm, self).__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
policy_base=ActorCriticPolicy,
|
||||
policy_base=policy_base,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
|
|
|
|||
Loading…
Reference in a new issue