From 83530560b54e9f41108b175bf9934f4cdff9d754 Mon Sep 17 00:00:00 2001 From: "Steven H. Wang" Date: Tue, 21 Jul 2020 01:12:39 -0700 Subject: [PATCH 1/2] Fix CloudpickleWrapper load (#118) * CloudpickleWrapper: Load using cloudpickle * Update changelog --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/vec_env/base_vec_env.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 722e9e7..9d4a1c2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -30,6 +30,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended) - Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states +- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang) Deprecations: ^^^^^^^^^^^^^ @@ -355,4 +356,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis @joeljosephjin +@tirafesi @blurLake @koulakis @joeljosephjin @shwang diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 6b5a421..6338a30 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,5 +1,4 @@ import inspect -import pickle from abc import ABC, abstractmethod from typing import List, Optional, Sequence, Union @@ -349,7 +348,7 @@ class VecEnvWrapper(VecEnv): return shadowed_wrapper_class -class CloudpickleWrapper(object): +class CloudpickleWrapper: def __init__(self, var): """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) @@ -362,4 +361,4 @@ class CloudpickleWrapper(object): return cloudpickle.dumps(self.var) def __setstate__(self, obs): - self.var = pickle.loads(obs) + self.var = cloudpickle.loads(obs) From bd2aae0c27d238a6a2c2f3a4c54066dfb30bd49b Mon Sep 17 00:00:00 2001 From: rk37 Date: Sun, 26 Jul 2020 04:35:48 +0800 Subject: [PATCH 2/2] Fix ortho init when `bias=False` with custom policy (#126) * Update policies.py fix AttributeError occurred when use "bias=False" linear layer in custom FeaturesExtractor #124 * Update changelog.rst update the changelog accordingly * Update changelog.rst Co-authored-by: Kong Lingchao Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/policies.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9d4a1c2..be9a7f3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -31,6 +31,7 @@ Bug Fixes: - Fixed a bug in the ``close()`` method of ``SubprocVecEnv``, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended) - Fix target for updating q values in SAC: the entropy term was not conditioned by terminals states - Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang) +- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) Deprecations: ^^^^^^^^^^^^^ @@ -356,4 +357,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis @joeljosephjin @shwang +@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index acadd88..01b788f 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -188,7 +188,8 @@ class BasePolicy(BaseModel): """ if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.orthogonal_(module.weight, gain=gain) - module.bias.data.fill_(0.0) + if module.bias is not None: + module.bias.data.fill_(0.0) @abstractmethod def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: