From 7d8ebb9e989aa9798191cbb21e5c44fc1b61ed35 Mon Sep 17 00:00:00 2001 From: Marios Koulakis Date: Tue, 30 Jun 2020 15:03:02 +0200 Subject: [PATCH 1/2] Udacity Reacher Project with Unity (#79) * Add the reacher project to the sample projects * Update the change log * Remove github incompatible link notation * Update changelog.rst Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- docs/misc/projects.rst | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6725594..14bce0a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -37,6 +37,7 @@ Documentation: ^^^^^^^^^^^^^^ - Updated notebook links - Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake) +- Added Unity reacher to the projects page (@koulakis) @@ -339,4 +340,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake +@tirafesi @blurLake @koulakis diff --git a/docs/misc/projects.rst b/docs/misc/projects.rst index 75a8093..9a3b2f1 100644 --- a/docs/misc/projects.rst +++ b/docs/misc/projects.rst @@ -25,3 +25,15 @@ It was the starting point of Stable-Baselines3. | Author: Antonin Raffin, Freek Stulp | Github: https://github.com/DLR-RM/stable-baselines3/tree/sde | Paper: https://arxiv.org/abs/2005.05719 + +Reacher +------- +A solution to the second project of the Udacity deep reinforcement learning course. +It is an example of: + +- wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3 +- creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is https://github.com/DLR-RM/rl-baselines3-zoo) +- generating several pre-trained models which solve the reacher environment + +| Author: Marios Koulakis +| Github: https://github.com/koulakis/reacher-deep-reinforcement-learning From 4aa66ed34a631a7ab891781df07b978ba3547cab Mon Sep 17 00:00:00 2001 From: Stelios Tymvios <52372765+PartiallyTyped@users.noreply.github.com> Date: Fri, 3 Jul 2020 01:14:21 +0300 Subject: [PATCH 2/2] Automatically create paths for saved objects (#80) * automatically create paths for saved objects * Minor Corrections, more tests * linting * typing * Correct mode checking * corrected tests to reflect new verbose functionality --- docs/misc/changelog.rst | 2 + stable_baselines3/common/base_class.py | 13 +- .../common/off_policy_algorithm.py | 21 +- stable_baselines3/common/save_util.py | 242 ++++++++++++++---- tests/test_save_load.py | 199 +++++++++++--- 5 files changed, 382 insertions(+), 95 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 14bce0a..45d78d3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,8 @@ New Features: - Buffer dtype is now set according to action and observation spaces for ``ReplayBuffer`` - Added warning when allocation of a buffer may exceed the available memory of the system when ``psutil`` is available +- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped) +- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 9b52ec4..4453231 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -2,6 +2,8 @@ import time from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable from abc import ABC, abstractmethod from collections import deque +import pathlib +import io import gym import torch as th @@ -291,7 +293,7 @@ class BaseAlgorithm(ABC): return self.policy.predict(observation, state, mask, deterministic) @classmethod - def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs): + def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm': """ Load the model from a zip-file @@ -475,11 +477,16 @@ class BaseAlgorithm(ABC): """ return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"] - def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None: + def save( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + exclude: Optional[List[str]] = None, + include: Optional[List[str]] = None, + ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. - :param path: path to the file where the rl agent should be saved + :param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved :param exclude: name of parameters that should be excluded in addition to the default one :param include: name of parameters that might be excluded but should be included anyway """ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 5b75e44..f459de9 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -1,7 +1,8 @@ import time -import pickle import warnings +import pathlib from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple +import io import gym import torch as th @@ -16,6 +17,7 @@ from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn, MaybeCa from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl class OffPolicyAlgorithm(BaseAlgorithm): @@ -126,7 +128,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup - def _setup_model(self): + def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space, @@ -136,24 +138,23 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.lr_schedule, **self.policy_kwargs) self.policy = self.policy.to(self.device) - def save_replay_buffer(self, path: str): + def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Save the replay buffer as a pickle file. - :param path: (str) Path to the file where the replay buffer should be saved + :param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved. + if path is a str or pathlib.Path, the path is automatically created if necessary. """ assert self.replay_buffer is not None, "The replay buffer is not defined" - with open(path, 'wb') as file_handler: - pickle.dump(self.replay_buffer, file_handler) + save_to_pkl(path, self.replay_buffer, self.verbose) - def load_replay_buffer(self, path: str): + def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Load a replay buffer from a pickle file. - :param path: (str) Path to the pickled replay buffer. + :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer. """ - with open(path, 'rb') as file_handler: - self.replay_buffer = pickle.load(file_handler) + self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class' def _setup_learn(self, diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index fe5f890..0d2fe98 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -2,14 +2,16 @@ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ -import os import io +import os import json import base64 import functools -from typing import Dict, Any, Tuple, Optional +from typing import Dict, Any, Tuple, Optional, Union import warnings import zipfile +import pathlib +import pickle import torch as th import cloudpickle @@ -30,10 +32,11 @@ def recursive_getattr(obj: Any, attr: str, *args) -> Any: :param attr: (str) Attribute to retrieve :return: (Any) The attribute """ + def _getattr(obj: Any, attr: str) -> Any: return getattr(obj, attr, *args) - return functools.reduce(_getattr, [obj] + attr.split('.')) + return functools.reduce(_getattr, [obj] + attr.split(".")) def recursive_setattr(obj: Any, attr: str, val: Any) -> None: @@ -48,7 +51,7 @@ def recursive_setattr(obj: Any, attr: str, val: Any) -> None: :param attr: (str) Attribute to set :param val: (Any) New value of the attribute """ - pre, _, post = attr.rpartition('.') + pre, _, post = attr.rpartition(".") return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) @@ -92,16 +95,14 @@ def data_to_json(data: Dict[str, Any]) -> str: # Also store type of the class for consumption # from other languages/humans, so we have an # idea what was being stored. - base64_encoded = base64.b64encode( - cloudpickle.dumps(data_item) - ).decode() + base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode() # Use ":" to make sure we do # not override these keys # when we include variables of the object later cloudpickle_serialization = { ":type:": str(type(data_item)), - ":serialized:": base64_encoded + ":serialized:": base64_encoded, } # Add first-level JSON-serializable items of the @@ -112,7 +113,9 @@ def data_to_json(data: Dict[str, Any]) -> str: if hasattr(data_item, "__dict__") or isinstance(data_item, dict): # Take elements from __dict__ for custom classes item_generator = ( - data_item.items if isinstance(data_item, dict) else data_item.__dict__.items + data_item.items + if isinstance(data_item, dict) + else data_item.__dict__.items ) for variable_name, variable_item in item_generator(): # Check if serializable. If not, just include the @@ -127,8 +130,9 @@ def data_to_json(data: Dict[str, Any]) -> str: return json_string -def json_to_data(json_string: str, - custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def json_to_data( + json_string: str, custom_objects: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. @@ -164,9 +168,11 @@ def json_to_data(json_string: str, base64_object = base64.b64decode(serialization.encode()) deserialized_object = cloudpickle.loads(base64_object) except RuntimeError: - warnings.warn(f"Could not deserialize object {data_key}. " + - "Consider using `custom_objects` argument to replace " + - "this object.") + warnings.warn( + f"Could not deserialize object {data_key}. " + + "Consider using `custom_objects` argument to replace " + + "this object." + ) return_data[data_key] = deserialized_object else: # Read as it is @@ -174,71 +180,211 @@ def json_to_data(json_string: str, return return_data -def save_to_zip_file(save_path: str, data: Dict[str, Any] = None, - params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None: +@functools.singledispatch +def open_path( + path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None +): + """ + Opens a path for reading or writing with a preferred suffix and raises debug information. + If the provided path is a derivative of io.BufferedIOBase it ensures that the file + matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable. + If the mode is write ("w", "write") it checks that the file is writable. + + If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read" + it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided. + If the mode is "write" and the path does not exist, it creates all the parent folders. If the path + points to a folder, it changes the path to path_2. If the path already exists and verbose == 2, + it raises a warning. + + :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param mode: (str) how to open the file. "w"|"write" for writing, "r"|"read" for reading. + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + """ + if not isinstance(path, io.BufferedIOBase): + raise TypeError("Path parameter has invalid type.", io.BufferedIOBase) + if path.closed: + raise ValueError("File stream is closed.") + mode = mode.lower() + try: + mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] + except KeyError: + raise ValueError("Expected mode to be either 'w' or 'r'.") + if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): + e1 = "writable" if "w" == mode else "readable" + raise ValueError(f"Expected a {e1} file.") + return path + + +@open_path.register(str) +def open_path_str( + path: str, mode: str, verbose=0, suffix=None +) -> io.BufferedIOBase: + """ + Open a path given by a string. If writing to the path, the function ensures + that the path exists. + + :param path: (str) the path to open. If mode is "w" then it ensures that the path exists + by creating the necessary folders and renaming path if it points to a folder. + :param mode: (str) how to open the file. "w" for writing, "r" for reading. + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + """ + return open_path(pathlib.Path(path), mode, verbose, suffix) + + +@open_path.register(pathlib.Path) +def open_path_pathlib( + path: pathlib.Path, mode: str, verbose=0, suffix=None +) -> io.BufferedIOBase: + """ + Open a path given by a string. If writing to the path, the function ensures + that the path exists. + + :param path: (pathlib.Path) the path to check. If mode is "w" then it + ensures that the path exists by creating the necessary folders and + renaming path if it points to a folder. + :param mode: (str) how to open the file. "w" for writing, "r" for reading. + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + """ + if mode not in ("w", "r"): + raise ValueError("Expected mode to be either 'w' or 'r'.") + + if mode == "r": + try: + path = path.open("rb") + except FileNotFoundError as error: + if suffix is not None and suffix != "": + newpath = pathlib.Path(f"{path}.{suffix}") + if verbose == 2: + warnings.warn(f"Path '{path}' not found. Attempting {newpath}.") + path, suffix = newpath, None + else: + raise error + else: + try: + if path.suffix == "" and suffix is not None and suffix != "": + path = pathlib.Path(f"{path}.{suffix}") + if path.exists() and path.is_file() and verbose == 2: + warnings.warn(f"Path '{path}' exists, will overwrite it.") + path = path.open("wb") + except IsADirectoryError: + warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2") + path = pathlib.Path(f"{path}_2") + except FileNotFoundError: # Occurs when the parent folder doesn't exist + warnings.warn(f"Path '{path.parent}' does not exist. Will create it.") + path.parent.mkdir(exist_ok=True, parents=True) + + # if opening was successful uses the identity function + # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib + # with corrections + # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix + + return open_path(path, mode, verbose, suffix) + + +def save_to_zip_file( + save_path: Union[str, pathlib.Path, io.BufferedIOBase], + data: Dict[str, Any] = None, + params: Dict[str, Any] = None, + tensors: Dict[str, Any] = None, + verbose=0, +) -> None: """ Save a model to a zip archive. - :param save_path: Where to store the model. + :param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model. + if save_path is a str or pathlib.Path ensures that the path actually exists. :param data: Class parameters being stored. :param params: Model parameters being stored expected to contain an entry for every state_dict with its name and the state_dict. :param tensors: Extra tensor variables expected to contain name and value of tensors + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information """ + save_path = open_path(save_path, "w", verbose=0, suffix="zip") # data/params can be None, so do not # try to serialize them blindly if data is not None: serialized_data = data_to_json(data) - # Check postfix if save_path is a string - if isinstance(save_path, str): - _, ext = os.path.splitext(save_path) - if ext == "": - save_path += ".zip" - - # Create a zip-archive and write our objects - # there. This works when save_path is either - # str or a file-like - with zipfile.ZipFile(save_path, "w") as archive: + # Create a zip-archive and write our objects there. + with zipfile.ZipFile(save_path, mode="w") as archive: # Do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) if tensors is not None: - with archive.open('tensors.pth', mode="w") as tensors_file: + with archive.open("tensors.pth", mode="w") as tensors_file: th.save(tensors, tensors_file) if params is not None: for file_name, dict_ in params.items(): - with archive.open(file_name + '.pth', mode="w") as param_file: + with archive.open(file_name + ".pth", mode="w") as param_file: th.save(dict_, param_file) -def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]], - Optional[TensorDict], - Optional[TensorDict]]): +def save_to_pkl( + path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0 +) -> None: + """ + Save an object to path creating the necessary folders along the way. + If the path exists and is a directory, it will raise a warning and rename the path. + If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'. + + :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param obj: The object to save. + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + """ + with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler: + pickle.dump(obj, file_handler) + + +def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) -> Any: + """ + Load an object from the path. If a suffix is provided in the path, it will use that suffix. + If the path does not exist, it will attempt to load using the .pkl suffix. + + :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + """ + with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler: + return pickle.load(file_handler) + + +def load_from_zip_file( + load_path: Union[str, pathlib.Path, io.BufferedIOBase], + load_data: bool = True, + verbose=0, +) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): """ Load model data from a .zip archive - :param load_path: Where to load the model from + :param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict) and dict of extra tensors """ - # Check if file exists if load_path is a string - if isinstance(load_path, str): - if not os.path.exists(load_path): - if os.path.exists(load_path + ".zip"): - load_path += ".zip" - else: - raise ValueError(f"Error: the file {load_path} could not be found") + load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") # set device to cpu if cuda is not available device = get_device() # Open the zip archive and load data try: - with zipfile.ZipFile(load_path, "r") as archive: + with zipfile.ZipFile(load_path) as archive: namelist = archive.namelist() # If data or parameters is not in the # zip archive, assume they were stored @@ -254,7 +400,7 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona if "tensors.pth" in namelist and load_data: # Load extra tensors - with archive.open('tensors.pth', mode="r") as tensor_file: + with archive.open("tensors.pth", mode="r") as tensor_file: # File has to be seekable, but opt_param_file is not, so load in BytesIO first # fixed in python >= 3.7 file_content = io.BytesIO() @@ -265,8 +411,12 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona tensors = th.load(file_content, map_location=device) # check for all other .pth files - other_files = [file_name for file_name in namelist if - os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"] + other_files = [ + file_name + for file_name in namelist + if os.path.splitext(file_name)[1] == ".pth" + and file_name != "tensors.pth" + ] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters if len(other_files) > 0: @@ -279,7 +429,9 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona # go to start of file file_content.seek(0) # load the parameters with the right ``map_location`` - params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device) + params[os.path.splitext(file_path)[0]] = th.load( + file_content, map_location=device + ) except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d107115..95f8c93 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,6 +1,8 @@ import os +import io import warnings from copy import deepcopy +import pathlib import pytest import gym @@ -12,6 +14,11 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.save_util import ( + open_path, + save_to_pkl, + load_from_pkl, +) MODEL_LIST = [ PPO, @@ -46,17 +53,21 @@ def test_save_load(tmp_path, model_class): env = DummyVecEnv([lambda: select_env(model_class)]) # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1) + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + observations = np.concatenate( + [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0 + ) # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = dict( + (param_name, th.rand_like(param)) for param_name, param in params.items() + ) # Update model parameters with the new random values model.policy.load_state_dict(random_params) @@ -64,7 +75,9 @@ def test_save_load(tmp_path, model_class): new_params = model.policy.state_dict() # Check that all params are different now for k in params: - assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + assert not th.allclose( + params[k], new_params[k] + ), "Parameters did not change as expected." params = new_params @@ -74,14 +87,16 @@ def test_save_load(tmp_path, model_class): # Check model.save(tmp_path / "test_save.zip") del model - model = model_class.load(str(tmp_path / "test_save"), env=env) + model = model_class.load(str(tmp_path / "test_save.zip"), env=env) # check if params are still the same after load new_params = model.policy.state_dict() # Check that all params are the same as before save load procedure now for key in params: - assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." + assert th.allclose( + params[key], new_params[key] + ), "Model parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = model.predict(observations, deterministic=True) @@ -107,7 +122,7 @@ def test_set_env(model_class): env3 = select_env(model_class) # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16])) + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) # learn model.learn(total_timesteps=1000, eval_freq=500) @@ -132,21 +147,21 @@ def test_exclude_include_saved_params(tmp_path, model_class): env = DummyVecEnv([lambda: select_env(model_class)]) # create model, set verbose as 2, which is not standard - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2) + model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2) # Check if exclude works - model.save(tmp_path / "test_save.zip", exclude=["verbose"]) + model.save(tmp_path / "test_save", exclude=["verbose"]) del model - model = model_class.load(str(tmp_path / "test_save")) + model = model_class.load(str(tmp_path / "test_save.zip")) # check if verbose was not saved assert model.verbose != 2 # set verbose as something different then standard settings model.verbose = 2 # Check if include works - model.save(tmp_path / "test_save.zip", exclude=["verbose"], include=["verbose"]) + model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"]) del model - model = model_class.load(str(tmp_path / "test_save")) + model = model_class.load(str(tmp_path / "test_save.zip")) assert model.verbose == 2 # clear file from os @@ -155,13 +170,14 @@ def test_exclude_include_saved_params(tmp_path, model_class): @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) def test_save_load_replay_buffer(tmp_path, model_class): - replay_path = tmp_path / 'replay_buffer.pkl' - model = model_class('MlpPolicy', select_env(model_class), buffer_size=1000) + path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") + path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning + model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000) model.learn(500) old_replay_buffer = deepcopy(model.replay_buffer) - model.save_replay_buffer(replay_path) + model.save_replay_buffer(path) model.replay_buffer = None - model.load_replay_buffer(replay_path) + model.load_replay_buffer(path) assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations) assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) @@ -169,11 +185,13 @@ def test_save_load_replay_buffer(tmp_path, model_class): assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) # test extending replay buffer - model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.observations, - old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones) - - # clear file from os - os.remove(replay_path) + model.replay_buffer.extend( + old_replay_buffer.observations, + old_replay_buffer.observations, + old_replay_buffer.actions, + old_replay_buffer.rewards, + old_replay_buffer.dones, + ) @pytest.mark.parametrize("model_class", [DQN, SAC, TD3]) @@ -186,12 +204,17 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): See https://github.com/DLR-RM/stable-baselines3/issues/46 """ # remove gym warnings - warnings.filterwarnings(action='ignore', category=DeprecationWarning) - warnings.filterwarnings(action='ignore', category=UserWarning, module='gym') + warnings.filterwarnings(action="ignore", category=DeprecationWarning) + warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") - model = model_class('MlpPolicy', select_env(model_class), buffer_size=100, - optimize_memory_usage=optimize_memory_usage, policy_kwargs=dict(net_arch=[64]), - learning_starts=10) + model = model_class( + "MlpPolicy", + select_env(model_class), + buffer_size=100, + optimize_memory_usage=optimize_memory_usage, + policy_kwargs=dict(net_arch=[64]), + learning_starts=10, + ) model.learn(150) @@ -205,13 +228,15 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): if optimize_memory_usage: assert len(recwarn) == 1 warning = recwarn.pop(UserWarning) - assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) + assert "The last trajectory in the replay buffer will be truncated" in str( + warning.message + ) else: assert len(recwarn) == 0 @pytest.mark.parametrize("model_class", MODEL_LIST) -@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy']) +@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) def test_save_load_policy(tmp_path, model_class, policy_str): """ Test saving and loading policy only. @@ -220,25 +245,29 @@ def test_save_load_policy(tmp_path, model_class, policy_str): :param policy_str: (str) Name of the policy. """ kwargs = {} - if policy_str == 'MlpPolicy': + if policy_str == "MlpPolicy": env = select_env(model_class) else: if model_class in [SAC, TD3, DQN]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict(buffer_size=250) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, - discrete=model_class == DQN) + env = FakeImageEnv( + screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN + ) env = DummyVecEnv([lambda: env]) # create model - model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), - verbose=1, **kwargs) + model = model_class( + policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs + ) model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + observations = np.concatenate( + [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0 + ) policy = model.policy policy_class = policy.__class__ @@ -251,7 +280,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = dict( + (param_name, th.rand_like(param)) for param_name, param in params.items() + ) # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -259,7 +290,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str): new_params = policy.state_dict() # Check that all params are different now for k in params: - assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + assert not th.allclose( + params[k], new_params[k] + ), "Parameters did not change as expected." params = new_params @@ -286,7 +319,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # Check that all params are the same as before save load procedure now for key in params: - assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." + assert th.allclose( + params[key], new_params[key] + ), "Policy parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = policy.predict(observations, deterministic=True) @@ -301,3 +336,93 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "policy.pkl") if actor_class is not None: os.remove(tmp_path / "actor.pkl") + + +@pytest.mark.parametrize("pathtype", [str, pathlib.Path]) +def test_open_file_str_pathlib(tmp_path, pathtype): + # check that suffix isn't added because we used open_path first + with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" + assert not record + + # test custom suffix + with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" + assert not record + + # test without suffix + with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" + assert not record + + # test that a warning is raised when the path doesn't exist + with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: + save_to_pkl(fp1, "foo") + assert fp1.closed + with pytest.warns(None) as record: + assert ( + load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) + == "foo" + ) + assert len(record) == 0 + + with pytest.warns(None) as record: + assert ( + load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) + == "foo" + ) + assert len(record) == 1 + + fp = pathlib.Path(f"{tmp_path}/t2").open("w") + fp.write("rubbish") + fp.close() + # test that a warning is only raised when verbose = 0 + with pytest.warns(None) as record: + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() + open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() + assert len(record) == 1 + + +def test_open_file(tmp_path): + + # path must much the type + with pytest.raises(TypeError): + open_path(123, None, None, None) + + p1 = tmp_path / "test1" + fp = p1.open("wb") + + # provided path must match the mode + with pytest.raises(ValueError): + open_path(fp, "r") + with pytest.raises(ValueError): + open_path(fp, "randomstuff") + + # test identity + _ = open_path(fp, "w") + assert _ is not None + assert fp is _ + + # Can't use a closed path + with pytest.raises(ValueError): + fp.close() + open_path(fp, "w") + + buff = io.BytesIO() + assert buff.writable() + assert buff.readable() is ("w" == "w") + _ = open_path(buff, "w") + assert _ is buff + with pytest.raises(ValueError): + buff.close() + open_path(buff, "w")