mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix grad computation for sde test
This commit is contained in:
parent
0885dbe74b
commit
d26fcf4566
1 changed files with 24 additions and 12 deletions
|
|
@ -15,37 +15,49 @@ def test_state_dependent_exploration():
|
|||
"""
|
||||
n_states = 2
|
||||
state_dim = 3
|
||||
# TODO: fix for action_dim > 1
|
||||
action_dim = 1
|
||||
sigma = th.ones(state_dim, 1, requires_grad=True)
|
||||
action_dim = 10
|
||||
sigma_hat = th.ones(state_dim, action_dim, requires_grad=True)
|
||||
# Reduce the number of parameters
|
||||
# sigma_ = th.ones(state_dim, action_dim) * sigma_
|
||||
|
||||
# weights_dist = Normal(th.zeros_like(log_sigma), th.exp(log_sigma))
|
||||
th.manual_seed(2)
|
||||
weights_dist = Normal(th.zeros_like(sigma), sigma)
|
||||
|
||||
weights_dist = Normal(th.zeros_like(sigma_hat), sigma_hat)
|
||||
weights = weights_dist.rsample()
|
||||
|
||||
state = th.rand(n_states, state_dim)
|
||||
mu = th.ones(action_dim)
|
||||
# print(weights.shape, state.shape)
|
||||
noise = th.mm(state, weights)
|
||||
|
||||
variance = th.mm(state ** 2, sigma ** 2)
|
||||
action = mu + noise
|
||||
|
||||
variance = th.mm(state ** 2, sigma_hat ** 2)
|
||||
action_dist = Normal(mu, th.sqrt(variance))
|
||||
|
||||
loss = action_dist.log_prob((mu + noise).detach()).mean()
|
||||
# Sum over the action dimension because we assume they are independent
|
||||
loss = action_dist.log_prob((action).detach()).sum(dim=-1).mean()
|
||||
loss.backward()
|
||||
|
||||
# From Rueckstiess paper
|
||||
grad = th.zeros_like(sigma)
|
||||
# From Rueckstiess paper: check that the computed gradient
|
||||
# correspond to the analytical form
|
||||
grad = th.zeros_like(sigma_hat)
|
||||
for j in range(action_dim):
|
||||
# sigma_hat is the std of the gaussian distribution of the noise matrix weights
|
||||
# sigma_j = sum_j(state_i **2 * sigma_hat_ij ** 2)
|
||||
# sigma_j is the standard deviation of the policy gaussian distribution
|
||||
sigma_j = th.sqrt(variance[:, j])
|
||||
for i in range(state_dim):
|
||||
a = ((noise[:, j] ** 2 - variance[:, j]) / (variance[:, j] ** 2)) * (state[:, i] ** 2 * sigma[i, j])
|
||||
grad[i, j] = a.mean()
|
||||
# Derivative of the log probability of the jth component of the action
|
||||
# w.r.t. the standard deviation sigma_j
|
||||
d_log_policy_j = (noise[:, j] ** 2 - sigma_j ** 2) / sigma_j ** 3
|
||||
# Derivative of sigma_j w.r.t. sigma_hat_ij
|
||||
d_log_sigma_j = (state[:, i] ** 2 * sigma_hat[i, j]) / sigma_j
|
||||
# Chain rule, average over the minibatch
|
||||
grad[i, j] = (d_log_policy_j * d_log_sigma_j).mean()
|
||||
|
||||
# sigma.grad should be equal to grad
|
||||
assert sigma.grad.allclose(grad)
|
||||
assert sigma_hat.grad.allclose(grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
|
|
|
|||
Loading…
Reference in a new issue