diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0ca6fa9..a65ec1d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ New Features: when ``psutil`` is available - Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped) - Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped) +- Introduced ``BaseModel`` abstract parent for ``BasePolicy``, which critics inherit from. Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index f66220e..4e51367 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -320,8 +320,8 @@ class BaseAlgorithm(ABC): f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}") # check if observation space and action space are part of the saved parameters - if ("observation_space" not in data or "action_space" not in data) and "env" not in data: - raise ValueError("The observation_space and action_space was not given, can't verify new environments") + if "observation_space" not in data or "action_space" not in data: + raise KeyError("The observation_space and action_space were not given, can't verify new environments") # check if given env is valid if env is not None: check_for_correct_spaces(env, data["observation_space"], data["action_space"]) @@ -425,8 +425,10 @@ class BaseAlgorithm(ABC): :return: (Tuple[int, BaseCallback]) """ self.start_time = time.time() - self.ep_info_buffer = deque(maxlen=100) - self.ep_success_buffer = deque(maxlen=100) + if self.ep_info_buffer is None or reset_num_timesteps: + # Initialize buffers if they don't exist, or reinitialize if resetting counters + self.ep_info_buffer = deque(maxlen=100) + self.ep_success_buffer = deque(maxlen=100) if self.action_noise is not None: self.action_noise.reset() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 6fa70f8..1a20928 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -17,6 +17,20 @@ class Distribution(ABC): def __init__(self): super(Distribution, self).__init__() + @abstractmethod + def proba_distribution_net(self, *args, **kwargs): + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + + @abstractmethod + def proba_distribution(self, *args, **kwargs) -> 'Distribution': + """Set parameters of the distribution. + + :return: (Distribution) self + """ + @abstractmethod def log_prob(self, x: th.Tensor) -> th.Tensor: """ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 08b0ab2..bc387c2 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -21,9 +21,12 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis StateDependentNoiseDistribution) -class BasePolicy(nn.Module, ABC): +class BaseModel(nn.Module, ABC): """ - The base policy object + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. :param observation_space: (gym.spaces.Space) The observation space of the environment :param action_space: (gym.spaces.Space) The action space of the environment @@ -39,8 +42,6 @@ class BasePolicy(nn.Module, ABC): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param squash_output: (bool) For continuous actions, whether the output is squashed - or not using a ``tanh()`` function. """ def __init__(self, @@ -52,9 +53,8 @@ class BasePolicy(nn.Module, ABC): features_extractor: Optional[nn.Module] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - squash_output: bool = False): - super(BasePolicy, self).__init__() + optimizer_kwargs: Optional[Dict[str, Any]] = None): + super(BaseModel, self).__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -67,7 +67,6 @@ class BasePolicy(nn.Module, ABC): self.device = get_device(device) self.features_extractor = features_extractor self.normalize_images = normalize_images - self._squash_output = squash_output self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs @@ -76,6 +75,10 @@ class BasePolicy(nn.Module, ABC): self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs + @abstractmethod + def forward(self, *args, **kwargs): + del args, kwargs + def extract_features(self, obs: th.Tensor) -> th.Tensor: """ Preprocess the observation if needed and extract features. @@ -87,9 +90,88 @@ class BasePolicy(nn.Module, ABC): preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) + def _get_data(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model. + This corresponds to the arguments of the constructor. + + :return: (Dict[str, Any]) + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: (str) + """ + th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) + + @classmethod + def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel': + """ + Load model from path. + + :param path: (str) + :param device: (Union[th.device, str]) Device on which the policy should be loaded. + :return: (BasePolicy) + """ + device = get_device(device) + saved_variables = th.load(path, map_location=device) + # Create policy object + model = cls(**saved_variables['data']) + # Load weights + model.load_state_dict(saved_variables['state_dict']) + model.to(device) + return model + + def load_from_vector(self, vector: np.ndarray): + """ + Load parameters from a 1D vector. + + :param vector: (np.ndarray) + """ + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: (np.ndarray) + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + +class BasePolicy(BaseModel): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: (bool) For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + def __init__(self, *args, squash_output: bool = False, **kwargs): + super(BasePolicy, self).__init__(*args, **kwargs) + self._squash_output = squash_output + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """ (float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + @property def squash_output(self) -> bool: - """ (bool) Getter for squash_output.""" + """(bool) Getter for squash_output.""" return self._squash_output @staticmethod @@ -101,16 +183,7 @@ class BasePolicy(nn.Module, ABC): nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) - @staticmethod - def _dummy_schedule(progress_remaining: float) -> float: - """ (float) Useful for pickling policy.""" - del progress_remaining - return 0.0 - @abstractmethod - def forward(self, *args, **kwargs): - del args, kwargs - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -122,7 +195,6 @@ class BasePolicy(nn.Module, ABC): :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ - raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -140,6 +212,7 @@ class BasePolicy(nn.Module, ABC): :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state (used in recurrent policies) """ + # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state # if mask is None: @@ -204,64 +277,6 @@ class BasePolicy(nn.Module, ABC): low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) - def _get_data(self) -> Dict[str, Any]: - """ - Get data that need to be saved in order to re-create the policy. - This corresponds to the arguments of the constructor. - - :return: (Dict[str, Any]) - """ - return dict( - observation_space=self.observation_space, - action_space=self.action_space, - # Passed to the constructor by child class - # squash_output=self.squash_output, - # features_extractor=self.features_extractor - normalize_images=self.normalize_images, - ) - - def save(self, path: str) -> None: - """ - Save policy to a given location. - - :param path: (str) - """ - th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) - - @classmethod - def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': - """ - Load policy from path. - - :param path: (str) - :param device: (Union[th.device, str]) Device on which the policy should be loaded. - :return: (BasePolicy) - """ - device = get_device(device) - saved_variables = th.load(path, map_location=device) - # Create policy object - model = cls(**saved_variables['data']) - # Load weights - model.load_state_dict(saved_variables['state_dict']) - model.to(device) - return model - - def load_from_vector(self, vector: np.ndarray): - """ - Load parameters from a 1D vector. - - :param vector: (np.ndarray) - """ - th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) - - def parameters_to_vector(self) -> np.ndarray: - """ - Convert the parameters to a 1D vector. - - :return: (np.ndarray) - """ - return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() - class ActorCriticPolicy(BasePolicy): """ @@ -438,6 +453,8 @@ class ActorCriticPolicy(BasePolicy): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) elif isinstance(self.action_dist, BernoulliDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) # Init weights: use orthogonal initialization @@ -626,7 +643,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): optimizer_kwargs) -class ContinuousCritic(BasePolicy): +class ContinuousCritic(BaseModel): """ Critic network(s) for DDPG/SAC/TD3. It represents the action-state value function (Q-value function).