Merge branch 'master' into sde

This commit is contained in:
Antonin RAFFIN 2020-07-03 21:51:44 +02:00
commit c1f30acd6f
6 changed files with 396 additions and 96 deletions

View file

@ -17,6 +17,8 @@ New Features:
- Buffer dtype is now set according to action and observation spaces for ``ReplayBuffer``
- Added warning when allocation of a buffer may exceed the available memory of the system
when ``psutil`` is available
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
Bug Fixes:
^^^^^^^^^^
@ -37,6 +39,7 @@ Documentation:
^^^^^^^^^^^^^^
- Updated notebook links
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
- Added Unity reacher to the projects page (@koulakis)
@ -337,4 +340,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake
@tirafesi @blurLake @koulakis

View file

@ -25,3 +25,15 @@ It was the starting point of Stable-Baselines3.
| Author: Antonin Raffin, Freek Stulp
| Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
| Paper: https://arxiv.org/abs/2005.05719
Reacher
-------
A solution to the second project of the Udacity deep reinforcement learning course.
It is an example of:
- wrapping single and multi-agent Unity environments to make them usable in Stable-Baselines3
- creating experimentation scripts which train and run A2C, PPO, TD3 and SAC models (a better choice for this one is https://github.com/DLR-RM/rl-baselines3-zoo)
- generating several pre-trained models which solve the reacher environment
| Author: Marios Koulakis
| Github: https://github.com/koulakis/reacher-deep-reinforcement-learning

View file

@ -2,6 +2,8 @@ import time
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
import pathlib
import io
import gym
import torch as th
@ -291,7 +293,7 @@ class BaseAlgorithm(ABC):
return self.policy.predict(observation, state, mask, deterministic)
@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm':
"""
Load the model from a zip-file
@ -475,11 +477,16 @@ class BaseAlgorithm(ABC):
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]
def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None:
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[List[str]] = None,
include: Optional[List[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default one
:param include: name of parameters that might be excluded but should be included anyway
"""

View file

@ -1,7 +1,8 @@
import time
import pickle
import warnings
import pathlib
from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple
import io
import gym
import torch as th
@ -16,6 +17,7 @@ from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn, MaybeCa
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl
class OffPolicyAlgorithm(BaseAlgorithm):
@ -126,7 +128,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _setup_model(self):
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
@ -136,24 +138,23 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self.lr_schedule, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
def save_replay_buffer(self, path: str):
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Save the replay buffer as a pickle file.
:param path: (str) Path to the file where the replay buffer should be saved
:param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved.
if path is a str or pathlib.Path, the path is automatically created if necessary.
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
with open(path, 'wb') as file_handler:
pickle.dump(self.replay_buffer, file_handler)
save_to_pkl(path, self.replay_buffer, self.verbose)
def load_replay_buffer(self, path: str):
def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Load a replay buffer from a pickle file.
:param path: (str) Path to the pickled replay buffer.
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.
"""
with open(path, 'rb') as file_handler:
self.replay_buffer = pickle.load(file_handler)
self.replay_buffer = load_from_pkl(path, self.verbose)
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
def _setup_learn(self,

View file

@ -2,14 +2,16 @@
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import os
import io
import os
import json
import base64
import functools
from typing import Dict, Any, Tuple, Optional
from typing import Dict, Any, Tuple, Optional, Union
import warnings
import zipfile
import pathlib
import pickle
import torch as th
import cloudpickle
@ -30,10 +32,11 @@ def recursive_getattr(obj: Any, attr: str, *args) -> Any:
:param attr: (str) Attribute to retrieve
:return: (Any) The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
return functools.reduce(_getattr, [obj] + attr.split("."))
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
@ -48,7 +51,7 @@ def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
:param attr: (str) Attribute to set
:param val: (Any) New value of the attribute
"""
pre, _, post = attr.rpartition('.')
pre, _, post = attr.rpartition(".")
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
@ -92,16 +95,14 @@ def data_to_json(data: Dict[str, Any]) -> str:
# Also store type of the class for consumption
# from other languages/humans, so we have an
# idea what was being stored.
base64_encoded = base64.b64encode(
cloudpickle.dumps(data_item)
).decode()
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
# Use ":" to make sure we do
# not override these keys
# when we include variables of the object later
cloudpickle_serialization = {
":type:": str(type(data_item)),
":serialized:": base64_encoded
":serialized:": base64_encoded,
}
# Add first-level JSON-serializable items of the
@ -112,7 +113,9 @@ def data_to_json(data: Dict[str, Any]) -> str:
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
item_generator = (
data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
data_item.items
if isinstance(data_item, dict)
else data_item.__dict__.items
)
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
@ -127,8 +130,9 @@ def data_to_json(data: Dict[str, Any]) -> str:
return json_string
def json_to_data(json_string: str,
custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
def json_to_data(
json_string: str, custom_objects: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
@ -164,9 +168,11 @@ def json_to_data(json_string: str,
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except RuntimeError:
warnings.warn(f"Could not deserialize object {data_key}. " +
"Consider using `custom_objects` argument to replace " +
"this object.")
warnings.warn(
f"Could not deserialize object {data_key}. "
+ "Consider using `custom_objects` argument to replace "
+ "this object."
)
return_data[data_key] = deserialized_object
else:
# Read as it is
@ -174,71 +180,211 @@ def json_to_data(json_string: str,
return return_data
def save_to_zip_file(save_path: str, data: Dict[str, Any] = None,
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
@functools.singledispatch
def open_path(
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None
):
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
If the mode is write ("w", "write") it checks that the file is writable.
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
points to a folder, it changes the path to path_2. If the path already exists and verbose == 2,
it raises a warning.
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param mode: (str) how to open the file. "w"|"write" for writing, "r"|"read" for reading.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
"""
if not isinstance(path, io.BufferedIOBase):
raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
if path.closed:
raise ValueError("File stream is closed.")
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError:
raise ValueError("Expected mode to be either 'w' or 'r'.")
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
e1 = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {e1} file.")
return path
@open_path.register(str)
def open_path_str(
path: str, mode: str, verbose=0, suffix=None
) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: (str) the path to open. If mode is "w" then it ensures that the path exists
by creating the necessary folders and renaming path if it points to a folder.
:param mode: (str) how to open the file. "w" for writing, "r" for reading.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
"""
return open_path(pathlib.Path(path), mode, verbose, suffix)
@open_path.register(pathlib.Path)
def open_path_pathlib(
path: pathlib.Path, mode: str, verbose=0, suffix=None
) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: (pathlib.Path) the path to check. If mode is "w" then it
ensures that the path exists by creating the necessary folders and
renaming path if it points to a folder.
:param mode: (str) how to open the file. "w" for writing, "r" for reading.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
:param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
"""
if mode not in ("w", "r"):
raise ValueError("Expected mode to be either 'w' or 'r'.")
if mode == "r":
try:
path = path.open("rb")
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
if verbose == 2:
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
path, suffix = newpath, None
else:
raise error
else:
try:
if path.suffix == "" and suffix is not None and suffix != "":
path = pathlib.Path(f"{path}.{suffix}")
if path.exists() and path.is_file() and verbose == 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
path = path.open("wb")
except IsADirectoryError:
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
path = pathlib.Path(f"{path}_2")
except FileNotFoundError: # Occurs when the parent folder doesn't exist
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
path.parent.mkdir(exist_ok=True, parents=True)
# if opening was successful uses the identity function
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
# with corrections
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
return open_path(path, mode, verbose, suffix)
def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
tensors: Dict[str, Any] = None,
verbose=0,
) -> None:
"""
Save a model to a zip archive.
:param save_path: Where to store the model.
:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored.
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param tensors: Extra tensor variables expected to contain name and value of tensors
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
"""
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
# Check postfix if save_path is a string
if isinstance(save_path, str):
_, ext = os.path.splitext(save_path)
if ext == "":
save_path += ".zip"
# Create a zip-archive and write our objects
# there. This works when save_path is either
# str or a file-like
with zipfile.ZipFile(save_path, "w") as archive:
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if tensors is not None:
with archive.open('tensors.pth', mode="w") as tensors_file:
with archive.open("tensors.pth", mode="w") as tensors_file:
th.save(tensors, tensors_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + '.pth', mode="w") as param_file:
with archive.open(file_name + ".pth", mode="w") as param_file:
th.save(dict_, param_file)
def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
Optional[TensorDict],
Optional[TensorDict]]):
def save_to_pkl(
path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0
) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param obj: The object to save.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
pickle.dump(obj, file_handler)
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) -> Any:
"""
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
If the path does not exist, it will attempt to load using the .pkl suffix.
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information.
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
verbose=0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Load model data from a .zip archive
:param load_path: Where to load the model from
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
"""
# Check if file exists if load_path is a string
if isinstance(load_path, str):
if not os.path.exists(load_path):
if os.path.exists(load_path + ".zip"):
load_path += ".zip"
else:
raise ValueError(f"Error: the file {load_path} could not be found")
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
# set device to cpu if cuda is not available
device = get_device()
# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path, "r") as archive:
with zipfile.ZipFile(load_path) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
@ -254,7 +400,7 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
if "tensors.pth" in namelist and load_data:
# Load extra tensors
with archive.open('tensors.pth', mode="r") as tensor_file:
with archive.open("tensors.pth", mode="r") as tensor_file:
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
@ -265,8 +411,12 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
tensors = th.load(file_content, map_location=device)
# check for all other .pth files
other_files = [file_name for file_name in namelist if
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
other_files = [
file_name
for file_name in namelist
if os.path.splitext(file_name)[1] == ".pth"
and file_name != "tensors.pth"
]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
if len(other_files) > 0:
@ -279,7 +429,9 @@ def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optiona
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
params[os.path.splitext(file_path)[0]] = th.load(
file_content, map_location=device
)
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")

View file

@ -1,6 +1,8 @@
import os
import io
import warnings
from copy import deepcopy
import pathlib
import pytest
import gym
@ -12,6 +14,11 @@ from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.save_util import (
open_path,
save_to_pkl,
load_from_pkl,
)
MODEL_LIST = [
PPO,
@ -46,17 +53,21 @@ def test_save_load(tmp_path, model_class):
env = DummyVecEnv([lambda: select_env(model_class)])
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
observations = np.concatenate(
[env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
)
# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = dict(
(param_name, th.rand_like(param)) for param_name, param in params.items()
)
# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
@ -64,7 +75,9 @@ def test_save_load(tmp_path, model_class):
new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
assert not th.allclose(
params[k], new_params[k]
), "Parameters did not change as expected."
params = new_params
@ -74,14 +87,16 @@ def test_save_load(tmp_path, model_class):
# Check
model.save(tmp_path / "test_save.zip")
del model
model = model_class.load(str(tmp_path / "test_save"), env=env)
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)
# check if params are still the same after load
new_params = model.policy.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
assert th.allclose(
params[key], new_params[key]
), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
@ -107,7 +122,7 @@ def test_set_env(model_class):
env3 = select_env(model_class)
# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]))
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]))
# learn
model.learn(total_timesteps=1000, eval_freq=500)
@ -132,21 +147,21 @@ def test_exclude_include_saved_params(tmp_path, model_class):
env = DummyVecEnv([lambda: select_env(model_class)])
# create model, set verbose as 2, which is not standard
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2)
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=2)
# Check if exclude works
model.save(tmp_path / "test_save.zip", exclude=["verbose"])
model.save(tmp_path / "test_save", exclude=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save"))
model = model_class.load(str(tmp_path / "test_save.zip"))
# check if verbose was not saved
assert model.verbose != 2
# set verbose as something different then standard settings
model.verbose = 2
# Check if include works
model.save(tmp_path / "test_save.zip", exclude=["verbose"], include=["verbose"])
model.save(tmp_path / "test_save", exclude=["verbose"], include=["verbose"])
del model
model = model_class.load(str(tmp_path / "test_save"))
model = model_class.load(str(tmp_path / "test_save.zip"))
assert model.verbose == 2
# clear file from os
@ -155,13 +170,14 @@ def test_exclude_include_saved_params(tmp_path, model_class):
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
def test_save_load_replay_buffer(tmp_path, model_class):
replay_path = tmp_path / 'replay_buffer.pkl'
model = model_class('MlpPolicy', select_env(model_class), buffer_size=1000)
path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl")
path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning
model = model_class("MlpPolicy", select_env(model_class), buffer_size=1000)
model.learn(500)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(replay_path)
model.save_replay_buffer(path)
model.replay_buffer = None
model.load_replay_buffer(replay_path)
model.load_replay_buffer(path)
assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
@ -169,11 +185,13 @@ def test_save_load_replay_buffer(tmp_path, model_class):
assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones)
# test extending replay buffer
model.replay_buffer.extend(old_replay_buffer.observations, old_replay_buffer.observations,
old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones)
# clear file from os
os.remove(replay_path)
model.replay_buffer.extend(
old_replay_buffer.observations,
old_replay_buffer.observations,
old_replay_buffer.actions,
old_replay_buffer.rewards,
old_replay_buffer.dones,
)
@pytest.mark.parametrize("model_class", [DQN, SAC, TD3])
@ -186,12 +204,17 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
See https://github.com/DLR-RM/stable-baselines3/issues/46
"""
# remove gym warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
warnings.filterwarnings(action='ignore', category=UserWarning, module='gym')
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=UserWarning, module="gym")
model = model_class('MlpPolicy', select_env(model_class), buffer_size=100,
optimize_memory_usage=optimize_memory_usage, policy_kwargs=dict(net_arch=[64]),
learning_starts=10)
model = model_class(
"MlpPolicy",
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
model.learn(150)
@ -205,13 +228,15 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
assert "The last trajectory in the replay buffer will be truncated" in str(
warning.message
)
else:
assert len(recwarn) == 0
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy'])
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str):
"""
Test saving and loading policy only.
@ -220,25 +245,29 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
:param policy_str: (str) Name of the policy.
"""
kwargs = {}
if policy_str == 'MlpPolicy':
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [SAC, TD3, DQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2,
discrete=model_class == DQN)
env = FakeImageEnv(
screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN
)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
verbose=1, **kwargs)
model = model_class(
policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs
)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
observations = np.concatenate(
[env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
)
policy = model.policy
policy_class = policy.__class__
@ -251,7 +280,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = dict(
(param_name, th.rand_like(param)) for param_name, param in params.items()
)
# Update model parameters with the new random values
policy.load_state_dict(random_params)
@ -259,7 +290,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
new_params = policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
assert not th.allclose(
params[k], new_params[k]
), "Parameters did not change as expected."
params = new_params
@ -286,7 +319,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
assert th.allclose(
params[key], new_params[key]
), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
@ -301,3 +336,93 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "policy.pkl")
if actor_class is not None:
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record
# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record
# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record
# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
assert (
load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl"))
== "foo"
)
assert len(record) == 0
with pytest.warns(None) as record:
assert (
load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2))
== "foo"
)
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
assert len(record) == 1
def test_open_file(tmp_path):
# path must much the type
with pytest.raises(TypeError):
open_path(123, None, None, None)
p1 = tmp_path / "test1"
fp = p1.open("wb")
# provided path must match the mode
with pytest.raises(ValueError):
open_path(fp, "r")
with pytest.raises(ValueError):
open_path(fp, "randomstuff")
# test identity
_ = open_path(fp, "w")
assert _ is not None
assert fp is _
# Can't use a closed path
with pytest.raises(ValueError):
fp.close()
open_path(fp, "w")
buff = io.BytesIO()
assert buff.writable()
assert buff.readable() is ("w" == "w")
_ = open_path(buff, "w")
assert _ is buff
with pytest.raises(ValueError):
buff.close()
open_path(buff, "w")