mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-29 03:31:08 +00:00
Fix train_freq at load time (#332)
* Fix train_freq loading * Update docker * Add sanity checks + tests for train freq
This commit is contained in:
parent
0c50d75ecb
commit
b2c94a677d
6 changed files with 77 additions and 14 deletions
|
|
@ -1,4 +1,4 @@
|
|||
image: stablebaselines/stable-baselines3-cpu:0.11.0a4
|
||||
image: stablebaselines/stable-baselines3-cpu:0.11.1
|
||||
|
||||
type-check:
|
||||
script:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,15 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.11.1 (2021-02-27)
|
||||
-------------------------------
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug where ``train_freq`` was not properly converted when loading a saved model
|
||||
|
||||
|
||||
|
||||
Pre-Release 0.11.0 (2021-02-27)
|
||||
-------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -131,15 +131,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# see https://github.com/hill-a/stable-baselines/issues/863
|
||||
self.remove_time_limit_termination = remove_time_limit_termination
|
||||
|
||||
if isinstance(train_freq, int):
|
||||
train_freq = (train_freq, "step")
|
||||
|
||||
try:
|
||||
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
||||
except ValueError:
|
||||
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
||||
|
||||
self.train_freq = TrainFreq(*train_freq)
|
||||
# Save train freq parameter, will be converted later to TrainFreq object
|
||||
self.train_freq = train_freq
|
||||
|
||||
self.actor = None # type: Optional[th.nn.Module]
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
|
|
@ -149,6 +142,28 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# For gSDE only
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _convert_train_freq(self) -> None:
|
||||
"""
|
||||
Convert `train_freq` parameter (int or tuple)
|
||||
to a TrainFreq object.
|
||||
"""
|
||||
if not isinstance(self.train_freq, TrainFreq):
|
||||
train_freq = self.train_freq
|
||||
|
||||
# The value of the train frequency will be checked later
|
||||
if not isinstance(train_freq, tuple):
|
||||
train_freq = (train_freq, "step")
|
||||
|
||||
try:
|
||||
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
||||
except ValueError:
|
||||
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
||||
|
||||
if not isinstance(train_freq[0], int):
|
||||
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
||||
|
||||
self.train_freq = TrainFreq(*train_freq)
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
|
@ -167,6 +182,9 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
# Convert train freq parameter to TrainFreq object
|
||||
self._convert_train_freq()
|
||||
|
||||
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.11.0
|
||||
0.11.1
|
||||
|
|
|
|||
|
|
@ -103,3 +103,39 @@ def test_dqn():
|
|||
create_eval_env=True,
|
||||
)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
|
||||
def test_train_freq(tmp_path, train_freq):
|
||||
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
|
||||
learning_starts=100,
|
||||
buffer_size=10000,
|
||||
verbose=1,
|
||||
train_freq=train_freq,
|
||||
)
|
||||
model.learn(total_timesteps=150)
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
env = model.get_env()
|
||||
model = SAC.load(tmp_path / "test_save.zip", env=env)
|
||||
model.learn(total_timesteps=150)
|
||||
model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env)
|
||||
model.learn(total_timesteps=150)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")])
|
||||
def test_train_freq_fail(train_freq):
|
||||
with pytest.raises(ValueError):
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
|
||||
learning_starts=100,
|
||||
buffer_size=10000,
|
||||
verbose=1,
|
||||
train_freq=train_freq,
|
||||
)
|
||||
model.learn(total_timesteps=250)
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ def test_set_env(model_class):
|
|||
|
||||
kwargs = {}
|
||||
if model_class in {DQN, DDPG, SAC, TD3}:
|
||||
kwargs = dict(learning_starts=100)
|
||||
kwargs = dict(learning_starts=100, train_freq=4)
|
||||
elif model_class in {A2C, PPO}:
|
||||
kwargs = dict(n_steps=64)
|
||||
|
||||
|
|
@ -238,12 +238,12 @@ def test_save_load_env_cnn(tmp_path, model_class):
|
|||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
|
||||
if model_class == TD3:
|
||||
kwargs.update(dict(buffer_size=100, learning_starts=50))
|
||||
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))
|
||||
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(100)
|
||||
model.save(tmp_path / "test_save")
|
||||
# Test loading with env and continuing training
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env).learn(100)
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue