mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Fix return type for load, learn in BaseAlgorithm (#1043)
* Fix return type for load, learn in BaseAlgorithm * Update changelog * Add typing extensions to dependencies * Import directly from typing for python >3.11 * Reorder changelog to reflect merge order * Roll back to typevar solution * Updated changelog * Remove typing extensions requirement * Update base_class.py * Remove final point in changelog * Additional type fixes across project Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
899eee6bd4
commit
432b3f876d
11 changed files with 58 additions and 33 deletions
|
|
@ -31,6 +31,8 @@ Bug Fixes:
|
|||
- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde)
|
||||
- Fixed loading saved model with different number of envrionments
|
||||
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
|
||||
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
|
||||
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
|
@ -9,6 +9,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
|
||||
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
||||
|
||||
|
||||
class A2C(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -183,7 +185,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: A2CSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
|
|
@ -193,7 +195,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "A2C",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "A2C":
|
||||
) -> A2CSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pathlib
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -53,6 +53,9 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
|
|||
return env
|
||||
|
||||
|
||||
BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm")
|
||||
|
||||
|
||||
class BaseAlgorithm(ABC):
|
||||
"""
|
||||
The base of RL algorithms
|
||||
|
|
@ -537,7 +540,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
self: BaseAlgorithmSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 100,
|
||||
|
|
@ -547,7 +550,7 @@ class BaseAlgorithm(ABC):
|
|||
n_eval_episodes: int = 5,
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "BaseAlgorithm":
|
||||
) -> BaseAlgorithmSelf:
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
@ -671,7 +674,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
cls: Type[BaseAlgorithmSelf],
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
env: Optional[GymEnv] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
|
@ -679,7 +682,7 @@ class BaseAlgorithm(ABC):
|
|||
print_system_info: bool = False,
|
||||
force_reset: bool = True,
|
||||
**kwargs,
|
||||
) -> "BaseAlgorithm":
|
||||
) -> BaseAlgorithmSelf:
|
||||
"""
|
||||
Load the model from a zip-file.
|
||||
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
|
||||
|
|
@ -709,7 +712,10 @@ class BaseAlgorithm(ABC):
|
|||
get_system_info()
|
||||
|
||||
data, params, pytorch_variables = load_from_zip_file(
|
||||
path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
|
||||
path,
|
||||
device=device,
|
||||
custom_objects=custom_objects,
|
||||
print_system_info=print_system_info,
|
||||
)
|
||||
|
||||
# Remove stored device information and replace with ours
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -21,6 +21,8 @@ from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
|||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
|
||||
OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm")
|
||||
|
||||
|
||||
class OffPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
|
|
@ -319,7 +321,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: OffPolicyAlgorithmSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -329,7 +331,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
tb_log_name: str = "run",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "OffPolicyAlgorithm":
|
||||
) -> OffPolicyAlgorithmSelf:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
||||
OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm")
|
||||
|
||||
|
||||
class OnPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
|
|
@ -225,7 +227,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
raise NotImplementedError
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: OnPolicyAlgorithmSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -235,7 +237,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "OnPolicyAlgorithm":
|
||||
) -> OnPolicyAlgorithmSelf:
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import copy
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -33,6 +33,8 @@ from stable_baselines3.common.torch_layers import (
|
|||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""
|
||||
|
|
@ -158,7 +160,7 @@ class BaseModel(nn.Module):
|
|||
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
|
||||
def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf:
|
||||
"""
|
||||
Load model from path.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.td3.policies import TD3Policy
|
||||
from stable_baselines3.td3.td3 import TD3
|
||||
|
||||
DDPGSelf = TypeVar("DDPGSelf", bound="DDPG")
|
||||
|
||||
|
||||
class DDPG(TD3):
|
||||
"""
|
||||
|
|
@ -116,7 +117,7 @@ class DDPG(TD3):
|
|||
self._setup_model()
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: DDPGSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -126,7 +127,7 @@ class DDPG(TD3):
|
|||
tb_log_name: str = "DDPG",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> DDPGSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -14,6 +14,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
DQNSelf = TypeVar("DQNSelf", bound="DQN")
|
||||
|
||||
|
||||
class DQN(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -255,7 +257,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
return action, state
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: DQNSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -265,7 +267,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "DQN",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> DQNSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
|
@ -11,6 +11,8 @@ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticP
|
|||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
PPOSelf = TypeVar("PPOSelf", bound="PPO")
|
||||
|
||||
|
||||
class PPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -297,7 +299,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: PPOSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
|
|
@ -307,7 +309,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "PPO":
|
||||
) -> PPOSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
SACSelf = TypeVar("SACSelf", bound="SAC")
|
||||
|
||||
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -289,7 +291,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: SACSelf,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -299,7 +301,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "SAC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> SACSelf:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -13,6 +13,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul
|
|||
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
|
||||
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
|
||||
|
||||
TD3Self = TypeVar("TD3Self", bound="TD3")
|
||||
|
||||
|
||||
class TD3(OffPolicyAlgorithm):
|
||||
"""
|
||||
|
|
@ -205,7 +207,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
self: TD3Self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 4,
|
||||
|
|
@ -215,7 +217,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
tb_log_name: str = "TD3",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
) -> TD3Self:
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
|
|
|||
Loading…
Reference in a new issue