Fix exception cause in base_class.py (#940)

This commit is contained in:
Ram Rachum 2022-06-21 22:58:02 +03:00 committed by GitHub
parent 7ce7b6a8c2
commit d64bcb401a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 20 additions and 17 deletions

View file

@ -32,6 +32,7 @@ Bug Fixes:
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
- Added a check for unbounded actions
- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
- Fix exception causes all over the codebase (@cool-RR)
Deprecations:
^^^^^^^^^^^^^
@ -978,4 +979,4 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR

View file

@ -628,11 +628,11 @@ class BaseAlgorithm(ABC):
attr = None
try:
attr = recursive_getattr(self, name)
except Exception:
except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError(f"Key {name} is an invalid object name.")
raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...

View file

@ -380,12 +380,12 @@ class EvalCallback(EventCallback):
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError:
except AttributeError as e:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
)
) from e
# Reset success rate buffer
self._is_success_buffer = []

View file

@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")

View file

@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise):
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
except (TypeError, AssertionError):
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
except (TypeError, AssertionError) as e:
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]

View file

@ -157,8 +157,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
) from e
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

View file

@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError:
raise ValueError("Expected mode to be either 'w' or 'r'.")
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
e1 = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {e1} file.")
@ -441,7 +441,7 @@ def load_from_zip_file(
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile:
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
return data, params, pytorch_variables

View file

@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in
if current_max_episode_length is None:
raise AttributeError
# if not available check if a valid value was passed as an argument
except AttributeError:
except AttributeError as e:
raise ValueError(
"The max episode length could not be inferred.\n"
"You must specify a `max_episode_steps` when registering the environment,\n"
"use a `gym.wrappers.TimeLimit` wrapper "
"or pass `max_episode_length` to the model constructor"
)
) from e
return current_max_episode_length