mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Fix load_from_tensor (#1231)
This commit is contained in:
parent
5549b34231
commit
3c028f3d5c
3 changed files with 4 additions and 3 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a8 (WIP)
|
||||
Release 1.7.0a9 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -39,6 +39,7 @@ Bug Fixes:
|
|||
- Fixed ``Self`` return type using ``TypeVar``
|
||||
- Fixed the env checker, the key was not passed when checking images from Dict observation space
|
||||
- Fixed ``normalize_images`` which was not passed to parent class in some cases
|
||||
- Fixed ``load_from_vector`` that was broken with newer PyTorch version when passing PyTorch tensor
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ class BaseModel(nn.Module):
|
|||
|
||||
:param vector:
|
||||
"""
|
||||
th.nn.utils.vector_to_parameters(th.FloatTensor(vector, device=self.device), self.parameters())
|
||||
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
|
||||
|
||||
def parameters_to_vector(self) -> np.ndarray:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a8
|
||||
1.7.0a9
|
||||
|
|
|
|||
Loading…
Reference in a new issue