mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Partially type base class
This commit is contained in:
parent
0328a39d1b
commit
ff0eddfb17
3 changed files with 88 additions and 66 deletions
|
|
@ -70,6 +70,7 @@ release = torchy_baselines.__version__
|
|||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx_autodoc_typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.ifconfig',
|
||||
|
|
|
|||
4
setup.py
4
setup.py
|
|
@ -25,7 +25,9 @@ setup(name='torchy_baselines',
|
|||
'sphinx-autobuild',
|
||||
'sphinx-rtd-theme',
|
||||
# For spelling
|
||||
'sphinxcontrib.spelling'
|
||||
'sphinxcontrib.spelling',
|
||||
# Type hints support
|
||||
'sphinx-autodoc-typehints'
|
||||
],
|
||||
'extra': [
|
||||
# For render
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import io
|
||||
import zipfile
|
||||
import typing
|
||||
from typing import Union, Type, Optional
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
||||
|
|
@ -19,6 +19,7 @@ from torchy_baselines.common.monitor import Monitor
|
|||
from torchy_baselines.common.evaluation import evaluate_policy
|
||||
from torchy_baselines.common.save_util import data_to_json, json_to_data
|
||||
|
||||
# TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv]
|
||||
if typing.TYPE_CHECKING:
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
|
||||
|
|
@ -27,31 +28,41 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
The base RL model
|
||||
|
||||
:param policy: (BasePolicy) Policy object
|
||||
:param env: (Gym environment) The environment to learn from
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: (BasePolicy) the base policy used by this method
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param policy_base: The base policy used by this method
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: (bool) Whether the algorithm supports training
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
"""
|
||||
def __init__(self, policy: Type[BasePolicy], env: Union[gym.Env, VecEnv, str], policy_base, policy_kwargs=None,
|
||||
verbose=0, device='auto', support_multi_env=False,
|
||||
create_eval_env=False, monitor_wrapper=True, seed=None,
|
||||
use_sde=False, sde_sample_freq=-1):
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[gym.Env, VecEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
policy_kwargs : Dict[str, Any] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1):
|
||||
|
||||
if isinstance(policy, str) and policy_base is not None:
|
||||
self.policy_class = get_policy_from_name(policy_base, policy)
|
||||
else:
|
||||
|
|
@ -118,12 +129,12 @@ class BaseRLModel(ABC):
|
|||
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
|
||||
" environment.")
|
||||
|
||||
def _get_eval_env(self, eval_env):
|
||||
def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]:
|
||||
"""
|
||||
Return the environment that will be used for evaluation.
|
||||
|
||||
:param eval_env: (gym.Env or VecEnv)
|
||||
:return: (VecEnv)
|
||||
:param eval_env:
|
||||
:return:
|
||||
"""
|
||||
if eval_env is None:
|
||||
eval_env = self.eval_env
|
||||
|
|
@ -134,48 +145,47 @@ class BaseRLModel(ABC):
|
|||
assert eval_env.num_envs == 1
|
||||
return eval_env
|
||||
|
||||
def scale_action(self, action):
|
||||
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rescale the action from [low, high] to [-1, 1]
|
||||
(no need for symmetric action space)
|
||||
|
||||
:param action: (np.ndarray)
|
||||
:return: (np.ndarray)
|
||||
:param action:
|
||||
:return:
|
||||
"""
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return 2.0 * ((action - low) / (high - low)) - 1.0
|
||||
|
||||
def unscale_action(self, scaled_action):
|
||||
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rescale the action from [-1, 1] to [low, high]
|
||||
(no need for symmetric action space)
|
||||
|
||||
:param scaled_action: (np.ndarray)
|
||||
:return: (np.ndarray)
|
||||
:param scaled_action:
|
||||
:return:
|
||||
"""
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
def _setup_learning_rate(self):
|
||||
def _setup_learning_rate(self) -> None:
|
||||
"""Transform to callable if needed."""
|
||||
self.learning_rate = get_schedule_fn(self.learning_rate)
|
||||
|
||||
def _update_current_progress(self, num_timesteps, total_timesteps):
|
||||
def _update_current_progress(self, num_timesteps: int, total_timesteps: int) -> None:
|
||||
"""
|
||||
Compute current progress (from 1 to 0)
|
||||
|
||||
:param num_timesteps: (int) current number of timesteps
|
||||
:param total_timesteps: (int)
|
||||
:param num_timesteps: current number of timesteps
|
||||
:param total_timesteps:
|
||||
"""
|
||||
self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps)
|
||||
|
||||
def _update_learning_rate(self, optimizers):
|
||||
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
||||
"""
|
||||
Update the optimizers learning rate using the current learning rate schedule
|
||||
and the current progress (from 1 to 0).
|
||||
|
||||
:param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer
|
||||
or a list of optimizer.
|
||||
:param optimizers: An optimizer or a list of optimizer.
|
||||
"""
|
||||
# Log the current learning rate
|
||||
logger.logkv("learning_rate", self.learning_rate(self._current_progress))
|
||||
|
|
@ -186,32 +196,32 @@ class BaseRLModel(ABC):
|
|||
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
|
||||
|
||||
@staticmethod
|
||||
def safe_mean(arr):
|
||||
def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray:
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
||||
:param arr: (np.ndarray)
|
||||
:return: (float)
|
||||
:param arr:
|
||||
:return:
|
||||
"""
|
||||
return np.nan if len(arr) == 0 else np.mean(arr)
|
||||
|
||||
def get_env(self):
|
||||
def get_env(self) -> Union[VecEnv, None]:
|
||||
"""
|
||||
Returns the current environment (can be None if not defined).
|
||||
|
||||
:return: (gym.Env) The current environment
|
||||
:return: The current environment
|
||||
"""
|
||||
return self.env
|
||||
|
||||
@staticmethod
|
||||
def check_env(env, observation_space, action_space):
|
||||
def check_env(env, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> bool:
|
||||
"""
|
||||
Checks the validity of the environment and returns if it is consistent.
|
||||
Checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
:return: (bool) True if environment seems to be coherent
|
||||
:return: True if environment seems to be coherent
|
||||
"""
|
||||
if observation_space != env.observation_space:
|
||||
return False
|
||||
|
|
@ -220,7 +230,7 @@ class BaseRLModel(ABC):
|
|||
# return true if no check failed
|
||||
return True
|
||||
|
||||
def set_env(self, env):
|
||||
def set_env(self, env: Union[gym.Env, VecEnv]) -> None:
|
||||
"""
|
||||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
Furthermore wrap any non vectorized env into a vectorized
|
||||
|
|
@ -228,7 +238,7 @@ class BaseRLModel(ABC):
|
|||
- observation_space
|
||||
- action_space
|
||||
|
||||
:param env: (gym.Env) The environment for learning a policy
|
||||
:param env: The environment for learning a policy
|
||||
"""
|
||||
if self.check_env(env, self.observation_space, self.action_space) is False:
|
||||
raise ValueError("The given environment is not compatible with model: "
|
||||
|
|
@ -242,23 +252,24 @@ class BaseRLModel(ABC):
|
|||
self.n_envs = env.num_envs
|
||||
self.env = env
|
||||
|
||||
def get_parameters(self):
|
||||
def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Returns policy and optimizer parameters as a tuple
|
||||
:return: (dict,dict) policy_parameters, opt_parameters
|
||||
|
||||
:return: policy_parameters, opt_parameters
|
||||
"""
|
||||
return self.get_policy_parameters(), self.get_opt_parameters()
|
||||
|
||||
def get_policy_parameters(self):
|
||||
def get_policy_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current model policy parameters as dictionary of variable name -> tensors.
|
||||
|
||||
:return: (dict) Dictionary of variable name -> tensor of model's policy parameters.
|
||||
:return: Dictionary of variable name -> tensor of model's policy parameters.
|
||||
"""
|
||||
return self.policy.state_dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_opt_parameters(self):
|
||||
def get_opt_parameters(self)-> Dict[str, Any]:
|
||||
"""
|
||||
Get current model optimizer parameters as dictionary of variable names -> tensors
|
||||
:return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters
|
||||
|
|
@ -266,8 +277,13 @@ class BaseRLModel(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True):
|
||||
def learn(self, total_timesteps: int,
|
||||
callback=None, log_interval: int = 100,
|
||||
tb_log_name: str = "run",
|
||||
eval_env: Union[gym.Env, VecEnv, None] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
reset_num_timesteps: bool = True):
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
@ -285,19 +301,22 @@ class BaseRLModel(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, observation, state=None, mask=None, deterministic=False):
|
||||
def predict(self, observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
|
||||
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Load model parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
|
|
@ -305,20 +324,20 @@ class BaseRLModel(ABC):
|
|||
but can only be handled in child classes.
|
||||
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
|
||||
:param load_dict: dict of parameters from model.state_dict()
|
||||
:param opt_params: dict of optimizer state_dicts should be handled in child_class
|
||||
"""
|
||||
if opt_params is not None:
|
||||
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
|
||||
self.policy.load_state_dict(load_dict)
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path, env=None, **kwargs):
|
||||
def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs):
|
||||
"""
|
||||
Load the model from a zip-file
|
||||
|
||||
:param load_path: (str) the location of the saved data
|
||||
:param env: (Gym Environment) the new environment to run the loaded model on
|
||||
:param load_path: the location of the saved data
|
||||
:param env: the new environment to run the loaded model on
|
||||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
|
|
@ -349,11 +368,11 @@ class BaseRLModel(ABC):
|
|||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(load_path, load_data=True):
|
||||
def _load_from_file(load_path: str, load_data: bool = True):
|
||||
""" Load model data from a .zip archive
|
||||
|
||||
:param load_path: (str) Where to load the model from
|
||||
:param load_data: (bool) Whether we should load and return data
|
||||
:param load_path: Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict)
|
||||
and dict of optimizer parameters (dict of state_dict)
|
||||
|
|
@ -415,7 +434,7 @@ class BaseRLModel(ABC):
|
|||
|
||||
return data, params, opt_params
|
||||
|
||||
def set_random_seed(self, seed: Optional[int] = None):
|
||||
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
Set the seed of the pseudo-random generators
|
||||
(python, numpy, pytorch, gym, action_space)
|
||||
|
|
|
|||
Loading…
Reference in a new issue