Fix gSDE loading issue in test mode (#45)

* Fix gSDE loading issue in test mode

* Forward `reset_noise` method

* Re-add `make_actor`

* Reformat
This commit is contained in:
Antonin RAFFIN 2020-06-08 11:15:10 +02:00 committed by GitHub
parent 353ea81080
commit 11d33eb4ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 6 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Pre-Release 0.7.0a0 (WIP)
Pre-Release 0.7.0a1 (WIP)
------------------------------
Breaking Changes:
@ -18,6 +18,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed ``render()`` method for ``VecEnvs``
- Fixed ``seed()``` method for ``SubprocVecEnv``
- Fixed loading on GPU for testing when using gSDE and ``deterministic=False``
Deprecations:
^^^^^^^^^^^^^

View file

@ -375,6 +375,10 @@ class BaseRLModel(ABC):
for name in tensors:
recursive_setattr(model, name, tensors[name])
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise()
return model
@staticmethod

View file

@ -128,8 +128,8 @@ class Actor(BasePolicy):
:return: (th.Tensor)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'get_std() is only available when using gSDE'
msg = 'get_std() is only available when using gSDE'
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
@ -138,8 +138,8 @@ class Actor(BasePolicy):
:param batch_size: (int)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'reset_noise() is only available when using gSDE'
msg = 'reset_noise() is only available when using gSDE'
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
@ -354,6 +354,14 @@ class SACPolicy(BasePolicy):
))
return data
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int)
"""
self.actor.reset_noise(batch_size=batch_size)
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)

View file

@ -1 +1 @@
0.7.0a0
0.7.0a1