diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index f459de9..b7d92ad 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage) self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, **self.policy_kwargs) + self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 9090d78..2937b77 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): n_envs=self.n_envs) self.policy = self.policy_class(self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, device=self.device, - **self.policy_kwargs) + **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts(self, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 48e066c..2aeffd9 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -111,15 +111,18 @@ class BasePolicy(nn.Module, ABC): def forward(self, *args, **kwargs): del args, kwargs - @abstractmethod def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + :param observation: (th.Tensor) :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ + raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -370,7 +373,7 @@ class ActorCriticPolicy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - default_none_kwargs = self.dist_kwargs or collections.defaultdict() + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) data.update(dict( net_arch=self.net_arch,