Policy Base for On-policy Algorithms (#412) (#415)

* add policy_base input to OnPolicyAlgorithms

* update changelog

* Fix pytype error

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Rohan Tangri 2021-05-04 10:59:36 +01:00 committed by GitHub
parent 6f822b9ed7
commit 35da0b59b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 4 deletions

View file

@ -18,6 +18,7 @@ New Features:
to handle gym3-style vectorized environments (@vwxyzjn)
- Ignored the terminal observation if the it is not provided by the environment
such as the gym3-style vectorized environments. (@vwxyzjn)
- Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro)
Bug Fixes:
^^^^^^^^^^
@ -655,4 +656,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro

View file

@ -376,7 +376,7 @@ class BaseAlgorithm(ABC):
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:

View file

@ -9,7 +9,7 @@ from stable_baselines3.common import logger
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -35,6 +35,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param policy_base: The base policy used by this method
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
@ -62,6 +63,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
policy_base: Type[BasePolicy] = ActorCriticPolicy,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
monitor_wrapper: bool = True,
@ -76,7 +78,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
super(OnPolicyAlgorithm, self).__init__(
policy=policy,
env=env,
policy_base=ActorCriticPolicy,
policy_base=policy_base,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,