From a31a60d85bb62171d3faa90a89618335ff700c8f Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 23 May 2024 21:08:12 +0000 Subject: [PATCH] Change run_test.py arg parsing to handle additional args better (#126709) Do not inherit parser from common_utils * I don't think we use any variables in run_test that depend on those, and I think all tests except doctests run in a subprocess so they will parse the args in common_utils and set the variables. I don't think doctests wants any of those variables? Parse known args, add the extra args as extra, pass the extra ones along to the subprocess Removes the first instance of `--` I think I will miss run_test telling me if an arg is valid or not Pull Request resolved: https://github.com/pytorch/pytorch/pull/126709 Approved by: https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/Flamefire --- test/run_test.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index d43a396f144..23160d01281 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -29,7 +29,6 @@ from torch.testing._internal.common_utils import ( IS_CI, IS_MACOS, IS_WINDOWS, - parser as common_parser, retry_shell, set_cwd, shell, @@ -384,7 +383,7 @@ def run_test( ) -> int: env = env or os.environ.copy() maybe_set_hip_visible_devies() - unittest_args = options.additional_unittest_args.copy() + unittest_args = options.additional_args.copy() test_file = test_module.name stepcurrent_key = test_file @@ -1057,7 +1056,6 @@ def parse_args(): description="Run the PyTorch unit test suite", epilog="where TESTS is any of: {}".format(", ".join(TESTS)), formatter_class=argparse.RawTextHelpFormatter, - parents=[common_parser], ) parser.add_argument( "-v", @@ -1206,12 +1204,6 @@ def parse_args(): and "debug" not in BUILD_ENVIRONMENT and "parallelnative" not in BUILD_ENVIRONMENT, ) - parser.add_argument( - "additional_unittest_args", - nargs="*", - help="additional arguments passed through to unittest, e.g., " - "python run_test.py -i sparse -- TestSparse.test_factory_size_check", - ) parser.add_argument( "--shard", nargs=2, @@ -1273,7 +1265,11 @@ def parse_args(): help="Run tests with TorchInductor turned on", ) - return parser.parse_args() + args, extra = parser.parse_known_args() + if "--" in extra: + extra.remove("--") + args.additional_args = extra + return args def exclude_tests( @@ -1626,7 +1622,7 @@ def run_tests( options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - options_clone.additional_unittest_args.extend(["-m", "serial"]) + options_clone.additional_args.extend(["-m", "serial"]) failure = run_test_module(test, test_directory, options_clone) test_failed = handle_error_messages(failure) if ( @@ -1641,7 +1637,7 @@ def run_tests( options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - options_clone.additional_unittest_args.extend(["-m", "not serial"]) + options_clone.additional_args.extend(["-m", "not serial"]) pool.apply_async( run_test_module, args=(test, test_directory, options_clone),