mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Bug fix for off-policy normalization
Now working properly
This commit is contained in:
parent
f719544386
commit
cdb62a93fe
3 changed files with 18 additions and 8 deletions
|
|
@ -24,7 +24,6 @@ TODO:
|
|||
- SDE: reduce the number of parameters (only n_features instead of n_features x n_actions) for A2C
|
||||
(done for TD3)
|
||||
- SDE: learn the feature extractor?
|
||||
- DEBUG normalization with replay buffer (apparently pb with observation normalization)
|
||||
|
||||
Later:
|
||||
- get_parameters / set_parameters
|
||||
|
|
|
|||
|
|
@ -46,10 +46,10 @@ def test_vec_env():
|
|||
@pytest.mark.parametrize("model_class", [SAC, TD3, CEMRL])
|
||||
def test_offpolicy_normalization(model_class):
|
||||
env = DummyVecEnv([lambda: gym.make(ENV_ID)])
|
||||
env = VecNormalize(env, norm_obs=False, norm_reward=False, clip_obs=10., clip_reward=10.)
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
|
||||
|
||||
eval_env = DummyVecEnv([lambda: gym.make(ENV_ID)])
|
||||
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10., clip_reward=10.)
|
||||
|
||||
model = model_class('MlpPolicy', env, verbose=1)
|
||||
model.learn(total_timesteps=10000, eval_env=eval_env, eval_freq=1000)
|
||||
model.learn(total_timesteps=1000, eval_env=eval_env, eval_freq=500)
|
||||
|
|
|
|||
|
|
@ -333,6 +333,11 @@ class BaseRLModel(object):
|
|||
assert isinstance(env, VecEnv)
|
||||
assert env.num_envs == 1
|
||||
|
||||
# Retrieve unnormalized observation for saving into the buffer
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = self._vec_normalize_env.get_original_obs()
|
||||
|
||||
|
||||
self.rollout_data = None
|
||||
if hasattr(self, 'use_sde') and self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
|
@ -377,12 +382,13 @@ class BaseRLModel(object):
|
|||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
# TODO: save it instead of unnormalizing
|
||||
obs = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
new_obs = self._vec_normalize_env.get_original_obs()
|
||||
reward = self._vec_normalize_env.get_original_reward()
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
obs_, new_obs_, reward_ = obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(obs, new_obs, action, reward, done_bool)
|
||||
replay_buffer.add(obs_, new_obs_, action, reward_, done_bool)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
|
|
@ -392,6 +398,11 @@ class BaseRLModel(object):
|
|||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
# otherwise obs_ = self._vec_normalize_env.unnormalize_obs(obs)
|
||||
# is a good approximation
|
||||
if self._vec_normalize_env is not None:
|
||||
obs_ = new_obs_
|
||||
|
||||
num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue