Raise error for abstract methods

This commit is contained in:
Antonin Raffin 2020-01-20 12:57:40 +01:00
parent e5c6601726
commit 0bed698ec5
2 changed files with 9 additions and 9 deletions

View file

@ -276,7 +276,7 @@ class BaseRLModel(object):
:param n_eval_episodes: (int) Number of episode to evaluate the agent
:return: (BaseRLModel) the trained model
"""
pass
raise NotImplementedError()
@abstractmethod
def predict(self, observation, state=None, mask=None, deterministic=False):
@ -289,7 +289,7 @@ class BaseRLModel(object):
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
pass
raise NotImplementedError()
def load_parameters(self, load_dict, opt_params):
"""

View file

@ -58,7 +58,7 @@ class VecEnv(object):
:return: ([int] or [float]) observation
"""
pass
raise NotImplementedError()
@abstractmethod
def step_async(self, actions):
@ -70,7 +70,7 @@ class VecEnv(object):
You should not call this if a step_async run is
already pending.
"""
pass
raise NotImplementedError()
@abstractmethod
def step_wait(self):
@ -79,14 +79,14 @@ class VecEnv(object):
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
"""
pass
raise NotImplementedError()
@abstractmethod
def close(self):
"""
Clean up the environment's resources.
"""
pass
raise NotImplementedError()
@abstractmethod
def get_attr(self, attr_name, indices=None):
@ -97,7 +97,7 @@ class VecEnv(object):
:param indices: (list,int) Indices of envs to get attribute from
:return: (list) List of values of 'attr_name' in all environments
"""
pass
raise NotImplementedError()
@abstractmethod
def set_attr(self, attr_name, value, indices=None):
@ -109,7 +109,7 @@ class VecEnv(object):
:param indices: (list,int) Indices of envs to assign value
:return: (NoneType)
"""
pass
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name, *method_args, **method_kwargs):
@ -122,7 +122,7 @@ class VecEnv(object):
:param method_kwargs: (dict) Any keyword arguments to provide in the call
:return: (list) List of items returned by the environment's method call
"""
pass
raise NotImplementedError()
def step(self, actions):
"""