diff --git a/tests/test_sde.py b/tests/test_sde.py index f397558..955f9ab 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -3,7 +3,7 @@ import pytest import torch as th from torch.distributions import Normal -from torchy_baselines import A2C +from torchy_baselines import A2C, TD3 def test_state_dependent_exploration(): @@ -54,4 +54,12 @@ def test_state_dependent_noise(model_class): model = model_class('MlpPolicy', env, n_steps=200, max_grad_norm=1, use_rms_prop=False, use_sde=True, ent_coef=0.00, verbose=1, create_eval_env=True, learning_rate=3e-4, policy_kwargs=dict(log_std_init=0.0, ortho_init=False, net_arch=[256, dict(pi=[256], vf=[256])]), seed=None) - model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env) + # model.learn(total_timesteps=int(20000), log_interval=5, eval_freq=10000, eval_env=eval_env) + model.learn(total_timesteps=int(1000), log_interval=5, eval_freq=500, eval_env=eval_env) + + +@pytest.mark.parametrize("model_class", [TD3]) +def test_state_dependent_offpolicy_noise(model_class): + model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, + verbose=1, policy_kwargs=dict(log_std_init=-2)) + model.learn(total_timesteps=int(20000), eval_freq=1000) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a6b9a41..d96950d 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -327,6 +327,9 @@ class BaseRLModel(object): assert isinstance(env, VecEnv) assert env.num_envs == 1 + if hasattr(self, 'use_sde') and self.use_sde: + self.policy.reset_noise() + while total_steps < n_steps or total_episodes < n_episodes: done = False # Reset environment: not needed for VecEnv @@ -338,6 +341,8 @@ class BaseRLModel(object): if num_timesteps < learning_starts: action = np.array([self.action_space.sample()]) else: + if hasattr(self, 'use_sde'): + deterministic = not self.use_sde action = self.predict(obs, deterministic=deterministic) # Rescale the action from [low, high] to [-1, 1] diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 4cf32f3..adbbda1 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -1,19 +1,46 @@ import torch as th import torch.nn as nn +from torch.distributions import Normal from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork class Actor(BaseNetwork): - def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU): + def __init__(self, obs_dim, action_dim, net_arch, activation_fn=nn.ReLU, + use_sde=False, log_std_init=0.0, clip_noise=0.1): super(Actor, self).__init__() - # TODO: orthogonal initialization? - actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) - self.actor_net = nn.Sequential(*actor_net) + self.latent_pi, self.log_std = None, None + self.weights_dist, self.exploration_mat = None, None + self.use_sde = use_sde - def forward(self, obs): - return self.actor_net(obs) + if use_sde: + latent_dim = net_arch[-1] + latent_pi = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False) + self.latent_pi = nn.Sequential(*latent_pi) + self.log_std = nn.Parameter(th.ones(latent_dim, action_dim) * log_std_init) + self.actor_net = nn.Sequential(nn.Linear(net_arch[-1], action_dim), nn.Tanh()) + self.clip_noise = clip_noise + self.reset_noise() + else: + actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True) + self.actor_net = nn.Sequential(*actor_net) + + def reset_noise(self): + self.weights_dist = Normal(th.zeros_like(self.log_std), th.exp(self.log_std)) + self.exploration_mat = self.weights_dist.rsample() + + def forward(self, obs, deterministic=True): + if self.use_sde: + latent_pi = self.latent_pi(obs) + if deterministic: + return self.actor_net(latent_pi) + noise = th.mm(latent_pi.detach(), self.exploration_mat) + # noise = th.clamp(noise, -self.clip_noise, self.clip_noise) + # TODO: fix clipping + return th.clamp(self.actor_net(latent_pi) + noise, -1, 1) + else: + return self.actor_net(obs) class Critic(BaseNetwork): @@ -40,7 +67,7 @@ class Critic(BaseNetwork): class TD3Policy(BasePolicy): def __init__(self, observation_space, action_space, learning_rate, net_arch=None, device='cpu', - activation_fn=nn.ReLU): + activation_fn=nn.ReLU, use_sde=False, log_std_init=0.0): super(TD3Policy, self).__init__(observation_space, action_space, device) if net_arch is None: @@ -56,8 +83,14 @@ class TD3Policy(BasePolicy): 'net_arch': self.net_arch, 'activation_fn': self.activation_fn } + self.actor_kwargs = self.net_args.copy() + self.actor_kwargs['use_sde'] = use_sde + self.actor_kwargs['log_std_init'] = log_std_init + self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None + self.use_sde = use_sde + self.log_std_init = log_std_init self._build(learning_rate) def _build(self, learning_rate): @@ -71,14 +104,17 @@ class TD3Policy(BasePolicy): self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1)) + def reset_noise(self): + return self.actor.reset_noise() + def make_actor(self): - return Actor(**self.net_args).to(self.device) + return Actor(**self.actor_kwargs).to(self.device) def make_critic(self): return Critic(**self.net_args).to(self.device) - def forward(self, obs): - return self.actor(obs) + def forward(self, obs, deterministic=True): + return self.actor(obs, deterministic=deterministic) MlpPolicy = TD3Policy diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 49cf16e..46943a9 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -37,6 +37,8 @@ class TD3(BaseRLModel): :param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy (smoothing noise) :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise. + :param use_sde: (bool) Whether to use State Dependent Exploration (SDE) + instead of action noise exploration (default: False) :param create_eval_env: (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation @@ -50,6 +52,7 @@ class TD3(BaseRLModel): policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5, + use_sde=False, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): @@ -70,6 +73,7 @@ class TD3(BaseRLModel): self.policy_delay = policy_delay self.target_noise_clip = target_noise_clip self.target_policy_noise = target_policy_noise + self.use_sde = use_sde if _init_setup_model: self._setup_model() @@ -79,8 +83,8 @@ class TD3(BaseRLModel): obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy(self.observation_space, self.action_space, - self.learning_rate, device=self.device, **self.policy_kwargs) + self.policy = self.policy(self.observation_space, self.action_space, self.learning_rate, + use_sde=self.use_sde, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) self._create_aliases() @@ -90,12 +94,12 @@ class TD3(BaseRLModel): self.critic = self.policy.critic self.critic_target = self.policy.critic_target - def select_action(self, observation): + def select_action(self, observation, deterministic=True): # Normally not needed observation = np.array(observation) with th.no_grad(): observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.actor(observation).cpu().numpy() + return self.actor(observation, deterministic=deterministic).cpu().numpy() def predict(self, observation, state=None, mask=None, deterministic=True): """ @@ -107,7 +111,7 @@ class TD3(BaseRLModel): :param deterministic: (bool) Whether or not to return deterministic actions. :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) """ - return self.unscale_action(self.select_action(observation)) + return self.unscale_action(self.select_action(observation, deterministic=deterministic)) def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0): # Update optimizer learning rate