From 03a84f97eaea47804569fcb2f20280fea6241ab0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 1 Dec 2019 16:46:39 +0100 Subject: [PATCH] Add monte-carlo test for SDE distribution --- tests/test_distributions.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 47651e4..130a1bc 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,8 +1,9 @@ import numpy as np import torch as th +from torchy_baselines.common.utils import set_random_seed from torchy_baselines.common.distributions import DiagGaussianDistribution, SquashedDiagGaussianDistribution,\ - CategoricalDistribution, TanhBijector + CategoricalDistribution, TanhBijector, StateDependentNoiseDistribution # TODO: more tests for the other distributions def test_bijector(): @@ -18,3 +19,21 @@ def test_bijector(): assert th.max(th.abs(squashed_actions)) <= 1.0 # Check the inverse method assert th.isclose(TanhBijector.inverse(squashed_actions), actions).all() + + +def test_sde_distribution(): + n_samples = int(5e6) + n_features = 2 + n_actions = 1 + deterministic_actions = th.ones(n_samples, n_actions) * 0.1 + state = th.ones(n_samples, n_features) * 0.3 + dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) + + set_random_seed(1) + _, log_std = dist.proba_distribution_net(n_features) + dist.sample_weights(log_std, batch_size=n_samples) + + actions, _ = dist.proba_distribution(deterministic_actions, log_std, state) + + assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=1e-3) + assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3)