From ff0eddfb1796e84dde7560db1f56f2ebefe83737 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 22 Jan 2020 17:51:27 +0100 Subject: [PATCH] Partially type base class --- docs/conf.py | 1 + setup.py | 4 +- torchy_baselines/common/base_class.py | 149 +++++++++++++++----------- 3 files changed, 88 insertions(+), 66 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 06c977b..624e305 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,6 +70,7 @@ release = torchy_baselines.__version__ # ones. extensions = [ 'sphinx.ext.autodoc', + 'sphinx_autodoc_typehints', 'sphinx.ext.autosummary', 'sphinx.ext.mathjax', 'sphinx.ext.ifconfig', diff --git a/setup.py b/setup.py index 343967a..e598108 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,9 @@ setup(name='torchy_baselines', 'sphinx-autobuild', 'sphinx-rtd-theme', # For spelling - 'sphinxcontrib.spelling' + 'sphinxcontrib.spelling', + # Type hints support + 'sphinx-autodoc-typehints' ], 'extra': [ # For render diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a745d4c..d714300 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -3,7 +3,7 @@ import os import io import zipfile import typing -from typing import Union, Type, Optional +from typing import Union, Type, Optional, Dict, Any, List, Tuple from abc import ABC, abstractmethod from collections import deque @@ -19,6 +19,7 @@ from torchy_baselines.common.monitor import Monitor from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.common.save_util import data_to_json, json_to_data +# TODO: define aliases, ex GymEnv = Union[gym.Env, VecEnv] if typing.TYPE_CHECKING: from torchy_baselines.common.noise import ActionNoise @@ -27,31 +28,41 @@ class BaseRLModel(ABC): """ The base RL model - :param policy: (BasePolicy) Policy object - :param env: (Gym environment) The environment to learn from + :param policy: Policy object + :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: (BasePolicy) the base policy used by this method - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug - :param device: (str or th.device) Device on which the code should run. + :param policy_base: The base policy used by this method + :param policy_kwargs: Additional arguments to be passed to the policy on creation + :param verbose: The verbosity level: 0 none, 1 training information, 2 debug + :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible. - :param support_multi_env: (bool) Whether the algorithm supports training + :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param monitor_wrapper: (bool) When creating an environment, whether to wrap it + :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. - :param seed: (int) Seed for the pseudo random generators - :param use_sde: (bool) Whether to use State Dependent Exploration (SDE) + :param seed: Seed for the pseudo random generators + :param use_sde: Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using SDE Default: -1 (only sample at the beginning of the rollout) """ - def __init__(self, policy: Type[BasePolicy], env: Union[gym.Env, VecEnv, str], policy_base, policy_kwargs=None, - verbose=0, device='auto', support_multi_env=False, - create_eval_env=False, monitor_wrapper=True, seed=None, - use_sde=False, sde_sample_freq=-1): + def __init__(self, + policy: Type[BasePolicy], + env: Union[gym.Env, VecEnv, str], + policy_base: Type[BasePolicy], + policy_kwargs : Dict[str, Any] = None, + verbose: int = 0, + device: Union[th.device, str] = 'auto', + support_multi_env: bool = False, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + seed: Optional[int] = None, + use_sde: bool = False, + sde_sample_freq: int = -1): + if isinstance(policy, str) and policy_base is not None: self.policy_class = get_policy_from_name(policy_base, policy) else: @@ -118,12 +129,12 @@ class BaseRLModel(ABC): raise ValueError("Error: the model does not support multiple envs requires a single vectorized" " environment.") - def _get_eval_env(self, eval_env): + def _get_eval_env(self, eval_env: Union[gym.Env, VecEnv, None]) -> Union[gym.Env, VecEnv, None]: """ Return the environment that will be used for evaluation. - :param eval_env: (gym.Env or VecEnv) - :return: (VecEnv) + :param eval_env: + :return: """ if eval_env is None: eval_env = self.eval_env @@ -134,48 +145,47 @@ class BaseRLModel(ABC): assert eval_env.num_envs == 1 return eval_env - def scale_action(self, action): + def scale_action(self, action: np.ndarray) -> np.ndarray: """ Rescale the action from [low, high] to [-1, 1] (no need for symmetric action space) - :param action: (np.ndarray) - :return: (np.ndarray) + :param action: + :return: """ low, high = self.action_space.low, self.action_space.high return 2.0 * ((action - low) / (high - low)) - 1.0 - def unscale_action(self, scaled_action): + def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: """ Rescale the action from [-1, 1] to [low, high] (no need for symmetric action space) - :param scaled_action: (np.ndarray) - :return: (np.ndarray) + :param scaled_action: + :return: """ low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) - def _setup_learning_rate(self): + def _setup_learning_rate(self) -> None: """Transform to callable if needed.""" self.learning_rate = get_schedule_fn(self.learning_rate) - def _update_current_progress(self, num_timesteps, total_timesteps): + def _update_current_progress(self, num_timesteps: int, total_timesteps: int) -> None: """ Compute current progress (from 1 to 0) - :param num_timesteps: (int) current number of timesteps - :param total_timesteps: (int) + :param num_timesteps: current number of timesteps + :param total_timesteps: """ self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) - def _update_learning_rate(self, optimizers): + def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress (from 1 to 0). - :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer - or a list of optimizer. + :param optimizers: An optimizer or a list of optimizer. """ # Log the current learning rate logger.logkv("learning_rate", self.learning_rate(self._current_progress)) @@ -186,32 +196,32 @@ class BaseRLModel(ABC): update_learning_rate(optimizer, self.learning_rate(self._current_progress)) @staticmethod - def safe_mean(arr): + def safe_mean(arr: Union[np.ndarray, list]) -> np.ndarray: """ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. - :param arr: (np.ndarray) - :return: (float) + :param arr: + :return: """ return np.nan if len(arr) == 0 else np.mean(arr) - def get_env(self): + def get_env(self) -> Union[VecEnv, None]: """ Returns the current environment (can be None if not defined). - :return: (gym.Env) The current environment + :return: The current environment """ return self.env @staticmethod - def check_env(env, observation_space, action_space): + def check_env(env, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> bool: """ Checks the validity of the environment and returns if it is consistent. Checked parameters: - observation_space - action_space - :return: (bool) True if environment seems to be coherent + :return: True if environment seems to be coherent """ if observation_space != env.observation_space: return False @@ -220,7 +230,7 @@ class BaseRLModel(ABC): # return true if no check failed return True - def set_env(self, env): + def set_env(self, env: Union[gym.Env, VecEnv]) -> None: """ Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized @@ -228,7 +238,7 @@ class BaseRLModel(ABC): - observation_space - action_space - :param env: (gym.Env) The environment for learning a policy + :param env: The environment for learning a policy """ if self.check_env(env, self.observation_space, self.action_space) is False: raise ValueError("The given environment is not compatible with model: " @@ -242,23 +252,24 @@ class BaseRLModel(ABC): self.n_envs = env.num_envs self.env = env - def get_parameters(self): + def get_parameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Returns policy and optimizer parameters as a tuple - :return: (dict,dict) policy_parameters, opt_parameters + + :return: policy_parameters, opt_parameters """ return self.get_policy_parameters(), self.get_opt_parameters() - def get_policy_parameters(self): + def get_policy_parameters(self) -> Dict[str, Any]: """ Get current model policy parameters as dictionary of variable name -> tensors. - :return: (dict) Dictionary of variable name -> tensor of model's policy parameters. + :return: Dictionary of variable name -> tensor of model's policy parameters. """ return self.policy.state_dict() @abstractmethod - def get_opt_parameters(self): + def get_opt_parameters(self)-> Dict[str, Any]: """ Get current model optimizer parameters as dictionary of variable names -> tensors :return: (dict) Dictionary of variable name -> tensor of model's optimizer parameters @@ -266,8 +277,13 @@ class BaseRLModel(ABC): raise NotImplementedError() @abstractmethod - def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run", - eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True): + def learn(self, total_timesteps: int, + callback=None, log_interval: int = 100, + tb_log_name: str = "run", + eval_env: Union[gym.Env, VecEnv, None] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + reset_num_timesteps: bool = True): """ Return a trained model. @@ -285,19 +301,22 @@ class BaseRLModel(ABC): raise NotImplementedError() @abstractmethod - def predict(self, observation, state=None, mask=None, deterministic=False): + def predict(self, observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False) -> np.ndarray: """ Get the model's action from an observation - :param observation: (np.ndarray) the input observation - :param state: (np.ndarray) The last states (can be None, used in recurrent policies) - :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state (used in recurrent policies) """ raise NotImplementedError() - def load_parameters(self, load_dict, opt_params): + def load_parameters(self, load_dict: Dict[str, Any], opt_params: Dict[str, Any]) -> None: """ Load model parameters from a dictionary load_dict should contain all keys from torch.model.state_dict() @@ -305,20 +324,20 @@ class BaseRLModel(ABC): but can only be handled in child classes. - :param load_dict: (dict) dict of parameters from model.state_dict() - :param opt_params: (dict of dicts) dict of optimizer state_dicts should be handled in child_class + :param load_dict: dict of parameters from model.state_dict() + :param opt_params: dict of optimizer state_dicts should be handled in child_class """ if opt_params is not None: raise ValueError("Optimizer Parameters where given but no overloaded load function exists for this class") self.policy.load_state_dict(load_dict) @classmethod - def load(cls, load_path, env=None, **kwargs): + def load(cls, load_path: str, env: Union[gym.Env, VecEnv, None] = None, **kwargs): """ Load the model from a zip-file - :param load_path: (str) the location of the saved data - :param env: (Gym Environment) the new environment to run the loaded model on + :param load_path: the location of the saved data + :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param kwargs: extra arguments to change the model when loading """ @@ -349,11 +368,11 @@ class BaseRLModel(ABC): return model @staticmethod - def _load_from_file(load_path, load_data=True): + def _load_from_file(load_path: str, load_data: bool = True): """ Load model data from a .zip archive - :param load_path: (str) Where to load the model from - :param load_data: (bool) Whether we should load and return data + :param load_path: Where to load the model from + :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) :return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict) @@ -415,7 +434,7 @@ class BaseRLModel(ABC): return data, params, opt_params - def set_random_seed(self, seed: Optional[int] = None): + def set_random_seed(self, seed: Optional[int] = None) -> None: """ Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)