diff --git a/test/run_test.py b/test/run_test.py index 340e1b17023..53d7245d720 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -673,6 +673,7 @@ def run_doctests(test_module, test_directory, options): import pathlib pkgpath = pathlib.Path(torch.__file__).parent + exclude_module_list = [] enabled = { # TODO: expose these options to the user # For now disable all feature-conditional tests @@ -687,6 +688,7 @@ def run_doctests(test_module, test_directory, options): 'autograd_profiler': 0, 'cpp_ext': 0, 'monitor': 0, + "onnx": "auto", } # Resolve "auto" based on a test to determine if the feature is available. @@ -710,6 +712,17 @@ def run_doctests(test_module, test_directory, options): else: enabled['qengine'] = True + if enabled["onnx"] == "auto": + try: + import onnx # NOQA + import onnxscript # NOQA + import onnxruntime # NOQA + except ImportError: + exclude_module_list.append("torch.onnx._internal.fx.*") + enabled["onnx"] = False + else: + enabled["onnx"] = True + # Set doctest environment variables if enabled['cuda']: os.environ['TORCH_DOCTEST_CUDA'] = '1' @@ -732,6 +745,9 @@ def run_doctests(test_module, test_directory, options): if enabled['monitor']: os.environ['TORCH_DOCTEST_MONITOR'] = '1' + if enabled["onnx"]: + os.environ['TORCH_DOCTEST_ONNX'] = '1' + if 0: # TODO: could try to enable some of these os.environ['TORCH_DOCTEST_QUANTIZED_DYNAMIC'] = '1' @@ -739,7 +755,6 @@ def run_doctests(test_module, test_directory, options): os.environ['TORCH_DOCTEST_AUTOGRAD'] = '1' os.environ['TORCH_DOCTEST_HUB'] = '1' os.environ['TORCH_DOCTEST_DATALOADER'] = '1' - os.environ['TORCH_DOCTEST_ONNX'] = '1' os.environ['TORCH_DOCTEST_FUTURES'] = '1' pkgpath = os.path.dirname(torch.__file__) @@ -757,7 +772,8 @@ def run_doctests(test_module, test_directory, options): xdoctest_verbose = max(1, options.verbose) run_summary = xdoctest.runner.doctest_module( os.fspath(pkgpath), config=xdoctest_config, verbose=xdoctest_verbose, - command=options.xdoctest_command, argv=[]) + command=options.xdoctest_command, argv=[], + exclude=exclude_module_list) result = 1 if run_summary.get('n_failed', 0) else 0 return result