From 97db5b42237ece09401cc23096256da5ceeca14b Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 5 Jul 2022 14:53:43 +0200 Subject: [PATCH] Update expected values in DecisionTransformerModelIntegrationTest (#18016) Co-authored-by: ydshieh --- .../test_modeling_decision_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/decision_transformer/test_modeling_decision_transformer.py b/tests/models/decision_transformer/test_modeling_decision_transformer.py index 9124c64fa..3ac89cf9b 100644 --- a/tests/models/decision_transformer/test_modeling_decision_transformer.py +++ b/tests/models/decision_transformer/test_modeling_decision_transformer.py @@ -206,7 +206,9 @@ class DecisionTransformerModelIntegrationTest(unittest.TestCase): torch.manual_seed(0) state = torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32) # env.reset() - expected_outputs = torch.tensor([[0.2384, -0.2955, 0.8741], [0.6765, -0.0793, -0.1298]], device=torch_device) + expected_outputs = torch.tensor( + [[0.242793, -0.28693074, 0.8742613], [0.67815274, -0.08101085, -0.12952147]], device=torch_device + ) returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1) states = state