From 3c028f3d5ccbd493eae70efca281ef0dd2ffb999 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 22 Dec 2022 17:28:18 +0100 Subject: [PATCH] Fix `load_from_tensor` (#1231) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/policies.py | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b6efffc..bab86a4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a8 (WIP) +Release 1.7.0a9 (WIP) -------------------------- Breaking Changes: @@ -39,6 +39,7 @@ Bug Fixes: - Fixed ``Self`` return type using ``TypeVar`` - Fixed the env checker, the key was not passed when checking images from Dict observation space - Fixed ``normalize_images`` which was not passed to parent class in some cases +- Fixed ``load_from_vector`` that was broken with newer PyTorch version when passing PyTorch tensor Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 679ad92..7331815 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -199,7 +199,7 @@ class BaseModel(nn.Module): :param vector: """ - th.nn.utils.vector_to_parameters(th.FloatTensor(vector, device=self.device), self.parameters()) + th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters()) def parameters_to_vector(self) -> np.ndarray: """ diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index bee32dc..82d01f9 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a8 +1.7.0a9