diff --git a/test/test_comparison_utils.py b/test/test_comparison_utils.py index fccc217bb7b..172e2c40929 100644 --- a/test/test_comparison_utils.py +++ b/test/test_comparison_utils.py @@ -2,7 +2,7 @@ # Owner(s): ["module: internals"] import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, run_tests class TestComparisonUtils(TestCase): def test_all_equal_no_assert(self): @@ -30,3 +30,7 @@ class TestComparisonUtils(TestCase): with self.assertRaises(RuntimeError): torch._assert_tensor_metadata(t, [3], [1], torch.float) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_pruning_op.py b/test/test_pruning_op.py index 88e5a4e57be..ef28381c190 100644 --- a/test/test_pruning_op.py +++ b/test/test_pruning_op.py @@ -4,7 +4,7 @@ import hypothesis.strategies as st from hypothesis import given import numpy as np import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, run_tests import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() @@ -76,3 +76,7 @@ class PruningOpTest(TestCase): ) def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) + + +if __name__ == '__main__': + run_tests()