Fix return type for load, learn in BaseAlgorithm (#1043)

* Fix return type for load, learn in BaseAlgorithm

* Update changelog

* Add typing extensions to dependencies

* Import directly from typing for python >3.11

* Reorder changelog to reflect merge order

* Roll back to typevar solution

* Updated changelog

* Remove typing extensions requirement

* Update base_class.py

* Remove final point in changelog

* Additional type fixes across project

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Juan Rocamonde 2022-09-26 12:13:56 +02:00 committed by GitHub
parent 899eee6bd4
commit 432b3f876d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 58 additions and 33 deletions

View file

@ -31,6 +31,8 @@ Bug Fixes:
- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde)
- Fixed loading saved model with different number of envrionments
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
Deprecations:
^^^^^^^^^^^^^

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union
import torch as th
from gym import spaces
@ -9,6 +9,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
A2CSelf = TypeVar("A2CSelf", bound="A2C")
class A2C(OnPolicyAlgorithm):
"""
@ -183,7 +185,7 @@ class A2C(OnPolicyAlgorithm):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
self,
self: A2CSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
@ -193,7 +195,7 @@ class A2C(OnPolicyAlgorithm):
tb_log_name: str = "A2C",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "A2C":
) -> A2CSelf:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -5,7 +5,7 @@ import pathlib
import time
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -53,6 +53,9 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
return env
BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm")
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
@ -537,7 +540,7 @@ class BaseAlgorithm(ABC):
@abstractmethod
def learn(
self,
self: BaseAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
@ -547,7 +550,7 @@ class BaseAlgorithm(ABC):
n_eval_episodes: int = 5,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "BaseAlgorithm":
) -> BaseAlgorithmSelf:
"""
Return a trained model.
@ -671,7 +674,7 @@ class BaseAlgorithm(ABC):
@classmethod
def load(
cls,
cls: Type[BaseAlgorithmSelf],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
@ -679,7 +682,7 @@ class BaseAlgorithm(ABC):
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> "BaseAlgorithm":
) -> BaseAlgorithmSelf:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
@ -709,7 +712,10 @@ class BaseAlgorithm(ABC):
get_system_info()
data, params, pytorch_variables = load_from_zip_file(
path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
path,
device=device,
custom_objects=custom_objects,
print_system_info=print_system_info,
)
# Remove stored device information and replace with ours

View file

@ -4,7 +4,7 @@ import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -21,6 +21,8 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm")
class OffPolicyAlgorithm(BaseAlgorithm):
"""
@ -319,7 +321,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
)
def learn(
self,
self: OffPolicyAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -329,7 +331,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
tb_log_name: str = "run",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "OffPolicyAlgorithm":
) -> OffPolicyAlgorithmSelf:
total_timesteps, callback = self._setup_learn(
total_timesteps,

View file

@ -1,6 +1,6 @@
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
"""
@ -225,7 +227,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
raise NotImplementedError
def learn(
self,
self: OnPolicyAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -235,7 +237,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
tb_log_name: str = "OnPolicyAlgorithm",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "OnPolicyAlgorithm":
) -> OnPolicyAlgorithmSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(

View file

@ -5,7 +5,7 @@ import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -33,6 +33,8 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel")
class BaseModel(nn.Module):
"""
@ -158,7 +160,7 @@ class BaseModel(nn.Module):
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf:
"""
Load model from path.

View file

@ -1,14 +1,15 @@
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
DDPGSelf = TypeVar("DDPGSelf", bound="DDPG")
class DDPG(TD3):
"""
@ -116,7 +117,7 @@ class DDPG(TD3):
self._setup_model()
def learn(
self,
self: DDPGSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -126,7 +127,7 @@ class DDPG(TD3):
tb_log_name: str = "DDPG",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> DDPGSelf:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
DQNSelf = TypeVar("DQNSelf", bound="DQN")
class DQN(OffPolicyAlgorithm):
"""
@ -255,7 +257,7 @@ class DQN(OffPolicyAlgorithm):
return action, state
def learn(
self,
self: DQNSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -265,7 +267,7 @@ class DQN(OffPolicyAlgorithm):
tb_log_name: str = "DQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> DQNSelf:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
import torch as th
@ -11,6 +11,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
PPOSelf = TypeVar("PPOSelf", bound="PPO")
class PPO(OnPolicyAlgorithm):
"""
@ -297,7 +299,7 @@ class PPO(OnPolicyAlgorithm):
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
self,
self: PPOSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
@ -307,7 +309,7 @@ class PPO(OnPolicyAlgorithm):
tb_log_name: str = "PPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "PPO":
) -> PPOSelf:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
SACSelf = TypeVar("SACSelf", bound="SAC")
class SAC(OffPolicyAlgorithm):
"""
@ -289,7 +291,7 @@ class SAC(OffPolicyAlgorithm):
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
self,
self: SACSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -299,7 +301,7 @@ class SAC(OffPolicyAlgorithm):
tb_log_name: str = "SAC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> SACSelf:
return super().learn(
total_timesteps=total_timesteps,

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
TD3Self = TypeVar("TD3Self", bound="TD3")
class TD3(OffPolicyAlgorithm):
"""
@ -205,7 +207,7 @@ class TD3(OffPolicyAlgorithm):
self.logger.record("train/critic_loss", np.mean(critic_losses))
def learn(
self,
self: TD3Self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
@ -215,7 +217,7 @@ class TD3(OffPolicyAlgorithm):
tb_log_name: str = "TD3",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
) -> TD3Self:
return super().learn(
total_timesteps=total_timesteps,