Fix typing, key error

This commit is contained in:
Adam Gleave 2020-07-02 21:35:06 -07:00
parent e9d8e05cc8
commit e61d34a6f0
3 changed files with 7 additions and 4 deletions

View file

@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self.action_space, self.device,
optimize_memory_usage=self.optimize_memory_usage)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, **self.policy_kwargs)
self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable
self.policy = self.policy.to(self.device)
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:

View file

@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
n_envs=self.n_envs)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, use_sde=self.use_sde, device=self.device,
**self.policy_kwargs)
**self.policy_kwargs) # pytype:disable=not-instantiable
self.policy = self.policy.to(self.device)
def collect_rollouts(self,

View file

@ -111,15 +111,18 @@ class BasePolicy(nn.Module, ABC):
def forward(self, *args, **kwargs):
del args, kwargs
@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
By default provides a dummy implementation -- not all BasePolicy classes
implement this, e.g. if they are a Critic in an Actor-Critic method.
:param observation: (th.Tensor)
:param deterministic: (bool) Whether to use stochastic or deterministic actions
:return: (th.Tensor) Taken action according to the policy
"""
raise NotImplementedError()
def predict(self,
observation: np.ndarray,
@ -370,7 +373,7 @@ class ActorCriticPolicy(BasePolicy):
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
default_none_kwargs = self.dist_kwargs or collections.defaultdict()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
data.update(dict(
net_arch=self.net_arch,