From 432b3f876d7c4a8bbbbf952f83110ec2ca041522 Mon Sep 17 00:00:00 2001 From: Juan Rocamonde Date: Mon, 26 Sep 2022 12:13:56 +0200 Subject: [PATCH] Fix return type for load, learn in BaseAlgorithm (#1043) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix return type for load, learn in BaseAlgorithm * Update changelog * Add typing extensions to dependencies * Import directly from typing for python >3.11 * Reorder changelog to reflect merge order * Roll back to typevar solution * Updated changelog * Remove typing extensions requirement * Update base_class.py * Remove final point in changelog * Additional type fixes across project Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 2 ++ stable_baselines3/a2c/a2c.py | 8 +++++--- stable_baselines3/common/base_class.py | 18 ++++++++++++------ .../common/off_policy_algorithm.py | 8 +++++--- .../common/on_policy_algorithm.py | 8 +++++--- stable_baselines3/common/policies.py | 6 ++++-- stable_baselines3/ddpg/ddpg.py | 9 +++++---- stable_baselines3/dqn/dqn.py | 8 +++++--- stable_baselines3/ppo/ppo.py | 8 +++++--- stable_baselines3/sac/sac.py | 8 +++++--- stable_baselines3/td3/td3.py | 8 +++++--- 11 files changed, 58 insertions(+), 33 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2c2e5ff..8fd853b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -31,6 +31,8 @@ Bug Fixes: - Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde) - Fixed loading saved model with different number of envrionments - Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde) +- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde) + Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 8058f52..8b8cecb 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import torch as th from gym import spaces @@ -9,6 +9,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance +A2CSelf = TypeVar("A2CSelf", bound="A2C") + class A2C(OnPolicyAlgorithm): """ @@ -183,7 +185,7 @@ class A2C(OnPolicyAlgorithm): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def learn( - self, + self: A2CSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, @@ -193,7 +195,7 @@ class A2C(OnPolicyAlgorithm): tb_log_name: str = "A2C", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "A2C": + ) -> A2CSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index dc04874..33d3fac 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -5,7 +5,7 @@ import pathlib import time from abc import ABC, abstractmethod from collections import deque -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -53,6 +53,9 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE return env +BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm") + + class BaseAlgorithm(ABC): """ The base of RL algorithms @@ -537,7 +540,7 @@ class BaseAlgorithm(ABC): @abstractmethod def learn( - self, + self: BaseAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, @@ -547,7 +550,7 @@ class BaseAlgorithm(ABC): n_eval_episodes: int = 5, eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "BaseAlgorithm": + ) -> BaseAlgorithmSelf: """ Return a trained model. @@ -671,7 +674,7 @@ class BaseAlgorithm(ABC): @classmethod def load( - cls, + cls: Type[BaseAlgorithmSelf], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", @@ -679,7 +682,7 @@ class BaseAlgorithm(ABC): print_system_info: bool = False, force_reset: bool = True, **kwargs, - ) -> "BaseAlgorithm": + ) -> BaseAlgorithmSelf: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! @@ -709,7 +712,10 @@ class BaseAlgorithm(ABC): get_system_info() data, params, pytorch_variables = load_from_zip_file( - path, device=device, custom_objects=custom_objects, print_system_info=print_system_info + path, + device=device, + custom_objects=custom_objects, + print_system_info=print_system_info, ) # Remove stored device information and replace with ours diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index d2574ed..c23223d 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,7 +4,7 @@ import sys import time import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -21,6 +21,8 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer +OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm") + class OffPolicyAlgorithm(BaseAlgorithm): """ @@ -319,7 +321,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): ) def learn( - self, + self: OffPolicyAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -329,7 +331,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): tb_log_name: str = "run", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "OffPolicyAlgorithm": + ) -> OffPolicyAlgorithmSelf: total_timesteps, callback = self._setup_learn( total_timesteps, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a9c0a41..4063680 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,6 +1,6 @@ import sys import time -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv +OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm") + class OnPolicyAlgorithm(BaseAlgorithm): """ @@ -225,7 +227,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): raise NotImplementedError def learn( - self, + self: OnPolicyAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -235,7 +237,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): tb_log_name: str = "OnPolicyAlgorithm", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "OnPolicyAlgorithm": + ) -> OnPolicyAlgorithmSelf: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 16afe06..4632e48 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -5,7 +5,7 @@ import copy import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -33,6 +33,8 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor +BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel") + class BaseModel(nn.Module): """ @@ -158,7 +160,7 @@ class BaseModel(nn.Module): th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": + def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf: """ Load model from path. diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index a9244e7..531acd1 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,14 +1,15 @@ -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 +DDPGSelf = TypeVar("DDPGSelf", bound="DDPG") + class DDPG(TD3): """ @@ -116,7 +117,7 @@ class DDPG(TD3): self._setup_model() def learn( - self, + self: DDPGSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -126,7 +127,7 @@ class DDPG(TD3): tb_log_name: str = "DDPG", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> DDPGSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 80e024b..ea7d9f3 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy +DQNSelf = TypeVar("DQNSelf", bound="DQN") + class DQN(OffPolicyAlgorithm): """ @@ -255,7 +257,7 @@ class DQN(OffPolicyAlgorithm): return action, state def learn( - self, + self: DQNSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -265,7 +267,7 @@ class DQN(OffPolicyAlgorithm): tb_log_name: str = "DQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> DQNSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 6bb9c23..d65201b 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import numpy as np import torch as th @@ -11,6 +11,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn +PPOSelf = TypeVar("PPOSelf", bound="PPO") + class PPO(OnPolicyAlgorithm): """ @@ -297,7 +299,7 @@ class PPO(OnPolicyAlgorithm): self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self, + self: PPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -307,7 +309,7 @@ class PPO(OnPolicyAlgorithm): tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "PPO": + ) -> PPOSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index de08b75..8505e88 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy +SACSelf = TypeVar("SACSelf", bound="SAC") + class SAC(OffPolicyAlgorithm): """ @@ -289,7 +291,7 @@ class SAC(OffPolicyAlgorithm): self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self, + self: SACSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -299,7 +301,7 @@ class SAC(OffPolicyAlgorithm): tb_log_name: str = "SAC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> SACSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 51df755..62e33f5 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy +TD3Self = TypeVar("TD3Self", bound="TD3") + class TD3(OffPolicyAlgorithm): """ @@ -205,7 +207,7 @@ class TD3(OffPolicyAlgorithm): self.logger.record("train/critic_loss", np.mean(critic_losses)) def learn( - self, + self: TD3Self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -215,7 +217,7 @@ class TD3(OffPolicyAlgorithm): tb_log_name: str = "TD3", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> TD3Self: return super().learn( total_timesteps=total_timesteps,