From d64bcb401ad7d45799af1feee5c1058943be23f0 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Tue, 21 Jun 2022 22:58:02 +0300 Subject: [PATCH] Fix exception cause in base_class.py (#940) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/base_class.py | 4 ++-- stable_baselines3/common/callbacks.py | 4 ++-- stable_baselines3/common/env_checker.py | 4 ++-- stable_baselines3/common/noise.py | 4 ++-- stable_baselines3/common/off_policy_algorithm.py | 6 ++++-- stable_baselines3/common/save_util.py | 8 ++++---- stable_baselines3/her/her_replay_buffer.py | 4 ++-- 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 480b08e..52bf3e4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,7 @@ Bug Fixes: - Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled - Added a check for unbounded actions - Fixed issues due to newer version of protobuf (tensorboard) and sphinx +- Fix exception causes all over the codebase (@cool-RR) Deprecations: ^^^^^^^^^^^^^ @@ -978,4 +979,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e16814c..36a73b8 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -628,11 +628,11 @@ class BaseAlgorithm(ABC): attr = None try: attr = recursive_getattr(self, name) - except Exception: + except Exception as e: # What errors recursive_getattr could throw? KeyError, but # possible something else too (e.g. if key is an int?). # Catch anything for now. - raise ValueError(f"Key {name} is an invalid object name.") + raise ValueError(f"Key {name} is an invalid object name.") from e if isinstance(attr, th.optim.Optimizer): # Optimizers do not support "strict" keyword... diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c5f297c..e9f46fe 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -380,12 +380,12 @@ class EvalCallback(EventCallback): if self.model.get_vec_normalize_env() is not None: try: sync_envs_normalization(self.training_env, self.eval_env) - except AttributeError: + except AttributeError as e: raise AssertionError( "Training and eval env are not wrapped the same way, " "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " "and warning above." - ) + ) from e # Reset success rate buffer self._is_success_buffer = [] diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index ed07e7e..3b2c502 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "reset") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "reset") @@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "step") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "step") diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 119ed36..baa72e9 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise): try: self.n_envs = int(n_envs) assert self.n_envs > 0 - except (TypeError, AssertionError): - raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") + except (TypeError, AssertionError) as e: + raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e self.base_noise = base_noise self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index ca57166..99a02ff 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -157,8 +157,10 @@ class OffPolicyAlgorithm(BaseAlgorithm): try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) - except ValueError: - raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") + except ValueError as e: + raise ValueError( + f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" + ) from e if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index e0b104f..1569001 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb mode = mode.lower() try: mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] - except KeyError: - raise ValueError("Expected mode to be either 'w' or 'r'.") + except KeyError as e: + raise ValueError("Expected mode to be either 'w' or 'r'.") from e if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): e1 = "writable" if "w" == mode else "readable" raise ValueError(f"Expected a {e1} file.") @@ -441,7 +441,7 @@ def load_from_zip_file( # State dicts. Store into params dictionary # with same name as in .zip file (without .pth) params[os.path.splitext(file_path)[0]] = th_object - except zipfile.BadZipFile: + except zipfile.BadZipFile as e: # load_path wasn't a zip file - raise ValueError(f"Error: the file {load_path} wasn't a zip-file") + raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e return data, params, pytorch_variables diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index c461d19..3c19aac 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in if current_max_episode_length is None: raise AttributeError # if not available check if a valid value was passed as an argument - except AttributeError: + except AttributeError as e: raise ValueError( "The max episode length could not be inferred.\n" "You must specify a `max_episode_steps` when registering the environment,\n" "use a `gym.wrappers.TimeLimit` wrapper " "or pass `max_episode_length` to the model constructor" - ) + ) from e return current_max_episode_length