Remove saved device + update doc

This commit is contained in:
Antonin RAFFIN 2020-05-05 17:19:21 +02:00
parent 0481fbe727
commit 4a4da90671

View file

@ -219,9 +219,9 @@ class BaseRLModel(ABC):
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
Return the `VecNormalize` wrapper of the training env
Return the ``VecNormalize`` wrapper of the training env
if it exists.
:return: Optional[VecNormalize] The `VecNormalize` env.
:return: Optional[VecNormalize] The ``VecNormalize`` env.
"""
return self._vec_normalize_env
@ -267,7 +267,7 @@ class BaseRLModel(ABC):
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variable that will be saved.
`th.save` and `th.load` will be used with the right device
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.
:return: (Tuple[List[str], List[str]])
@ -297,7 +297,7 @@ class BaseRLModel(ABC):
:param tb_log_name: (str) the name of the run for tensorboard log
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
:param eval_freq: (int) Evaluate the agent every `eval_freq` timesteps (this may vary a little)
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
:param n_eval_episodes: (int) Number of episode to evaluate the agent
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:param reset_num_timesteps: (bool)
@ -333,6 +333,11 @@ class BaseRLModel(ABC):
"""
data, params, tensors = cls._load_from_file(load_path)
if 'policy_kwargs' in data:
for arg_to_remove in ['device']:
if arg_to_remove in data['policy_kwargs']:
del data['policy_kwargs'][arg_to_remove]
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")
@ -354,7 +359,7 @@ class BaseRLModel(ABC):
model.__dict__.update(data)
model.__dict__.update(kwargs)
if not hasattr(model, "_setup_model") and len(params) > 0:
raise NotImplementedError(f"{cls} has no `_setup_model()` method")
raise NotImplementedError(f"{cls} has no ``_setup_model()`` method")
model._setup_model()
# put state_dicts back in place
@ -417,7 +422,7 @@ class BaseRLModel(ABC):
file_content.write(tensor_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right `map_location`
# load the parameters with the right ``map_location``
tensors = th.load(file_content, map_location=device)
# check for all other .pth files
@ -434,7 +439,7 @@ class BaseRLModel(ABC):
file_content.write(opt_param_file.read())
# go to start of file
file_content.seek(0)
# load the parameters with the right `map_location`
# load the parameters with the right ``map_location``
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
except zipfile.BadZipFile:
@ -502,7 +507,7 @@ class BaseRLModel(ABC):
:param eval_freq: (int)
:param n_eval_episodes: (int)
:param log_path (Optional[str]): Path to a log folder
:param reset_num_timesteps: (bool) Whether to reset or not the `num_timesteps` attribute
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
:return: (BaseCallback)
"""
self.start_time = time.time()
@ -516,7 +521,7 @@ class BaseRLModel(ABC):
self.num_timesteps = 0
self._episode_num = 0
# Avoid resetting the environment when calling `.learn()` consecutive times
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
# Retrieve unnormalized observation for saving into the buffer
@ -762,9 +767,9 @@ class OffPolicyRLModel(BaseRLModel):
:param env: (VecEnv) The training environment
:param n_episodes: (int) Number of episodes to use to collect rollout data
You can also specify a `n_steps` instead
You can also specify a ``n_steps`` instead
:param n_steps: (int) Number of steps to use to collect rollout data
You can also specify a `n_episodes` instead.
You can also specify a ``n_episodes`` instead.
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
@ -772,7 +777,7 @@ class OffPolicyRLModel(BaseRLModel):
(and at the beginning and end of the rollout)
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
:param replay_buffer: (ReplayBuffer)
:param log_interval: (int) Log data every `log_interval` episodes
:param log_interval: (int) Log data every ``log_interval`` episodes
:return: (RolloutReturn)
"""
episode_rewards, total_timesteps = [], []