From 6e75b6808b440996ab4988b259c27f13aff6945d Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 29 Jan 2025 18:30:48 +0100 Subject: [PATCH] hack --- tests/models/timm_backbone/test_modeling_timm_backbone.py | 7 +++++++ tests/test_modeling_common.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 296a38c17..8d7c486ac 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -119,6 +119,13 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`." ) def test_batching_equivalence(self): + + import os + test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + if not hasattr(self, "_data"): + self._data = {} + if test_name not in self._data: + self._data[test_name] = {"atol": 1e-4, "rtol": 1e-4} super().test_batching_equivalence() def test_timm_transformer_backbone_equivalence(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d52e8281..dfc5b9d22 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -812,7 +812,13 @@ class ModelTesterMixin: torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" ) try: - torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5) + import os + test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + if test_name in getattr(self, "_data", {}): + test_data = self._data[test_name] + atol = test_data.get("atol", 1e-5) + rtol = test_data.get("rtol", 1e-5) + torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol) except AssertionError as e: msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n" msg += str(e)