mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update doc: SB3-Contrib (#267)
* Fix big when saving/loading q-net alone * Rename variables to match SB3-contrib * Update docker image * Set min version for tensorboard * Add SB3-Contrib to doc * Update DQN * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
parent
b8c72a5348
commit
944dfdafe4
15 changed files with 234 additions and 41 deletions
|
|
@ -1,4 +1,4 @@
|
|||
image: stablebaselines/stable-baselines3-cpu:0.9.0a2
|
||||
image: stablebaselines/stable-baselines3-cpu:0.11.0a4
|
||||
|
||||
type-check:
|
||||
script:
|
||||
|
|
|
|||
12
README.md
12
README.md
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
# Stable Baselines3
|
||||
|
||||
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
||||
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
||||
|
||||
You can read a detailed presentation of Stable Baselines in the [Medium article](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-reinforcement-learning-made-easy-df87c4b2fc82).
|
||||
|
||||
|
|
@ -50,7 +50,6 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st
|
|||
|
||||
Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)
|
||||
|
||||
|
||||
## RL Baselines3 Zoo: A Collection of Trained RL Agents
|
||||
|
||||
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.
|
||||
|
|
@ -68,6 +67,15 @@ Github repo: https://github.com/DLR-RM/rl-baselines3-zoo
|
|||
|
||||
Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html
|
||||
|
||||
## SB3-Contrib: Experimental RL Features
|
||||
|
||||
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
|
||||
|
||||
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN).
|
||||
|
||||
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
**Note:** Stable-Baselines3 supports PyTorch 1.4+.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ Actions ``gym.spaces``:
|
|||
- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
|
||||
|
||||
.. note::
|
||||
|
||||
Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper
|
||||
|
|
|
|||
|
|
@ -87,8 +87,6 @@ Looking at the training curve (episode reward function of the timesteps) is a go
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
We suggest you reading `Deep Reinforcement Learning that Matters <https://arxiv.org/abs/1709.06560>`_ for a good discussion about RL evaluation.
|
||||
|
||||
You can also take a look at this `blog post <https://openlab-flowers.inria.fr/t/how-many-random-seeds-should-i-use-statistical-power-analysis-in-deep-reinforcement-learning-experiments/457>`_
|
||||
|
|
@ -122,6 +120,7 @@ Discrete Actions - Single Process
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
|
||||
We notably provide QR-DQN in our :ref:`contrib repo <sb3_contrib>`.
|
||||
DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
|
||||
|
||||
Discrete Actions - Multiprocessed
|
||||
|
|
@ -136,7 +135,7 @@ Continuous Actions
|
|||
Continuous Actions - Single Process
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Current State Of The Art (SOTA) algorithms are ``SAC`` and ``TD3``.
|
||||
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>`).
|
||||
Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for best results.
|
||||
|
||||
|
||||
|
|
@ -156,7 +155,7 @@ Goal Environment
|
|||
-----------------
|
||||
|
||||
If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
|
||||
HER + (SAC/TD3/DDPG/DQN) depending on the action space.
|
||||
HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space.
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
|
|||
97
docs/guide/sb3_contrib.rst
Normal file
97
docs/guide/sb3_contrib.rst
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
.. _sb3_contrib:
|
||||
|
||||
==================
|
||||
SB3 Contrib
|
||||
==================
|
||||
|
||||
We implement experimental features in a separate contrib repository:
|
||||
`SB3-Contrib`_
|
||||
|
||||
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
|
||||
providing the latest features, like Truncated Quantile Critics (TQC) or
|
||||
Quantile Regression DQN (QR-DQN).
|
||||
|
||||
Why create this repository?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Over the span of stable-baselines and stable-baselines3, the community
|
||||
has been eager to contribute in form of better logging utilities,
|
||||
environment wrappers, extended support (e.g. different action spaces)
|
||||
and learning algorithms.
|
||||
|
||||
However sometimes these utilities were too niche to be considered for
|
||||
stable-baselines or proved to be too difficult to integrate well into
|
||||
the existing code without creating a mess. sb3-contrib aims to fix this by not
|
||||
requiring the neatest code integration with existing code and not
|
||||
setting limits on what is too niche: almost everything remotely useful
|
||||
goes!
|
||||
We hope this allows us to provide reliable implementations
|
||||
following stable-baselines usual standards (consistent style, documentation, etc)
|
||||
beyond the relatively small scope of utilities in the main repository.
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
See documentation for the full list of included features.
|
||||
|
||||
**RL Algorithms**:
|
||||
|
||||
- `Truncated Quantile Critics (TQC)`_
|
||||
- `Quantile Regression DQN (QR-DQN)`_
|
||||
|
||||
**Gym Wrappers**:
|
||||
|
||||
- `Time Feature Wrapper`_
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
Documentation is available online: https://sb3-contrib.readthedocs.io/
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
To install Stable-Baselines3 contrib with pip, execute:
|
||||
|
||||
::
|
||||
|
||||
pip install sb3-contrib
|
||||
|
||||
We recommend to use the ``master`` version of Stable Baselines3 and SB3-Contrib.
|
||||
|
||||
To install Stable Baselines3 ``master`` version:
|
||||
|
||||
::
|
||||
|
||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||
|
||||
To install Stable Baselines3 contrib ``master`` version:
|
||||
|
||||
::
|
||||
|
||||
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3,
|
||||
using SB3-Contrib should be easy too.
|
||||
|
||||
Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sb3_contrib import QRDQN
|
||||
|
||||
policy_kwargs = dict(n_quantiles=50)
|
||||
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=4)
|
||||
model.save("qrdqn_cartpole")
|
||||
|
||||
|
||||
|
||||
.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
.. _Truncated Quantile Critics (TQC): https://arxiv.org/abs/2005.04269
|
||||
.. _Quantile Regression DQN (QR-DQN): https://arxiv.org/abs/1710.10044
|
||||
.. _Time Feature Wrapper: https://arxiv.org/abs/1712.00378
|
||||
|
|
@ -3,10 +3,10 @@
|
|||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to Stable Baselines3 docs! - RL Baselines Made Easy
|
||||
===========================================================
|
||||
Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations
|
||||
========================================================================
|
||||
|
||||
`Stable Baselines3 <https://github.com/DLR-RM/stable-baselines3>`_ is a set of improved implementations of reinforcement learning algorithms in PyTorch.
|
||||
`Stable Baselines3 (SB3) <https://github.com/DLR-RM/stable-baselines3>`_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch.
|
||||
It is the next major version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_.
|
||||
|
||||
|
||||
|
|
@ -16,6 +16,8 @@ RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/
|
|||
|
||||
RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
|
||||
|
||||
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
|
||||
|
||||
Main Features
|
||||
--------------
|
||||
|
|
@ -45,6 +47,7 @@ Main Features
|
|||
guide/callbacks
|
||||
guide/tensorboard
|
||||
guide/rl_zoo
|
||||
guide/sb3_contrib
|
||||
guide/imitation
|
||||
guide/migration
|
||||
guide/checking_nan
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.11.0a2 (WIP)
|
||||
Pre-Release 0.11.0a4 (WIP)
|
||||
-------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -44,6 +44,9 @@ Others:
|
|||
- Add signatures to callable type annotations (@ernestum)
|
||||
- Improve error message in ``NatureCNN``
|
||||
- Added checks for supported action spaces to improve clarity of error messages for the user
|
||||
- Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib.
|
||||
- Updated docker base image to Ubuntu 18.04
|
||||
- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -55,6 +58,7 @@ Documentation:
|
|||
- Added example of learning rate schedule
|
||||
- Added SUMO-RL as example project (@LucasAlegre)
|
||||
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
|
||||
- Added SB3-Contrib page
|
||||
|
||||
Pre-Release 0.10.0 (2020-10-28)
|
||||
-------------------------------
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
CPU_PARENT=ubuntu:16.04
|
||||
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04
|
||||
CPU_PARENT=ubuntu:18.04
|
||||
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
|
||||
|
||||
TAG=stablebaselines/stable-baselines3
|
||||
VERSION=$(cat ./stable_baselines3/version.txt)
|
||||
|
|
|
|||
7
setup.py
7
setup.py
|
|
@ -10,7 +10,7 @@ long_description = """
|
|||
|
||||
# Stable Baselines3
|
||||
|
||||
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
||||
Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
||||
|
||||
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
|
||||
|
||||
|
|
@ -29,6 +29,9 @@ https://stable-baselines3.readthedocs.io/en/master/
|
|||
RL Baselines3 Zoo:
|
||||
https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
|
||||
SB3 Contrib:
|
||||
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
|
||||
## Quick example
|
||||
|
||||
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym.
|
||||
|
|
@ -112,7 +115,7 @@ setup(
|
|||
"atari_py~=0.2.0",
|
||||
"pillow",
|
||||
# Tensorboard support
|
||||
"tensorboard",
|
||||
"tensorboard>=2.2.0",
|
||||
# Checking memory taken by replay buffer
|
||||
"psutil",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -155,23 +155,23 @@ class DQN(OffPolicyAlgorithm):
|
|||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
with th.no_grad():
|
||||
# Compute the target Q values
|
||||
target_q = self.q_net_target(replay_data.next_observations)
|
||||
# Compute the next Q-values using the target network
|
||||
next_q_values = self.q_net_target(replay_data.next_observations)
|
||||
# Follow greedy policy: use the one with the highest value
|
||||
target_q, _ = target_q.max(dim=1)
|
||||
next_q_values, _ = next_q_values.max(dim=1)
|
||||
# Avoid potential broadcast issue
|
||||
target_q = target_q.reshape(-1, 1)
|
||||
next_q_values = next_q_values.reshape(-1, 1)
|
||||
# 1-step TD target
|
||||
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q estimates
|
||||
current_q = self.q_net(replay_data.observations)
|
||||
# Get current Q-values estimates
|
||||
current_q_values = self.q_net(replay_data.observations)
|
||||
|
||||
# Retrieve the q-values for the actions from the replay buffer
|
||||
current_q = th.gather(current_q, dim=1, index=replay_data.actions.long())
|
||||
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
|
||||
|
||||
# Compute Huber loss (less sensitive to outliers)
|
||||
loss = F.smooth_l1_loss(current_q, target_q)
|
||||
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
||||
losses.append(loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ class QNetwork(BasePolicy):
|
|||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
features_extractor=self.features_extractor,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -223,20 +223,20 @@ class SAC(OffPolicyAlgorithm):
|
|||
with th.no_grad():
|
||||
# Select action according to policy
|
||||
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
|
||||
# Compute the target Q value: min over all critics targets
|
||||
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
|
||||
target_q, _ = th.min(targets, dim=1, keepdim=True)
|
||||
# Compute the next Q values: min over all critics targets
|
||||
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
|
||||
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
||||
# add entropy term
|
||||
target_q = target_q - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
|
||||
# td error + entropy term
|
||||
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q estimates for each critic network
|
||||
# Get current Q-values estimates for each critic network
|
||||
# using action from the replay buffer
|
||||
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
|
||||
current_q_values = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates])
|
||||
critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
|
|
|
|||
|
|
@ -146,16 +146,16 @@ class TD3(OffPolicyAlgorithm):
|
|||
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
|
||||
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
|
||||
|
||||
# Compute the target Q value: min over all critics targets
|
||||
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
|
||||
target_q, _ = th.min(targets, dim=1, keepdim=True)
|
||||
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
|
||||
# Compute the next Q-values: min over all critics targets
|
||||
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
|
||||
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q estimates for each critic network
|
||||
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
|
||||
# Get current Q-values estimates for each critic network
|
||||
current_q_values = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = sum([F.mse_loss(current_q, target_q) for current_q in current_q_estimates])
|
||||
critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critics
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.11.0a2
|
||||
0.11.0a4
|
||||
|
|
|
|||
|
|
@ -408,6 +408,82 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
os.remove(tmp_path / "actor.pkl")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [DQN])
|
||||
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_save_load_q_net(tmp_path, model_class, policy_str):
|
||||
"""
|
||||
Test saving and loading q-network/quantile net only.
|
||||
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [DQN]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(
|
||||
buffer_size=250,
|
||||
learning_starts=100,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
||||
q_net = model.q_net
|
||||
q_net_class = q_net.__class__
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(q_net.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
|
||||
# Update model parameters with the new random values
|
||||
q_net.load_state_dict(random_params)
|
||||
|
||||
new_params = q_net.state_dict()
|
||||
# Check that all params are different now
|
||||
for k in params:
|
||||
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
|
||||
|
||||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||
|
||||
# Save and load q_net
|
||||
q_net.save(tmp_path / "q_net.pkl")
|
||||
|
||||
del q_net
|
||||
|
||||
q_net = q_net_class.load(tmp_path / "q_net.pkl")
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = q_net.state_dict()
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "q_net.pkl")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
|
||||
def test_open_file_str_pathlib(tmp_path, pathtype):
|
||||
# check that suffix isn't added because we used open_path first
|
||||
|
|
|
|||
Loading…
Reference in a new issue