diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4ca0bb5..3a2613a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,6 +18,7 @@ New Features: to handle gym3-style vectorized environments (@vwxyzjn) - Ignored the terminal observation if the it is not provided by the environment such as the gym3-style vectorized environments. (@vwxyzjn) +- Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro) Bug Fixes: ^^^^^^^^^^ @@ -655,4 +656,4 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 819a67a..4394797 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -376,7 +376,7 @@ class BaseAlgorithm(ABC): # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: - self._last_obs = self.env.reset() + self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 64300dc..016954d 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -9,7 +9,7 @@ from stable_baselines3.common import logger from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -35,6 +35,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param policy_base: The base policy used by this method :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -62,6 +63,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): max_grad_norm: float, use_sde: bool, sde_sample_freq: int, + policy_base: Type[BasePolicy] = ActorCriticPolicy, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, monitor_wrapper: bool = True, @@ -76,7 +78,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): super(OnPolicyAlgorithm, self).__init__( policy=policy, env=env, - policy_base=ActorCriticPolicy, + policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, verbose=verbose,