diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 1aee061..12320a8 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -219,9 +219,9 @@ class BaseRLModel(ABC): def get_vec_normalize_env(self) -> Optional[VecNormalize]: """ - Return the `VecNormalize` wrapper of the training env + Return the ``VecNormalize`` wrapper of the training env if it exists. - :return: Optional[VecNormalize] The `VecNormalize` env. + :return: Optional[VecNormalize] The ``VecNormalize`` env. """ return self._vec_normalize_env @@ -267,7 +267,7 @@ class BaseRLModel(ABC): def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ Get the name of the torch variable that will be saved. - `th.save` and `th.load` will be used with the right device + ``th.save`` and ``th.load`` will be used with the right device instead of the default pickling strategy. :return: (Tuple[List[str], List[str]]) @@ -297,7 +297,7 @@ class BaseRLModel(ABC): :param tb_log_name: (str) the name of the run for tensorboard log :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) :param eval_env: (gym.Env) Environment that will be used to evaluate the agent - :param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little) + :param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) :param n_eval_episodes: (int) Number of episode to evaluate the agent :param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved :param reset_num_timesteps: (bool) @@ -333,6 +333,11 @@ class BaseRLModel(ABC): """ data, params, tensors = cls._load_from_file(load_path) + if 'policy_kwargs' in data: + for arg_to_remove in ['device']: + if arg_to_remove in data['policy_kwargs']: + del data['policy_kwargs'][arg_to_remove] + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}") @@ -354,7 +359,7 @@ class BaseRLModel(ABC): model.__dict__.update(data) model.__dict__.update(kwargs) if not hasattr(model, "_setup_model") and len(params) > 0: - raise NotImplementedError(f"{cls} has no `_setup_model()` method") + raise NotImplementedError(f"{cls} has no ``_setup_model()`` method") model._setup_model() # put state_dicts back in place @@ -417,7 +422,7 @@ class BaseRLModel(ABC): file_content.write(tensor_file.read()) # go to start of file file_content.seek(0) - # load the parameters with the right `map_location` + # load the parameters with the right ``map_location`` tensors = th.load(file_content, map_location=device) # check for all other .pth files @@ -434,7 +439,7 @@ class BaseRLModel(ABC): file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) - # load the parameters with the right `map_location` + # load the parameters with the right ``map_location`` params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device) except zipfile.BadZipFile: @@ -502,7 +507,7 @@ class BaseRLModel(ABC): :param eval_freq: (int) :param n_eval_episodes: (int) :param log_path (Optional[str]): Path to a log folder - :param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute + :param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute :return: (BaseCallback) """ self.start_time = time.time() @@ -516,7 +521,7 @@ class BaseRLModel(ABC): self.num_timesteps = 0 self._episode_num = 0 - # Avoid resetting the environment when calling `.learn()` consecutive times + # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: self._last_obs = self.env.reset() # Retrieve unnormalized observation for saving into the buffer @@ -762,9 +767,9 @@ class OffPolicyRLModel(BaseRLModel): :param env: (VecEnv) The training environment :param n_episodes: (int) Number of episodes to use to collect rollout data - You can also specify a `n_steps` instead + You can also specify a ``n_steps`` instead :param n_steps: (int) Number of steps to use to collect rollout data - You can also specify a `n_episodes` instead. + You can also specify a ``n_episodes`` instead. :param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. @@ -772,7 +777,7 @@ class OffPolicyRLModel(BaseRLModel): (and at the beginning and end of the rollout) :param learning_starts: (int) Number of steps before learning for the warm-up phase. :param replay_buffer: (ReplayBuffer) - :param log_interval: (int) Log data every `log_interval` episodes + :param log_interval: (int) Log data every ``log_interval`` episodes :return: (RolloutReturn) """ episode_rewards, total_timesteps = [], []