mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-17 01:45:03 +00:00
Merge branch 'master' into sde
This commit is contained in:
commit
c1f30acd6f
6 changed files with 396 additions and 96 deletions
|
|
@ -17,6 +17,8 @@ New Features:
|
|||
- Buffer dtype is now set according to action and observation spaces for ``ReplayBuffer``
|
||||
- Added warning when allocation of a buffer may exceed the available memory of the system
|
||||
when ``psutil`` is available
|
||||
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
|
||||
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -37,6 +39,7 @@ Documentation:
|
|||
^^^^^^^^^^^^^^
|
||||
- Updated notebook links
|
||||
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
|
||||
- Added Unity reacher to the projects page (@koulakis)
|
||||
|
||||
|
||||
|
||||
|
|
@ -337,4 +340,4 @@ And all the contributors:
|
|||
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
|
||||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake
|
||||
@tirafesi @blurLake @koulakis
|
||||
|
|
|
|||
|
|
@ -25,3 +25,15 @@ It was the starting point of Stable-Baselines3.
|
|||
| Author: Antonin Raffin, Freek Stulp
|
||||
| Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
|
||||
| Paper: https://arxiv.org/abs/2005.05719
|
||||
|
||||
Reacher
|
||||
-------
|
||||
A solution to the second project of the Udacity deep reinforcement learning course.
|
||||
It is an example of:
|
||||
|
||||
- wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3
|
||||
- creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is https://github.com/DLR-RM/rl-baselines3-zoo)
|
||||
- generating several pre-trained models which solve the reacher environment
|
||||
|
||||
| Author: Marios Koulakis
|
||||
| Github: https://github.com/koulakis/reacher-deep-reinforcement-learning
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import time
|
|||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
import pathlib
|
||||
import io
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -291,7 +293,7 @@ class BaseAlgorithm(ABC):
|
|||
return self.policy.predict(observation, state, mask, deterministic)
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
|
||||
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm':
|
||||
"""
|
||||
Load the model from a zip-file
|
||||
|
||||
|
|
@ -475,11 +477,16 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
|
||||
|
||||
def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None:
|
||||
def save(
|
||||
self,
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save all the attributes of the object and the model parameters in a zip-file.
|
||||
|
||||
:param path: path to the file where the rl agent should be saved
|
||||
:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
|
||||
:param exclude: name of parameters that should be excluded in addition to the default one
|
||||
:param include: name of parameters that might be excluded but should be included anyway
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import time
|
||||
import pickle
|
||||
import warnings
|
||||
import pathlib
|
||||
from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple
|
||||
import io
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -16,6 +17,7 @@ from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn, MaybeCa
|
|||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl
|
||||
|
||||
|
||||
class OffPolicyAlgorithm(BaseAlgorithm):
|
||||
|
|
@ -126,7 +128,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# For gSDE only
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _setup_model(self):
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
|
||||
|
|
@ -136,24 +138,23 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: str):
|
||||
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: (str) Path to the file where the replay buffer should be saved
|
||||
:param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved.
|
||||
if path is a str or pathlib.Path, the path is automatically created if necessary.
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(path, 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
save_to_pkl(path, self.replay_buffer, self.verbose)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
"""
|
||||
Load a replay buffer from a pickle file.
|
||||
|
||||
:param path: (str) Path to the pickled replay buffer.
|
||||
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.
|
||||
"""
|
||||
with open(path, 'rb') as file_handler:
|
||||
self.replay_buffer = pickle.load(file_handler)
|
||||
self.replay_buffer = load_from_pkl(path, self.verbose)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def _setup_learn(self,
|
||||
|
|
|
|||
|
|
@ -2,14 +2,16 @@
|
|||
Save util taken from stable_baselines
|
||||
used to serialize data (class parameters) of model classes
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import functools
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from typing import Dict, Any, Tuple, Optional, Union
|
||||
import warnings
|
||||
import zipfile
|
||||
import pathlib
|
||||
import pickle
|
||||
|
||||
import torch as th
|
||||
import cloudpickle
|
||||
|
|
@ -30,10 +32,11 @@ def recursive_getattr(obj: Any, attr: str, *args) -> Any:
|
|||
:param attr: (str) Attribute to retrieve
|
||||
:return: (Any) The attribute
|
||||
"""
|
||||
|
||||
def _getattr(obj: Any, attr: str) -> Any:
|
||||
return getattr(obj, attr, *args)
|
||||
|
||||
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
||||
return functools.reduce(_getattr, [obj] + attr.split("."))
|
||||
|
||||
|
||||
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
|
||||
|
|
@ -48,7 +51,7 @@ def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
|
|||
:param attr: (str) Attribute to set
|
||||
:param val: (Any) New value of the attribute
|
||||
"""
|
||||
pre, _, post = attr.rpartition('.')
|
||||
pre, _, post = attr.rpartition(".")
|
||||
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
|
||||
|
||||
|
||||
|
|
@ -92,16 +95,14 @@ def data_to_json(data: Dict[str, Any]) -> str:
|
|||
# Also store type of the class for consumption
|
||||
# from other languages/humans, so we have an
|
||||
# idea what was being stored.
|
||||
base64_encoded = base64.b64encode(
|
||||
cloudpickle.dumps(data_item)
|
||||
).decode()
|
||||
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
|
||||
|
||||
# Use ":" to make sure we do
|
||||
# not override these keys
|
||||
# when we include variables of the object later
|
||||
cloudpickle_serialization = {
|
||||
":type:": str(type(data_item)),
|
||||
":serialized:": base64_encoded
|
||||
":serialized:": base64_encoded,
|
||||
}
|
||||
|
||||
# Add first-level JSON-serializable items of the
|
||||
|
|
@ -112,7 +113,9 @@ def data_to_json(data: Dict[str, Any]) -> str:
|
|||
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
|
||||
# Take elements from __dict__ for custom classes
|
||||
item_generator = (
|
||||
data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
|
||||
data_item.items
|
||||
if isinstance(data_item, dict)
|
||||
else data_item.__dict__.items
|
||||
)
|
||||
for variable_name, variable_item in item_generator():
|
||||
# Check if serializable. If not, just include the
|
||||
|
|
@ -127,8 +130,9 @@ def data_to_json(data: Dict[str, Any]) -> str:
|
|||
return json_string
|
||||
|
||||
|
||||
def json_to_data(json_string: str,
|
||||
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
def json_to_data(
|
||||
json_string: str, custom_objects: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Turn JSON serialization of class-parameters back into dictionary.
|
||||
|
||||
|
|
@ -164,9 +168,11 @@ def json_to_data(json_string: str,
|
|||
base64_object = base64.b64decode(serialization.encode())
|
||||
deserialized_object = cloudpickle.loads(base64_object)
|
||||
except RuntimeError:
|
||||
warnings.warn(f"Could not deserialize object {data_key}. " +
|
||||
"Consider using `custom_objects` argument to replace " +
|
||||
"this object.")
|
||||
warnings.warn(
|
||||
f"Could not deserialize object {data_key}. "
|
||||
+ "Consider using `custom_objects` argument to replace "
|
||||
+ "this object."
|
||||
)
|
||||
return_data[data_key] = deserialized_object
|
||||
else:
|
||||
# Read as it is
|
||||
|
|
@ -174,71 +180,211 @@ def json_to_data(json_string: str,
|
|||
return return_data
|
||||
|
||||
|
||||
def save_to_zip_file(save_path: str, data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
||||
@functools.singledispatch
|
||||
def open_path(
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None
|
||||
):
|
||||
"""
|
||||
Opens a path for reading or writing with a preferred suffix and raises debug information.
|
||||
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
|
||||
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
|
||||
If the mode is write ("w", "write") it checks that the file is writable.
|
||||
|
||||
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
|
||||
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
|
||||
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
|
||||
points to a folder, it changes the path to path_2. If the path already exists and verbose == 2,
|
||||
it raises a warning.
|
||||
|
||||
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
|
||||
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
||||
path actually exists. If path is a io.BufferedIOBase the path exists.
|
||||
:param mode: (str) how to open the file. "w"|"write" for writing, "r"|"read" for reading.
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
|
||||
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
|
||||
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
||||
is not None, we attempt to open the path with the suffix.
|
||||
"""
|
||||
if not isinstance(path, io.BufferedIOBase):
|
||||
raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
|
||||
if path.closed:
|
||||
raise ValueError("File stream is closed.")
|
||||
mode = mode.lower()
|
||||
try:
|
||||
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
|
||||
except KeyError:
|
||||
raise ValueError("Expected mode to be either 'w' or 'r'.")
|
||||
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
|
||||
e1 = "writable" if "w" == mode else "readable"
|
||||
raise ValueError(f"Expected a {e1} file.")
|
||||
return path
|
||||
|
||||
|
||||
@open_path.register(str)
|
||||
def open_path_str(
|
||||
path: str, mode: str, verbose=0, suffix=None
|
||||
) -> io.BufferedIOBase:
|
||||
"""
|
||||
Open a path given by a string. If writing to the path, the function ensures
|
||||
that the path exists.
|
||||
|
||||
:param path: (str) the path to open. If mode is "w" then it ensures that the path exists
|
||||
by creating the necessary folders and renaming path if it points to a folder.
|
||||
:param mode: (str) how to open the file. "w" for writing, "r" for reading.
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
|
||||
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
|
||||
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
||||
is not None, we attempt to open the path with the suffix.
|
||||
"""
|
||||
return open_path(pathlib.Path(path), mode, verbose, suffix)
|
||||
|
||||
|
||||
@open_path.register(pathlib.Path)
|
||||
def open_path_pathlib(
|
||||
path: pathlib.Path, mode: str, verbose=0, suffix=None
|
||||
) -> io.BufferedIOBase:
|
||||
"""
|
||||
Open a path given by a string. If writing to the path, the function ensures
|
||||
that the path exists.
|
||||
|
||||
:param path: (pathlib.Path) the path to check. If mode is "w" then it
|
||||
ensures that the path exists by creating the necessary folders and
|
||||
renaming path if it points to a folder.
|
||||
:param mode: (str) how to open the file. "w" for writing, "r" for reading.
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
|
||||
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
|
||||
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
||||
is not None, we attempt to open the path with the suffix.
|
||||
"""
|
||||
if mode not in ("w", "r"):
|
||||
raise ValueError("Expected mode to be either 'w' or 'r'.")
|
||||
|
||||
if mode == "r":
|
||||
try:
|
||||
path = path.open("rb")
|
||||
except FileNotFoundError as error:
|
||||
if suffix is not None and suffix != "":
|
||||
newpath = pathlib.Path(f"{path}.{suffix}")
|
||||
if verbose == 2:
|
||||
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
|
||||
path, suffix = newpath, None
|
||||
else:
|
||||
raise error
|
||||
else:
|
||||
try:
|
||||
if path.suffix == "" and suffix is not None and suffix != "":
|
||||
path = pathlib.Path(f"{path}.{suffix}")
|
||||
if path.exists() and path.is_file() and verbose == 2:
|
||||
warnings.warn(f"Path '{path}' exists, will overwrite it.")
|
||||
path = path.open("wb")
|
||||
except IsADirectoryError:
|
||||
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
|
||||
path = pathlib.Path(f"{path}_2")
|
||||
except FileNotFoundError: # Occurs when the parent folder doesn't exist
|
||||
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# if opening was successful uses the identity function
|
||||
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
|
||||
# with corrections
|
||||
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
|
||||
|
||||
return open_path(path, mode, verbose, suffix)
|
||||
|
||||
|
||||
def save_to_zip_file(
|
||||
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None,
|
||||
tensors: Dict[str, Any] = None,
|
||||
verbose=0,
|
||||
) -> None:
|
||||
"""
|
||||
Save a model to a zip archive.
|
||||
|
||||
:param save_path: Where to store the model.
|
||||
:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
|
||||
if save_path is a str or pathlib.Path ensures that the path actually exists.
|
||||
:param data: Class parameters being stored.
|
||||
:param params: Model parameters being stored expected to contain an entry for every
|
||||
state_dict with its name and the state_dict.
|
||||
:param tensors: Extra tensor variables expected to contain name and value of tensors
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
|
||||
"""
|
||||
|
||||
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Check postfix if save_path is a string
|
||||
if isinstance(save_path, str):
|
||||
_, ext = os.path.splitext(save_path)
|
||||
if ext == "":
|
||||
save_path += ".zip"
|
||||
|
||||
# Create a zip-archive and write our objects
|
||||
# there. This works when save_path is either
|
||||
# str or a file-like
|
||||
with zipfile.ZipFile(save_path, "w") as archive:
|
||||
# Create a zip-archive and write our objects there.
|
||||
with zipfile.ZipFile(save_path, mode="w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if tensors is not None:
|
||||
with archive.open('tensors.pth', mode="w") as tensors_file:
|
||||
with archive.open("tensors.pth", mode="w") as tensors_file:
|
||||
th.save(tensors, tensors_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as param_file:
|
||||
with archive.open(file_name + ".pth", mode="w") as param_file:
|
||||
th.save(dict_, param_file)
|
||||
|
||||
|
||||
def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
|
||||
Optional[TensorDict],
|
||||
Optional[TensorDict]]):
|
||||
def save_to_pkl(
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0
|
||||
) -> None:
|
||||
"""
|
||||
Save an object to path creating the necessary folders along the way.
|
||||
If the path exists and is a directory, it will raise a warning and rename the path.
|
||||
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
|
||||
|
||||
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
|
||||
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
||||
path actually exists. If path is a io.BufferedIOBase the path exists.
|
||||
:param obj: The object to save.
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
|
||||
"""
|
||||
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
|
||||
pickle.dump(obj, file_handler)
|
||||
|
||||
|
||||
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) -> Any:
|
||||
"""
|
||||
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
|
||||
If the path does not exist, it will attempt to load using the .pkl suffix.
|
||||
|
||||
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
|
||||
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
||||
path actually exists. If path is a io.BufferedIOBase the path exists.
|
||||
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
|
||||
"""
|
||||
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
|
||||
return pickle.load(file_handler)
|
||||
|
||||
|
||||
def load_from_zip_file(
|
||||
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
load_data: bool = True,
|
||||
verbose=0,
|
||||
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
|
||||
"""
|
||||
Load model data from a .zip archive
|
||||
|
||||
:param load_path: Where to load the model from
|
||||
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
||||
and dict of extra tensors
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
if not os.path.exists(load_path):
|
||||
if os.path.exists(load_path + ".zip"):
|
||||
load_path += ".zip"
|
||||
else:
|
||||
raise ValueError(f"Error: the file {load_path} could not be found")
|
||||
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
device = get_device()
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path, "r") as archive:
|
||||
with zipfile.ZipFile(load_path) as archive:
|
||||
namelist = archive.namelist()
|
||||
# If data or parameters is not in the
|
||||
# zip archive, assume they were stored
|
||||
|
|
@ -254,7 +400,7 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
|
|||
|
||||
if "tensors.pth" in namelist and load_data:
|
||||
# Load extra tensors
|
||||
with archive.open('tensors.pth', mode="r") as tensor_file:
|
||||
with archive.open("tensors.pth", mode="r") as tensor_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
|
|
@ -265,8 +411,12 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
|
|||
tensors = th.load(file_content, map_location=device)
|
||||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
|
||||
other_files = [
|
||||
file_name
|
||||
for file_name in namelist
|
||||
if os.path.splitext(file_name)[1] == ".pth"
|
||||
and file_name != "tensors.pth"
|
||||
]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
|
|
@ -279,7 +429,9 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
|
|||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right ``map_location``
|
||||
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
||||
params[os.path.splitext(file_path)[0]] = th.load(
|
||||
file_content, map_location=device
|
||||
)
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import io
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
import gym
|
||||
|
|
@ -12,6 +14,11 @@ from stable_baselines3.common.base_class import BaseAlgorithm
|
|||
from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.save_util import (
|
||||
open_path,
|
||||
save_to_pkl,
|
||||
load_from_pkl,
|
||||
)
|
||||
|
||||
MODEL_LIST = [
|
||||
PPO,
|
||||
|
|
@ -46,17 +53,21 @@ def test_save_load(tmp_path, model_class):
|
|||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
observations = np.concatenate(
|
||||
[env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
|
||||
)
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = dict(
|
||||
(param_name, th.rand_like(param)) for param_name, param in params.items()
|
||||
)
|
||||
|
||||
# Update model parameters with the new random values
|
||||
model.policy.load_state_dict(random_params)
|
||||
|
|
@ -64,7 +75,9 @@ def test_save_load(tmp_path, model_class):
|
|||
new_params = model.policy.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
assert not th.allclose(
|
||||
params[k], new_params[k]
|
||||
), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
|
|
@ -74,14 +87,16 @@ def test_save_load(tmp_path, model_class):
|
|||
# Check
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save"), env=env)
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.policy.state_dict()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
|
||||
assert th.allclose(
|
||||
params[key], new_params[key]
|
||||
), "Model parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = model.predict(observations, deterministic=True)
|
||||
|
|
@ -107,7 +122,7 @@ def test_set_env(model_class):
|
|||
env3 = select_env(model_class)
|
||||
|
||||
# create model
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]))
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
|
||||
# learn
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
|
@ -132,21 +147,21 @@ def test_exclude_include_saved_params(tmp_path, model_class):
|
|||
env = DummyVecEnv([lambda: select_env(model_class)])
|
||||
|
||||
# create model, set verbose as 2, which is not standard
|
||||
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2)
|
||||
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
|
||||
|
||||
# Check if exclude works
|
||||
model.save(tmp_path / "test_save.zip", exclude=["verbose"])
|
||||
model.save(tmp_path / "test_save", exclude=["verbose"])
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save"))
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"))
|
||||
# check if verbose was not saved
|
||||
assert model.verbose != 2
|
||||
|
||||
# set verbose as something different then standard settings
|
||||
model.verbose = 2
|
||||
# Check if include works
|
||||
model.save(tmp_path / "test_save.zip", exclude=["verbose"], include=["verbose"])
|
||||
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save"))
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"))
|
||||
assert model.verbose == 2
|
||||
|
||||
# clear file from os
|
||||
|
|
@ -155,13 +170,14 @@ def test_exclude_include_saved_params(tmp_path, model_class):
|
|||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
def test_save_load_replay_buffer(tmp_path, model_class):
|
||||
replay_path = tmp_path / 'replay_buffer.pkl'
|
||||
model = model_class('MlpPolicy', select_env(model_class), buffer_size=1000)
|
||||
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
|
||||
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
|
||||
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
|
||||
model.learn(500)
|
||||
old_replay_buffer = deepcopy(model.replay_buffer)
|
||||
model.save_replay_buffer(replay_path)
|
||||
model.save_replay_buffer(path)
|
||||
model.replay_buffer = None
|
||||
model.load_replay_buffer(replay_path)
|
||||
model.load_replay_buffer(path)
|
||||
|
||||
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
|
||||
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
|
||||
|
|
@ -169,11 +185,13 @@ def test_save_load_replay_buffer(tmp_path, model_class):
|
|||
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
|
||||
|
||||
# test extending replay buffer
|
||||
model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.observations,
|
||||
old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones)
|
||||
|
||||
# clear file from os
|
||||
os.remove(replay_path)
|
||||
model.replay_buffer.extend(
|
||||
old_replay_buffer.observations,
|
||||
old_replay_buffer.observations,
|
||||
old_replay_buffer.actions,
|
||||
old_replay_buffer.rewards,
|
||||
old_replay_buffer.dones,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [DQN, SAC, TD3])
|
||||
|
|
@ -186,12 +204,17 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
|||
See https://github.com/DLR-RM/stable-baselines3/issues/46
|
||||
"""
|
||||
# remove gym warnings
|
||||
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings(action='ignore', category=UserWarning, module='gym')
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
|
||||
|
||||
model = model_class('MlpPolicy', select_env(model_class), buffer_size=100,
|
||||
optimize_memory_usage=optimize_memory_usage, policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=10)
|
||||
model = model_class(
|
||||
"MlpPolicy",
|
||||
select_env(model_class),
|
||||
buffer_size=100,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=10,
|
||||
)
|
||||
|
||||
model.learn(150)
|
||||
|
||||
|
|
@ -205,13 +228,15 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
|||
if optimize_memory_usage:
|
||||
assert len(recwarn) == 1
|
||||
warning = recwarn.pop(UserWarning)
|
||||
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
|
||||
assert "The last trajectory in the replay buffer will be truncated" in str(
|
||||
warning.message
|
||||
)
|
||||
else:
|
||||
assert len(recwarn) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy'])
|
||||
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||
"""
|
||||
Test saving and loading policy only.
|
||||
|
|
@ -220,25 +245,29 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = {}
|
||||
if policy_str == 'MlpPolicy':
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [SAC, TD3, DQN]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2,
|
||||
discrete=model_class == DQN)
|
||||
env = FakeImageEnv(
|
||||
screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN
|
||||
)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1, **kwargs)
|
||||
model = model_class(
|
||||
policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
observations = np.concatenate(
|
||||
[env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
|
||||
)
|
||||
|
||||
policy = model.policy
|
||||
policy_class = policy.__class__
|
||||
|
|
@ -251,7 +280,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
params = deepcopy(policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = dict(
|
||||
(param_name, th.rand_like(param)) for param_name, param in params.items()
|
||||
)
|
||||
|
||||
# Update model parameters with the new random values
|
||||
policy.load_state_dict(random_params)
|
||||
|
|
@ -259,7 +290,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
new_params = policy.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
assert not th.allclose(
|
||||
params[k], new_params[k]
|
||||
), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
|
|
@ -286,7 +319,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
||||
assert th.allclose(
|
||||
params[key], new_params[key]
|
||||
), "Policy parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = policy.predict(observations, deterministic=True)
|
||||
|
|
@ -301,3 +336,93 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
os.remove(tmp_path / "policy.pkl")
|
||||
if actor_class is not None:
|
||||
os.remove(tmp_path / "actor.pkl")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
|
||||
def test_open_file_str_pathlib(tmp_path, pathtype):
|
||||
# check that suffix isn't added because we used open_path first
|
||||
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test custom suffix
|
||||
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test without suffix
|
||||
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
|
||||
assert not record
|
||||
|
||||
# test that a warning is raised when the path doesn't exist
|
||||
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
|
||||
save_to_pkl(fp1, "foo")
|
||||
assert fp1.closed
|
||||
with pytest.warns(None) as record:
|
||||
assert (
|
||||
load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl"))
|
||||
== "foo"
|
||||
)
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
assert (
|
||||
load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2))
|
||||
== "foo"
|
||||
)
|
||||
assert len(record) == 1
|
||||
|
||||
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
|
||||
fp.write("rubbish")
|
||||
fp.close()
|
||||
# test that a warning is only raised when verbose = 0
|
||||
with pytest.warns(None) as record:
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
|
||||
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
|
||||
assert len(record) == 1
|
||||
|
||||
|
||||
def test_open_file(tmp_path):
|
||||
|
||||
# path must much the type
|
||||
with pytest.raises(TypeError):
|
||||
open_path(123, None, None, None)
|
||||
|
||||
p1 = tmp_path / "test1"
|
||||
fp = p1.open("wb")
|
||||
|
||||
# provided path must match the mode
|
||||
with pytest.raises(ValueError):
|
||||
open_path(fp, "r")
|
||||
with pytest.raises(ValueError):
|
||||
open_path(fp, "randomstuff")
|
||||
|
||||
# test identity
|
||||
_ = open_path(fp, "w")
|
||||
assert _ is not None
|
||||
assert fp is _
|
||||
|
||||
# Can't use a closed path
|
||||
with pytest.raises(ValueError):
|
||||
fp.close()
|
||||
open_path(fp, "w")
|
||||
|
||||
buff = io.BytesIO()
|
||||
assert buff.writable()
|
||||
assert buff.readable() is ("w" == "w")
|
||||
_ = open_path(buff, "w")
|
||||
assert _ is buff
|
||||
with pytest.raises(ValueError):
|
||||
buff.close()
|
||||
open_path(buff, "w")
|
||||
|
|
|
|||
Loading…
Reference in a new issue