Partially type base class

This commit is contained in:
Antonin Raffin 2020-01-22 17:51:27 +01:00
parent 0328a39d1b
commit ff0eddfb17
3 changed files with 88 additions and 66 deletions

View file

@ -70,6 +70,7 @@ release = torchy_baselines.__version__
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx_autodoc_typehints',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',

View file

@ -25,7 +25,9 @@ setup(name='torchy_baselines',
'sphinx-autobuild',
'sphinx-rtd-theme',
# For spelling
'sphinxcontrib.spelling'
'sphinxcontrib.spelling',
# Type hints support
'sphinx-autodoc-typehints'
],
'extra': [
# For render

View file

@ -3,7 +3,7 @@ import os
import io
import zipfile
import typing
from typing import Union, Type, Optional
from typing import Union, Type, Optional, Dict, Any, List, Tuple
from abc import ABC, abstractmethod
from collections import deque
@ -19,6 +19,7 @@ from torchy_baselines.common.monitor import Monitor
from torchy_baselines.common.evaluation import evaluate_policy
from torchy_baselines.common.save_util import data_to_json, json_to_data
# TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv]
if typing.TYPE_CHECKING:
from torchy_baselines.common.noise import ActionNoise
@ -27,31 +28,41 @@ class BaseRLModel(ABC):
"""
The base RL model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
:param policy: Policy object
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: (BasePolicy) the base policy used by this method
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
:param device: (str or th.device) Device on which the code should run.
:param policy_base: The base policy used by this method
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: (bool) Whether the algorithm supports training
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param create_eval_env: (bool) Whether to create a second environment that will be
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: (int) Seed for the pseudo random generators
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
Default: -1 (only sample at the beginning of the rollout)
"""
def __init__(self, policy: Type[BasePolicy], env: Union[gym.Env, VecEnv, str], policy_base, policy_kwargs=None,
verbose=0, device='auto', support_multi_env=False,
create_eval_env=False, monitor_wrapper=True, seed=None,
use_sde=False, sde_sample_freq=-1):
def __init__(self,
policy: Type[BasePolicy],
env: Union[gym.Env, VecEnv, str],
policy_base: Type[BasePolicy],
policy_kwargs : Dict[str, Any] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
support_multi_env: bool = False,
create_eval_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1):
if isinstance(policy, str) and policy_base is not None:
self.policy_class = get_policy_from_name(policy_base, policy)
else:
@ -118,12 +129,12 @@ class BaseRLModel(ABC):
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
" environment.")
def _get_eval_env(self, eval_env):
def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]:
"""
Return the environment that will be used for evaluation.
:param eval_env: (gym.Env or VecEnv)
:return: (VecEnv)
:param eval_env:
:return:
"""
if eval_env is None:
eval_env = self.eval_env
@ -134,48 +145,47 @@ class BaseRLModel(ABC):
assert eval_env.num_envs == 1
return eval_env
def scale_action(self, action):
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
:param action: (np.ndarray)
:return: (np.ndarray)
:param action:
:return:
"""
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
def unscale_action(self, scaled_action):
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
:param scaled_action: (np.ndarray)
:return: (np.ndarray)
:param scaled_action:
:return:
"""
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
def _setup_learning_rate(self):
def _setup_learning_rate(self) -> None:
"""Transform to callable if needed."""
self.learning_rate = get_schedule_fn(self.learning_rate)
def _update_current_progress(self, num_timesteps, total_timesteps):
def _update_current_progress(self, num_timesteps: int, total_timesteps: int) -> None:
"""
Compute current progress (from 1 to 0)
:param num_timesteps: (int) current number of timesteps
:param total_timesteps: (int)
:param num_timesteps: current number of timesteps
:param total_timesteps:
"""
self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps)
def _update_learning_rate(self, optimizers):
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress (from 1 to 0).
:param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer
or a list of optimizer.
:param optimizers: An optimizer or a list of optimizer.
"""
# Log the current learning rate
logger.logkv("learning_rate", self.learning_rate(self._current_progress))
@ -186,32 +196,32 @@ class BaseRLModel(ABC):
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
@staticmethod
def safe_mean(arr):
def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: (np.ndarray)
:return: (float)
:param arr:
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
def get_env(self):
def get_env(self) -> Union[VecEnv, None]:
"""
Returns the current environment (can be None if not defined).
:return: (gym.Env) The current environment
:return: The current environment
"""
return self.env
@staticmethod
def check_env(env, observation_space, action_space):
def check_env(env, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> bool:
"""
Checks the validity of the environment and returns if it is consistent.
Checked parameters:
- observation_space
- action_space
:return: (bool) True if environment seems to be coherent
:return: True if environment seems to be coherent
"""
if observation_space != env.observation_space:
return False
@ -220,7 +230,7 @@ class BaseRLModel(ABC):
# return true if no check failed
return True
def set_env(self, env):
def set_env(self, env: Union[gym.Env, VecEnv]) -> None:
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
@ -228,7 +238,7 @@ class BaseRLModel(ABC):
- observation_space
- action_space
:param env: (gym.Env) The environment for learning a policy
:param env: The environment for learning a policy
"""
if self.check_env(env, self.observation_space, self.action_space) is False:
raise ValueError("The given environment is not compatible with model: "
@ -242,23 +252,24 @@ class BaseRLModel(ABC):
self.n_envs = env.num_envs
self.env = env
def get_parameters(self):
def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Returns policy and optimizer parameters as a tuple
:return: (dict,dict) policy_parameters, opt_parameters
:return: policy_parameters, opt_parameters
"""
return self.get_policy_parameters(), self.get_opt_parameters()
def get_policy_parameters(self):
def get_policy_parameters(self) -> Dict[str, Any]:
"""
Get current model policy parameters as dictionary of variable name -> tensors.
:return: (dict) Dictionary of variable name -> tensor of model's policy parameters.
:return: Dictionary of variable name -> tensor of model's policy parameters.
"""
return self.policy.state_dict()
@abstractmethod
def get_opt_parameters(self):
def get_opt_parameters(self)-> Dict[str, Any]:
"""
Get current model optimizer parameters as dictionary of variable names -> tensors
:return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters
@ -266,8 +277,13 @@ class BaseRLModel(ABC):
raise NotImplementedError()
@abstractmethod
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True):
def learn(self, total_timesteps: int,
callback=None, log_interval: int = 100,
tb_log_name: str = "run",
eval_env: Union[gym.Env, VecEnv, None] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
reset_num_timesteps: bool = True):
"""
Return a trained model.
@ -285,19 +301,22 @@ class BaseRLModel(ABC):
raise NotImplementedError()
@abstractmethod
def predict(self, observation, state=None, mask=None, deterministic=False):
def predict(self, observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False) -> np.ndarray:
"""
Get the model's action from an observation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state (used in recurrent policies)
"""
raise NotImplementedError()
def load_parameters(self, load_dict, opt_params):
def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None:
"""
Load model parameters from a dictionary
load_dict should contain all keys from torch.model.state_dict()
@ -305,20 +324,20 @@ class BaseRLModel(ABC):
but can only be handled in child classes.
:param load_dict: (dict) dict of parameters from model.state_dict()
:param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class
:param load_dict: dict of parameters from model.state_dict()
:param opt_params: dict of optimizer state_dicts should be handled in child_class
"""
if opt_params is not None:
raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class")
self.policy.load_state_dict(load_dict)
@classmethod
def load(cls, load_path, env=None, **kwargs):
def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs):
"""
Load the model from a zip-file
:param load_path: (str) the location of the saved data
:param env: (Gym Environment) the new environment to run the loaded model on
:param load_path: the location of the saved data
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param kwargs: extra arguments to change the model when loading
"""
@ -349,11 +368,11 @@ class BaseRLModel(ABC):
return model
@staticmethod
def _load_from_file(load_path, load_data=True):
def _load_from_file(load_path: str, load_data: bool = True):
""" Load model data from a .zip archive
:param load_path: (str) Where to load the model from
:param load_data: (bool) Whether we should load and return data
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict)
and dict of optimizer parameters (dict of state_dict)
@ -415,7 +434,7 @@ class BaseRLModel(ABC):
return data, params, opt_params
def set_random_seed(self, seed: Optional[int] = None):
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
Set the seed of the pseudo-random generators
(python, numpy, pytorch, gym, action_space)