diff --git a/test/run_test.py b/test/run_test.py index d43a396f144..23160d01281 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -29,7 +29,6 @@ from torch.testing._internal.common_utils import ( IS_CI, IS_MACOS, IS_WINDOWS, - parser as common_parser, retry_shell, set_cwd, shell, @@ -384,7 +383,7 @@ def run_test( ) -> int: env = env or os.environ.copy() maybe_set_hip_visible_devies() - unittest_args = options.additional_unittest_args.copy() + unittest_args = options.additional_args.copy() test_file = test_module.name stepcurrent_key = test_file @@ -1057,7 +1056,6 @@ def parse_args(): description="Run the PyTorch unit test suite", epilog="where TESTS is any of: {}".format(", ".join(TESTS)), formatter_class=argparse.RawTextHelpFormatter, - parents=[common_parser], ) parser.add_argument( "-v", @@ -1206,12 +1204,6 @@ def parse_args(): and "debug" not in BUILD_ENVIRONMENT and "parallelnative" not in BUILD_ENVIRONMENT, ) - parser.add_argument( - "additional_unittest_args", - nargs="*", - help="additional arguments passed through to unittest, e.g., " - "python run_test.py -i sparse -- TestSparse.test_factory_size_check", - ) parser.add_argument( "--shard", nargs=2, @@ -1273,7 +1265,11 @@ def parse_args(): help="Run tests with TorchInductor turned on", ) - return parser.parse_args() + args, extra = parser.parse_known_args() + if "--" in extra: + extra.remove("--") + args.additional_args = extra + return args def exclude_tests( @@ -1626,7 +1622,7 @@ def run_tests( options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - options_clone.additional_unittest_args.extend(["-m", "serial"]) + options_clone.additional_args.extend(["-m", "serial"]) failure = run_test_module(test, test_directory, options_clone) test_failed = handle_error_messages(failure) if ( @@ -1641,7 +1637,7 @@ def run_tests( options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - options_clone.additional_unittest_args.extend(["-m", "not serial"]) + options_clone.additional_args.extend(["-m", "not serial"]) pool.apply_async( run_test_module, args=(test, test_directory, options_clone),