mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Fix policies type annotations (#1735)
This commit is contained in:
parent
a35c08c0d6
commit
d671402c93
9 changed files with 74 additions and 56 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a10 (WIP)
|
||||
Release 2.2.0a11 (WIP)
|
||||
--------------------------
|
||||
**Support for options at reset, bug fixes and better error messages**
|
||||
|
||||
|
|
@ -62,6 +62,7 @@ Others:
|
|||
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/distributions.py`` type hints
|
||||
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
|
||||
- Fixed ``stable_baselines3/common/policies.py`` type hints
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ exclude = [
|
|||
"stable_baselines3/common/vec_env/patch_gym.py",
|
||||
"stable_baselines3/common/off_policy_algorithm.py",
|
||||
"stable_baselines3/common/distributions.py",
|
||||
"stable_baselines3/common/policies.py",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
|
@ -45,8 +46,7 @@ ignore_missing_imports = true
|
|||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
stable_baselines3/common/policies.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
| tests/test_logger.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
NatureCNN,
|
||||
create_mlp,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||
|
||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||
|
|
@ -119,7 +119,7 @@ class BaseModel(nn.Module):
|
|||
"""Helper method to create a features extractor."""
|
||||
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
|
||||
def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
|
||||
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
||||
|
|
@ -219,6 +219,9 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
vectorized_env = False
|
||||
if isinstance(observation, dict):
|
||||
assert isinstance(
|
||||
self.observation_space, spaces.Dict
|
||||
), f"The observation provided is a dict but the obs space is {self.observation_space}"
|
||||
for key, obs in observation.items():
|
||||
obs_space = self.observation_space.spaces[key]
|
||||
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
|
||||
|
|
@ -228,7 +231,7 @@ class BaseModel(nn.Module):
|
|||
)
|
||||
return vectorized_env
|
||||
|
||||
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
|
||||
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
|
||||
"""
|
||||
Convert an input observation to a PyTorch tensor that can be fed to a model.
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
|
@ -239,6 +242,9 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
vectorized_env = False
|
||||
if isinstance(observation, dict):
|
||||
assert isinstance(
|
||||
self.observation_space, spaces.Dict
|
||||
), f"The observation provided is a dict but the obs space is {self.observation_space}"
|
||||
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
|
||||
observation = copy.deepcopy(observation)
|
||||
for key, obs in observation.items():
|
||||
|
|
@ -249,7 +255,7 @@ class BaseModel(nn.Module):
|
|||
obs_ = np.array(obs)
|
||||
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
|
||||
# Add batch dimension if needed
|
||||
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))
|
||||
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]
|
||||
|
||||
elif is_image_space(self.observation_space):
|
||||
# Handle the different cases for images
|
||||
|
|
@ -263,10 +269,10 @@ class BaseModel(nn.Module):
|
|||
# Dict obs need to be handled separately
|
||||
vectorized_env = is_vectorized_observation(observation, self.observation_space)
|
||||
# Add batch dimension if needed
|
||||
observation = observation.reshape((-1, *self.observation_space.shape))
|
||||
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]
|
||||
|
||||
observation = obs_as_tensor(observation, self.device)
|
||||
return observation, vectorized_env
|
||||
obs_tensor = obs_as_tensor(observation, self.device)
|
||||
return obs_tensor, vectorized_env
|
||||
|
||||
|
||||
class BasePolicy(BaseModel, ABC):
|
||||
|
|
@ -308,7 +314,7 @@ class BasePolicy(BaseModel, ABC):
|
|||
module.bias.data.fill_(0.0)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
|
|
@ -354,27 +360,28 @@ class BasePolicy(BaseModel, ABC):
|
|||
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
||||
)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic)
|
||||
actions = self._predict(obs_tensor, deterministic=deterministic)
|
||||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))
|
||||
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
||||
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
||||
|
||||
# Remove batch dimension if needed
|
||||
if not vectorized_env:
|
||||
assert isinstance(actions, np.ndarray)
|
||||
actions = actions.squeeze(axis=0)
|
||||
|
||||
return actions, state
|
||||
return actions, state # type: ignore[return-value]
|
||||
|
||||
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
|
|
@ -384,6 +391,9 @@ class BasePolicy(BaseModel, ABC):
|
|||
:param action: Action to scale
|
||||
:return: Scaled action
|
||||
"""
|
||||
assert isinstance(
|
||||
self.action_space, spaces.Box
|
||||
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return 2.0 * ((action - low) / (high - low)) - 1.0
|
||||
|
||||
|
|
@ -394,6 +404,9 @@ class BasePolicy(BaseModel, ABC):
|
|||
|
||||
:param scaled_action: Action to un-scale
|
||||
"""
|
||||
assert isinstance(
|
||||
self.action_space, spaces.Box
|
||||
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
|
|
@ -522,7 +535,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
data = super()._get_constructor_parameters()
|
||||
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
|
||||
|
||||
data.update(
|
||||
dict(
|
||||
|
|
@ -616,7 +629,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
|
|
@ -639,11 +652,11 @@ class ActorCriticPolicy(BasePolicy):
|
|||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
actions = actions.reshape((-1, *self.action_space.shape))
|
||||
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
||||
return actions, values, log_prob
|
||||
|
||||
def extract_features(
|
||||
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
def extract_features( # type: ignore[override]
|
||||
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -691,7 +704,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
|
|
@ -701,7 +714,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
"""
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
|
@ -725,7 +738,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
entropy = distribution.entropy()
|
||||
return values, log_prob, entropy
|
||||
|
||||
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||
def get_distribution(self, obs: PyTorchObs) -> Distribution:
|
||||
"""
|
||||
Get the current policy distribution given the observations.
|
||||
|
||||
|
|
@ -736,7 +749,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
return self._get_action_dist_from_latent(latent_pi)
|
||||
|
||||
def predict_values(self, obs: th.Tensor) -> th.Tensor:
|
||||
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
||||
"""
|
||||
Get the estimated values according to the current policy given the observations.
|
||||
|
||||
|
|
@ -921,6 +934,8 @@ class ContinuousCritic(BaseModel):
|
|||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
features_extractor: BaseFeaturesExtractor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
|
|
@ -944,10 +959,10 @@ class ContinuousCritic(BaseModel):
|
|||
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.n_critics = n_critics
|
||||
self.q_networks = []
|
||||
self.q_networks: List[nn.Module] = []
|
||||
for idx in range(n_critics):
|
||||
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net)
|
||||
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net_list)
|
||||
self.add_module(f"qf{idx}", q_net)
|
||||
self.q_networks.append(q_net)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->
|
|||
|
||||
|
||||
def preprocess_obs(
|
||||
obs: th.Tensor,
|
||||
obs: Union[th.Tensor, Dict[str, th.Tensor]],
|
||||
observation_space: spaces.Space,
|
||||
normalize_images: bool = True,
|
||||
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
|
||||
|
|
@ -105,6 +105,16 @@ def preprocess_obs(
|
|||
(True by default)
|
||||
:return:
|
||||
"""
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
# Do not modify by reference the original observation
|
||||
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
|
||||
preprocessed_obs = {}
|
||||
for key, _obs in obs.items():
|
||||
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
|
||||
return preprocessed_obs # type: ignore[return-value]
|
||||
|
||||
assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"
|
||||
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
if normalize_images and is_image_space(observation_space):
|
||||
return obs.float() / 255.0
|
||||
|
|
@ -126,15 +136,6 @@ def preprocess_obs(
|
|||
|
||||
elif isinstance(observation_space, spaces.MultiBinary):
|
||||
return obs.float()
|
||||
|
||||
elif isinstance(observation_space, spaces.Dict):
|
||||
# Do not modify by reference the original observation
|
||||
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
|
||||
preprocessed_obs = {}
|
||||
for key, _obs in obs.items():
|
||||
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
|
||||
return preprocessed_obs
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
|
|||
TensorDict = Dict[str, th.Tensor]
|
||||
OptimizerStateDict = Dict[str, Any]
|
||||
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
|
||||
PyTorchObs = Union[th.Tensor, TensorDict]
|
||||
|
||||
# A schedule takes the remaining progress as input
|
||||
# and ouputs a scalar (e.g. learning rate, clip range, ...)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
NatureCNN,
|
||||
create_mlp,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
|
||||
|
||||
class QNetwork(BasePolicy):
|
||||
|
|
@ -56,7 +56,7 @@ class QNetwork(BasePolicy):
|
|||
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
|
||||
self.q_net = nn.Sequential(*q_net)
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs) -> th.Tensor:
|
||||
"""
|
||||
Predict the q-values.
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ class QNetwork(BasePolicy):
|
|||
"""
|
||||
return self.q_net(self.extract_features(obs, self.features_extractor))
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
q_values = self(observation)
|
||||
# Greedy action
|
||||
action = q_values.argmax(dim=1).reshape(-1)
|
||||
|
|
@ -177,10 +177,10 @@ class DQNPolicy(BasePolicy):
|
|||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||
return QNetwork(**net_args).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
|
||||
return self.q_net._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
LOG_STD_MAX = 2
|
||||
|
|
@ -144,7 +144,7 @@ class Actor(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
"""
|
||||
Get the parameters for the action distribution.
|
||||
|
||||
|
|
@ -164,17 +164,17 @@ class Actor(BasePolicy):
|
|||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
return mean_actions, log_std, {}
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# return action and associated log prob
|
||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self(observation, deterministic)
|
||||
|
||||
|
||||
|
|
@ -346,10 +346,10 @@ class SACPolicy(BasePolicy):
|
|||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
def set_training_mode(self, mode: bool) -> None:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from stable_baselines3.common.torch_layers import (
|
|||
create_mlp,
|
||||
get_actor_critic_arch,
|
||||
)
|
||||
from stable_baselines3.common.type_aliases import Schedule
|
||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
||||
|
||||
|
||||
class Actor(BasePolicy):
|
||||
|
|
@ -77,7 +77,7 @@ class Actor(BasePolicy):
|
|||
features = self.extract_features(obs, self.features_extractor)
|
||||
return self.mu(features)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
|
||||
# Predictions are always deterministic.
|
||||
return self(observation)
|
||||
|
|
@ -233,10 +233,10 @@ class TD3Policy(BasePolicy):
|
|||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def forward(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(observation, deterministic=deterministic)
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
||||
# Note: the deterministic deterministic parameter is ignored in the case of TD3.
|
||||
# Predictions are always deterministic.
|
||||
return self.actor(observation)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a10
|
||||
2.2.0a11
|
||||
|
|
|
|||
Loading…
Reference in a new issue