mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Remove saved device + update doc
This commit is contained in:
parent
0481fbe727
commit
4a4da90671
1 changed files with 17 additions and 12 deletions
|
|
@ -219,9 +219,9 @@ class BaseRLModel(ABC):
|
|||
|
||||
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
|
||||
"""
|
||||
Return the `VecNormalize` wrapper of the training env
|
||||
Return the ``VecNormalize`` wrapper of the training env
|
||||
if it exists.
|
||||
:return: Optional[VecNormalize] The `VecNormalize` env.
|
||||
:return: Optional[VecNormalize] The ``VecNormalize`` env.
|
||||
"""
|
||||
return self._vec_normalize_env
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ class BaseRLModel(ABC):
|
|||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the name of the torch variable that will be saved.
|
||||
`th.save` and `th.load` will be used with the right device
|
||||
``th.save`` and ``th.load`` will be used with the right device
|
||||
instead of the default pickling strategy.
|
||||
|
||||
:return: (Tuple[List[str], List[str]])
|
||||
|
|
@ -297,7 +297,7 @@ class BaseRLModel(ABC):
|
|||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
||||
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
|
||||
:param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little)
|
||||
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||
:param n_eval_episodes: (int) Number of episode to evaluate the agent
|
||||
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: (bool)
|
||||
|
|
@ -333,6 +333,11 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
data, params, tensors = cls._load_from_file(load_path)
|
||||
|
||||
if 'policy_kwargs' in data:
|
||||
for arg_to_remove in ['device']:
|
||||
if arg_to_remove in data['policy_kwargs']:
|
||||
del data['policy_kwargs'][arg_to_remove]
|
||||
|
||||
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
|
||||
raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
|
||||
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")
|
||||
|
|
@ -354,7 +359,7 @@ class BaseRLModel(ABC):
|
|||
model.__dict__.update(data)
|
||||
model.__dict__.update(kwargs)
|
||||
if not hasattr(model, "_setup_model") and len(params) > 0:
|
||||
raise NotImplementedError(f"{cls} has no `_setup_model()` method")
|
||||
raise NotImplementedError(f"{cls} has no ``_setup_model()`` method")
|
||||
model._setup_model()
|
||||
|
||||
# put state_dicts back in place
|
||||
|
|
@ -417,7 +422,7 @@ class BaseRLModel(ABC):
|
|||
file_content.write(tensor_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right `map_location`
|
||||
# load the parameters with the right ``map_location``
|
||||
tensors = th.load(file_content, map_location=device)
|
||||
|
||||
# check for all other .pth files
|
||||
|
|
@ -434,7 +439,7 @@ class BaseRLModel(ABC):
|
|||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right `map_location`
|
||||
# load the parameters with the right ``map_location``
|
||||
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
|
|
@ -502,7 +507,7 @@ class BaseRLModel(ABC):
|
|||
:param eval_freq: (int)
|
||||
:param n_eval_episodes: (int)
|
||||
:param log_path (Optional[str]): Path to a log folder
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
|
||||
:return: (BaseCallback)
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
|
|
@ -516,7 +521,7 @@ class BaseRLModel(ABC):
|
|||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
|
||||
# Avoid resetting the environment when calling `.learn()` consecutive times
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
self._last_obs = self.env.reset()
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
|
|
@ -762,9 +767,9 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param n_episodes: (int) Number of episodes to use to collect rollout data
|
||||
You can also specify a `n_steps` instead
|
||||
You can also specify a ``n_steps`` instead
|
||||
:param n_steps: (int) Number of steps to use to collect rollout data
|
||||
You can also specify a `n_episodes` instead.
|
||||
You can also specify a ``n_episodes`` instead.
|
||||
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
|
|
@ -772,7 +777,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
(and at the beginning and end of the rollout)
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param log_interval: (int) Log data every `log_interval` episodes
|
||||
:param log_interval: (int) Log data every ``log_interval`` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
|
|
|
|||
Loading…
Reference in a new issue