mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Fix gSDE loading issue in test mode (#45)
* Fix gSDE loading issue in test mode * Forward `reset_noise` method * Re-add `make_actor` * Reformat
This commit is contained in:
parent
353ea81080
commit
11d33eb4ae
4 changed files with 19 additions and 6 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Pre-Release 0.7.0a0 (WIP)
|
||||
Pre-Release 0.7.0a1 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -18,6 +18,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed ``render()`` method for ``VecEnvs``
|
||||
- Fixed ``seed()``` method for ``SubprocVecEnv``
|
||||
- Fixed loading on GPU for testing when using gSDE and ``deterministic=False``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -375,6 +375,10 @@ class BaseRLModel(ABC):
|
|||
for name in tensors:
|
||||
recursive_setattr(model, name, tensors[name])
|
||||
|
||||
# Sample gSDE exploration matrix, so it uses the right device
|
||||
# see issue #44
|
||||
if model.use_sde:
|
||||
model.policy.reset_noise()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -128,8 +128,8 @@ class Actor(BasePolicy):
|
|||
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
|
||||
'get_std() is only available when using gSDE'
|
||||
msg = 'get_std() is only available when using gSDE'
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
return self.action_dist.get_std(self.log_std)
|
||||
|
||||
def reset_noise(self, batch_size: int = 1) -> None:
|
||||
|
|
@ -138,8 +138,8 @@ class Actor(BasePolicy):
|
|||
|
||||
:param batch_size: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
|
||||
'reset_noise() is only available when using gSDE'
|
||||
msg = 'reset_noise() is only available when using gSDE'
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
|
|
@ -354,6 +354,14 @@ class SACPolicy(BasePolicy):
|
|||
))
|
||||
return data
|
||||
|
||||
def reset_noise(self, batch_size: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix, when using gSDE.
|
||||
|
||||
:param batch_size: (int)
|
||||
"""
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.7.0a0
|
||||
0.7.0a1
|
||||
|
|
|
|||
Loading…
Reference in a new issue