Fix train_freq at load time (#332)

* Fix train_freq loading

* Update docker

* Add sanity checks + tests for train freq
This commit is contained in:
Antonin RAFFIN 2021-02-27 19:53:13 +01:00 committed by GitHub
parent 0c50d75ecb
commit b2c94a677d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 14 deletions

View file

@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.11.0a4
image: stablebaselines/stable-baselines3-cpu:0.11.1
type-check:
script:

View file

@ -3,6 +3,15 @@
Changelog
==========
Pre-Release 0.11.1 (2021-02-27)
-------------------------------
Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``train_freq`` was not properly converted when loading a saved model
Pre-Release 0.11.0 (2021-02-27)
-------------------------------

View file

@ -131,15 +131,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# see https://github.com/hill-a/stable-baselines/issues/863
self.remove_time_limit_termination = remove_time_limit_termination
if isinstance(train_freq, int):
train_freq = (train_freq, "step")
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
self.train_freq = TrainFreq(*train_freq)
# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq
self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer = None # type: Optional[ReplayBuffer]
@ -149,6 +142,28 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _convert_train_freq(self) -> None:
"""
Convert `train_freq` parameter (int or tuple)
to a TrainFreq object.
"""
if not isinstance(self.train_freq, TrainFreq):
train_freq = self.train_freq
# The value of the train frequency will be checked later
if not isinstance(train_freq, tuple):
train_freq = (train_freq, "step")
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
self.train_freq = TrainFreq(*train_freq)
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
@ -167,6 +182,9 @@ class OffPolicyAlgorithm(BaseAlgorithm):
)
self.policy = self.policy.to(self.device)
# Convert train freq parameter to TrainFreq object
self._convert_train_freq()
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Save the replay buffer as a pickle file.

View file

@ -1 +1 @@
0.11.0
0.11.1

View file

@ -103,3 +103,39 @@ def test_dqn():
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
def test_train_freq(tmp_path, train_freq):
model = SAC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=150)
model.save(tmp_path / "test_save.zip")
env = model.get_env()
model = SAC.load(tmp_path / "test_save.zip", env=env)
model.learn(total_timesteps=150)
model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env)
model.learn(total_timesteps=150)
@pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")])
def test_train_freq_fail(train_freq):
with pytest.raises(ValueError):
model = SAC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=250)

View file

@ -176,7 +176,7 @@ def test_set_env(model_class):
kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
kwargs = dict(learning_starts=100)
kwargs = dict(learning_starts=100, train_freq=4)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)
@ -238,12 +238,12 @@ def test_save_load_env_cnn(tmp_path, model_class):
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
if model_class == TD3:
kwargs.update(dict(buffer_size=100, learning_starts=50))
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))
model = model_class("CnnPolicy", env, **kwargs).learn(100)
model.save(tmp_path / "test_save")
# Test loading with env and continuing training
model = model_class.load(str(tmp_path / "test_save.zip"), env=env).learn(100)
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
# clear file from os
os.remove(tmp_path / "test_save.zip")