mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Fix typing, key error
This commit is contained in:
parent
e9d8e05cc8
commit
e61d34a6f0
3 changed files with 7 additions and 4 deletions
|
|
@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
self.action_space, self.device,
|
||||
optimize_memory_usage=self.optimize_memory_usage)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
**self.policy_kwargs) # pytype:disable=not-instantiable
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(self,
|
||||
|
|
|
|||
|
|
@ -111,15 +111,18 @@ class BasePolicy(nn.Module, ABC):
|
|||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
By default provides a dummy implementation -- not all BasePolicy classes
|
||||
implement this, e.g. if they are a Critic in an Actor-Critic method.
|
||||
|
||||
:param observation: (th.Tensor)
|
||||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self,
|
||||
observation: np.ndarray,
|
||||
|
|
@ -370,7 +373,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict()
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
|
|
|
|||
Loading…
Reference in a new issue