mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
Clean up code + bug fixes
This commit is contained in:
parent
ea20721632
commit
c542009641
15 changed files with 51 additions and 78 deletions
|
|
@ -1,6 +1,3 @@
|
|||
import os
|
||||
|
||||
import gym
|
||||
import pytest
|
||||
|
||||
from torchy_baselines import PPO
|
||||
|
|
@ -15,4 +12,4 @@ from torchy_baselines import PPO
|
|||
[12, dict(pi=[8])],
|
||||
])
|
||||
def test_flexible_mlp(net_arch):
|
||||
model = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|
||||
_ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, TanhBijector, \
|
||||
StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.utils import set_random_seed
|
||||
from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution,\
|
||||
CategoricalDistribution, TanhBijector, StateDependentNoiseDistribution
|
||||
|
||||
|
||||
# TODO: more tests for the other distributions
|
||||
|
|
@ -43,6 +42,7 @@ def test_sde_distribution():
|
|||
|
||||
N_ACTIONS = 1
|
||||
|
||||
|
||||
# TODO: fix for num action > 1
|
||||
# TODO: analytical form for squashed Gaussian?
|
||||
@pytest.mark.parametrize("dist", [
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
import torch as th
|
||||
from copy import deepcopy
|
||||
|
||||
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv
|
||||
|
||||
MODEL_LIST = [
|
||||
CEMRL,
|
||||
|
|
@ -81,7 +80,8 @@ def test_save_load(model_class):
|
|||
for optimizer, opt_state in opt_params.items():
|
||||
for param_group_idx, param_group in enumerate(opt_state['param_groups']):
|
||||
for param_key, param_value in param_group.items():
|
||||
if param_key == 'params': # don't know how to handle params correctly, therefore only check if we have the same amount
|
||||
# don't know how to handle params correctly, therefore only check if we have the same amount
|
||||
if param_key == 'params':
|
||||
assert len(param_value) == len(
|
||||
new_opt_params[optimizer]['param_groups'][param_group_idx][param_key])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,8 @@
|
|||
import pytest
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from torch.distributions import Normal
|
||||
|
||||
from torchy_baselines import A2C, TD3, SAC
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from torchy_baselines.common.monitor import Monitor
|
||||
|
||||
|
||||
def test_state_dependent_exploration_grad():
|
||||
|
|
@ -35,7 +31,7 @@ def test_state_dependent_exploration_grad():
|
|||
action_dist = Normal(mu, th.sqrt(variance))
|
||||
|
||||
# Sum over the action dimension because we assume they are independent
|
||||
loss = action_dist.log_prob((action).detach()).sum(dim=-1).mean()
|
||||
loss = action_dist.log_prob(action.detach()).sum(dim=-1).mean()
|
||||
loss.backward()
|
||||
|
||||
# From Rueckstiess paper: check that the computed gradient
|
||||
|
|
@ -72,6 +68,6 @@ def test_scheduler():
|
|||
return -2.0 * progress + 1
|
||||
|
||||
model = TD3('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
|
||||
verbose=1, sde_log_std_scheduler=scheduler)
|
||||
verbose=1, sde_log_std_scheduler=scheduler)
|
||||
model.learn(total_timesteps=int(1000), eval_freq=500)
|
||||
assert th.isclose(model.actor.log_std, th.ones_like(model.actor.log_std)).all()
|
||||
|
|
|
|||
|
|
@ -325,9 +325,8 @@ def test_vecenv_wrapper_getattr():
|
|||
assert wrapped.name_test() == CustomWrapperBB
|
||||
|
||||
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
|
||||
dummy = double_wrapped.var_a # should not raise as it is directly defined here
|
||||
_ = double_wrapped.var_a # should not raise as it is directly defined here
|
||||
with pytest.raises(AttributeError): # should raise due to ambiguity
|
||||
dummy = double_wrapped.var_b
|
||||
_ = double_wrapped.var_b
|
||||
with pytest.raises(AttributeError): # should raise as does not exist
|
||||
dummy = double_wrapped.nonexistent_attribute
|
||||
del dummy # keep linter happy
|
||||
_ = double_wrapped.nonexistent_attribute
|
||||
|
|
|
|||
|
|
@ -225,7 +225,8 @@ class BaseRLModel(object):
|
|||
:param env: (gym.Env) The environment for learning a policy
|
||||
"""
|
||||
if self.check_env(env, self.observation_space, self.action_space) is False:
|
||||
raise ValueError("The given environment is not compatible with model: observation and action spaces do not match")
|
||||
raise ValueError("The given environment is not compatible with model: "
|
||||
"observation and action spaces do not match")
|
||||
# it must be coherent now
|
||||
# if it is not a VecEnv, make it a VecEnv
|
||||
if not isinstance(env, VecEnv):
|
||||
|
|
@ -258,24 +259,6 @@ class BaseRLModel(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
|
||||
adam_epsilon=1e-8, val_interval=None):
|
||||
"""
|
||||
Pretrain a model using behavior cloning:
|
||||
supervised learning given an expert dataset.
|
||||
|
||||
NOTE: only Box and Discrete spaces are supported for now.
|
||||
|
||||
:param dataset: (ExpertDataset) Dataset manager
|
||||
:param n_epochs: (int) Number of iterations on the training set
|
||||
:param learning_rate: (float) Learning rate
|
||||
:param adam_epsilon: (float) the epsilon value for the adam optimizer
|
||||
:param val_interval: (int) Report training and validation losses every n epochs.
|
||||
By default, every 10th of the maximum number of epochs.
|
||||
:return: (BaseRLModel) the pretrained model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, reset_num_timesteps=True):
|
||||
|
|
@ -308,11 +291,12 @@ class BaseRLModel(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def load_parameters(self, load_dict, opt_params=None):
|
||||
def load_parameters(self, load_dict, opt_params):
|
||||
"""
|
||||
Load model parameters from a dictionary
|
||||
load_dict should contain all keys from torch.model.state_dict()
|
||||
If opt_params are given this does also load agent's optimizer-parameters, but can only be handled in child classes.
|
||||
If opt_params are given this does also load agent's optimizer-parameters,
|
||||
but can only be handled in child classes.
|
||||
|
||||
|
||||
:param load_dict: (dict) dict of parameters from model.state_dict()
|
||||
|
|
@ -350,6 +334,7 @@ class BaseRLModel(object):
|
|||
env = data["env"]
|
||||
|
||||
# first create model, but only setup if a env was given
|
||||
# noinspection PyArgumentList
|
||||
model = cls(policy=data["policy_class"], env=env, _init_setup_model=env is not None)
|
||||
|
||||
# load parameters
|
||||
|
|
@ -365,7 +350,8 @@ class BaseRLModel(object):
|
|||
:param load_path: (str) Where to load the model from
|
||||
:param load_data: (bool) Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict) and dict of optimizer parameters (dict of state_dict)
|
||||
:return: (dict),(dict),(dict) Class parameters, model parameters (state_dict)
|
||||
and dict of optimizer parameters (dict of state_dict)
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
|
|
@ -403,7 +389,7 @@ class BaseRLModel(object):
|
|||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
|
|
@ -521,7 +507,7 @@ class BaseRLModel(object):
|
|||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
|
|
@ -577,7 +563,8 @@ class BaseRLModel(object):
|
|||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(np.array(done_bool[0]).copy())
|
||||
self.rollout_data['values'].append(self.vf_net(th.FloatTensor(obs).to(self.device))[0].cpu().detach().numpy())
|
||||
obs_tensor = th.FloatTensor(obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
obs = new_obs
|
||||
# Save the true unnormalized observation
|
||||
|
|
@ -674,9 +661,9 @@ class BaseRLModel(object):
|
|||
with archive.open('params.pth', mode="w") as param_file:
|
||||
th.save(params, param_file)
|
||||
if opt_params is not None:
|
||||
for file_name, dict in opt_params.items():
|
||||
for file_name, dict_ in opt_params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as opt_param_file:
|
||||
th.save(dict, opt_param_file)
|
||||
th.save(dict_, opt_param_file)
|
||||
|
||||
@staticmethod
|
||||
def excluded_save_params():
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines.common.vec_env import unwrap_vec_normalize
|
||||
|
||||
|
||||
class BaseBuffer(object):
|
||||
"""
|
||||
|
|
@ -79,7 +77,8 @@ class BaseBuffer(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _normalize_obs(self, obs, env=None):
|
||||
@staticmethod
|
||||
def _normalize_obs(obs, env=None):
|
||||
if env is not None:
|
||||
# TODO: get rid of pytorch - numpy conversion
|
||||
return th.FloatTensor(env.normalize_obs(obs.numpy()))
|
||||
|
|
|
|||
|
|
@ -411,7 +411,8 @@ class TanhBijector(object):
|
|||
super(TanhBijector, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return th.tanh(x)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""
|
||||
Taken from stable-baselines
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
|
@ -185,15 +184,6 @@ class CSVOutputFormat(KVWriter):
|
|||
self.file.close()
|
||||
|
||||
|
||||
def summary_val(key, value):
|
||||
"""
|
||||
:param key: (str)
|
||||
:param value: (float)
|
||||
"""
|
||||
kwargs = {'tag': key, 'simple_value': float(value)}
|
||||
return tf.Summary.Value(**kwargs)
|
||||
|
||||
|
||||
def valid_float_value(value):
|
||||
"""
|
||||
Returns True if the value can be successfully cast into a float
|
||||
|
|
|
|||
|
|
@ -35,4 +35,4 @@ def sync_envs_normalization(env, eval_env):
|
|||
if isinstance(env_tmp, VecNormalize):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp.venv
|
||||
eval_env_tmp = eval_env_tmp.venv
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
def normalize_obs(self, obs):
|
||||
if self.norm_obs:
|
||||
return np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs,
|
||||
self.clip_obs)
|
||||
self.clip_obs)
|
||||
return obs
|
||||
|
||||
def normalize_reward(self, reward):
|
||||
|
|
|
|||
|
|
@ -157,7 +157,8 @@ class PPOPolicy(BasePolicy):
|
|||
return self.action_dist.proba_distribution(mean_actions, deterministic=deterministic)
|
||||
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde,
|
||||
deterministic=deterministic)
|
||||
|
||||
def actor_forward(self, obs, deterministic=False):
|
||||
latent_pi, _, latent_sde = self._get_latent(obs)
|
||||
|
|
|
|||
|
|
@ -241,6 +241,9 @@ class SACPolicy(BasePolicy):
|
|||
def make_critic(self):
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
|
||||
def forward(self, obs):
|
||||
return self.actor(obs)
|
||||
|
||||
|
||||
MlpPolicy = SACPolicy
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import time
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -68,7 +66,8 @@ class SAC(BaseRLModel):
|
|||
_init_setup_model=True):
|
||||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.target_entropy = target_entropy
|
||||
|
|
@ -131,7 +130,8 @@ class SAC(BaseRLModel):
|
|||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs)
|
||||
self.learning_rate, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
|
|
@ -203,7 +203,6 @@ class SAC(BaseRLModel):
|
|||
ent_coef_loss.backward()
|
||||
self.ent_coef_optimizer.step()
|
||||
|
||||
|
||||
with th.no_grad():
|
||||
# if self.use_sde:
|
||||
# self.actor.reset_noise(batch_size=batch_size)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import time
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -62,7 +60,8 @@ class TD3(BaseRLModel):
|
|||
seed=0, device='auto', _init_setup_model=True):
|
||||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
|
||||
self.buffer_size = buffer_size
|
||||
self.learning_rate = learning_rate
|
||||
|
|
@ -94,7 +93,8 @@ class TD3(BaseRLModel):
|
|||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device, **self.policy_kwargs)
|
||||
self.learning_rate, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
||||
|
|
@ -209,7 +209,8 @@ class TD3(BaseRLModel):
|
|||
# self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
# Unpack
|
||||
obs, action, advantage, returns = [self.rollout_data[key] for key in ['observations', 'actions', 'advantage', 'returns']]
|
||||
obs, action, advantage, returns = [self.rollout_data[key] for key in
|
||||
['observations', 'actions', 'advantage', 'returns']]
|
||||
|
||||
log_prob, entropy = self.actor.evaluate_actions(obs, action)
|
||||
values = self.vf_net(obs).flatten()
|
||||
|
|
|
|||
Loading…
Reference in a new issue