diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 535b407a22..edd4a4f456 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -202,6 +202,12 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py" ) + file(GLOB onnxruntime_python_ortmodule_experimental_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/*.py" + ) + file(GLOB onnxruntime_python_ortmodule_experimental_json_config_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/json_config/*.py" + ) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/*.py" ) @@ -411,6 +417,8 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/amp COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/optim COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator @@ -429,6 +437,12 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_srcs} $/onnxruntime/training/ortmodule/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_experimental_srcs} + $/onnxruntime/training/ortmodule/experimental/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_experimental_json_config_srcs} + $/onnxruntime/training/ortmodule/experimental/json_config/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/ diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index c58ec49458..79dc9b057a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -141,15 +141,19 @@ class GraphExecutionManager(GraphExecutionInterface): self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False) self._use_external_gpu_allocator = True + # assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True + self._get_torch_gpu_allocator_function_addresses() + + # WIP feature to enable caching in Gradient accumulation scenario. + self._enable_grad_acc_optimization = False + + def _get_torch_gpu_allocator_function_addresses(self): if self._use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator self._torch_alloc = torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address() self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() - # WIP feature to enable caching in Gradient accumulation scenario. - self._enable_grad_acc_optimization = False - def _validate_module_type(self, module): """Raises a TypeError if the module is not a torch.nn.Module""" diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/__init__.py b/orttraining/orttraining/python/training/ortmodule/experimental/__init__.py new file mode 100644 index 0000000000..043afd846f --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/experimental/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# __init__.py diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py new file mode 100644 index 0000000000..0d62a217c7 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# __init__.py + +from ._load_config_from_json import load_from_json diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py new file mode 100644 index 0000000000..41ecae4ad2 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _load_config_from_json.py + +import json +import os +import logging +from types import SimpleNamespace + +from onnxruntime.capi import _pybind_state as C +from functools import reduce +from ..._graph_execution_manager import _SkipCheck +from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions + +log = logging.getLogger(__name__) + +def _load_data_from_json(path): + """Loads data from the json file path provided.""" + + data = None + with open(path) as f: + data = json.load(f, object_hook=lambda d: SimpleNamespace(**d)) + + if data is None: + raise RuntimeError(f"No data found in provided json file {path}.") + + return data + +def _load_propagate_cast_ops(ortmodule_config_accessor, data): + """Load PropagateCastOps from json file onto ORTModule""" + + assert hasattr(data, _load_propagate_cast_ops.loading_key) + log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.") + + def _update_strategy(): + ortmodule_config_accessor._propagate_cast_ops_strategy = \ + C.PropagateCastOpsStrategy.__members__[data.PropagateCastOps.Strategy] + + def _update_level(): + ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level + + def _update_allow(): + ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow + + key_to_function_mapping = { + "Strategy": _update_strategy, + "Level": _update_level, + "Allow": _update_allow + } + + for key, _ in data.PropagateCastOps.__dict__.items(): + key_to_function_mapping[key]() + +def _load_use_external_gpu_allocator(ortmodule_config_accessor, data): + """Load UseExternalGPUAllocator from json file onto ORTModule""" + + assert hasattr(data, _load_use_external_gpu_allocator.loading_key) + log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.UseExternalGPUAllocator, bool), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean" + ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator + ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses() + +def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): + """Load EnableCustomAutogradFunction from json file onto ORTModule""" + + assert hasattr(data, _load_enable_custom_autograd_function.loading_key) + log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean" + ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction + +def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data): + """Load AllowLayerNormModPrecision from json file onto ORTModule""" + + assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key) + log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean" + ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision + +def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): + """Load EnableGradAccOptimization from json file onto ORTModule""" + + assert hasattr(data, _load_enable_grad_acc_optimization.loading_key) + log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean" + ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization + +def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): + """Load RunSymbolicShapeInference from json file onto ORTModule""" + + assert hasattr(data, _load_run_symbolic_shape_infer.loading_key) + log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean" + ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference + +def _load_use_static_shape(ortmodule_config_accessor, data): + """Load UseStaticShape from json file onto ORTModule""" + + assert hasattr(data, _load_use_static_shape.loading_key) + log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.") + + assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean" + ortmodule_config_accessor._use_static_shape = data.UseStaticShape + +def _load_skip_check(ortmodule_config_accessor, data): + """Load SkipCheck from json file onto ORTModule""" + + assert hasattr(data, _load_skip_check.loading_key) + log.info(f"Found keyword {_load_skip_check.loading_key} in json. Loading attributes from file.") + + skip_check = reduce(lambda x, y: x|y, [_SkipCheck[name] for name in data.SkipCheck]) + if skip_check.value > 0: + ortmodule_config_accessor._skip_check = skip_check + +def _load_debug_options(ortmodule_config_accessor, data): + """Load DebugOptions from json file onto ORTModule""" + + assert hasattr(data, _load_debug_options.loading_key) + log.info(f"Found keyword {_load_debug_options.loading_key} in json. Loading attributes from file.") + + log_level = LogLevel.WARNING + def _update_log_level(): + nonlocal log_level + log_level = LogLevel[data.DebugOptions.LogLevel] + + save_onnx = False + def _update_save_onnx(): + nonlocal save_onnx + save_onnx = data.DebugOptions.SaveONNX + + onnx_prefix = '' + def _update_onnx_prefix(): + nonlocal onnx_prefix + onnx_prefix = data.DebugOptions.ONNXPrefix + + def _update_onnx_path(): + os.environ[_SaveOnnxOptions._path_environment_key] = data.DebugOptions.SaveONNXPath + + key_to_function_mapping = { + "LogLevel": _update_log_level, + "SaveONNX": _update_save_onnx, + "ONNXPrefix": _update_onnx_prefix, + "SaveONNXPath": _update_onnx_path + } + + for key, _ in data.DebugOptions.__dict__.items(): + key_to_function_mapping[key]() + + debug_options = DebugOptions(log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix) + ortmodule_config_accessor._debug_options = debug_options + +def _define_load_function_keys(): + """Define static key variables for each loading function""" + + _load_propagate_cast_ops.loading_key = "PropagateCastOps" + _load_use_external_gpu_allocator.loading_key = "UseExternalGPUAllocator" + _load_enable_custom_autograd_function.loading_key = "EnableCustomAutogradFunction" + _load_allow_layer_norm_mod_precision.loading_key = "AllowLayerNormModPrecision" + _load_enable_grad_acc_optimization.loading_key = "EnableGradAccOptimization" + _load_run_symbolic_shape_infer.loading_key = "RunSymbolicShapeInference" + _load_use_static_shape.loading_key = "UseStaticShape" + _load_skip_check.loading_key = "SkipCheck" + _load_debug_options.loading_key = "DebugOptions" + +def load_from_json(ortmodule, path=None): + """Load config from json file at given path. + + Here is the schema that the json file must adhere to: + { + "PropagateCastOps": + { + "Strategy": "FLOOD_FILL", # str representing strategy (like "NONE", "FLOOD_FILL"...) + "Level": 3, # propagate cast ops level as an int + "Allow": ["ABC", "DEF"] # propagate cast ops allow as list of strs + }, + "UseExternalGPUAllocator" : false, # bool flag + "EnableCustomAutogradFunction": true, # bool flag + "AllowLayerNormModPrecision": true, # bool flag + "EnableGradAccOptimization": true, # bool flag + "UseStaticShape": true, # bool flag + "RunSymbolicShapeInference": false, # bool flag + "SkipCheck": # list of strs representing `_SkipCheck`s checks to skip which will be aggregated using | + [ + "SKIP_CHECK_DEVICE", + "SKIP_CHECK_BUILD_GRADIENT", + "SKIP_CHECK_EXECUTION_AGENT" + ], + "DebugOptions": # debug options for user facing configuration + { + "LogLevel": "VERBOSE", + "SaveONNX": true, + "ONNXPrefix": "my_model", + "SaveONNXPath": "/path/to/onnx/directory" + } + } + + Args: + ortmodule (:obj:`ORTModule`): ORTModule instance that needs to be configured + path (:obj:`str`, optional): Path to json file. Alternatively, users can set the + environement variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case + both path and environment variable are set, the environment variable gets precedence. + """ + + global JSON_PATH_ENVIRONMENT_KEY + JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH" + path = os.getenv(JSON_PATH_ENVIRONMENT_KEY, path) + + # figure out the json path + if path is None: + raise ValueError(f"Path to json is not provided. Provide the path through function call or by setting the environment variable {json_path_environment_key}") + + # load the entire json file + data = _load_data_from_json(path) + + # define the keys for all loading functions + _define_load_function_keys() + + # define all load functions to iterate over + load_functions = { + _load_propagate_cast_ops.loading_key: _load_propagate_cast_ops, + _load_use_external_gpu_allocator.loading_key: _load_use_external_gpu_allocator, + _load_enable_custom_autograd_function.loading_key: _load_enable_custom_autograd_function, + _load_allow_layer_norm_mod_precision.loading_key: _load_allow_layer_norm_mod_precision, + _load_enable_grad_acc_optimization.loading_key: _load_enable_grad_acc_optimization, + _load_run_symbolic_shape_infer.loading_key: _load_run_symbolic_shape_infer, + _load_use_static_shape.loading_key: _load_use_static_shape, + _load_skip_check.loading_key: _load_skip_check, + _load_debug_options.loading_key: _load_debug_options + } + + for training_mode in [True, False]: + # update the debug config for both train and eval modes + ortmodule_config_accessor = ortmodule._torch_module._execution_manager(training_mode) + # iterate over the json data instead of checking for keys in json to catch key errors + for key, _ in data.__dict__.items(): + load_functions[key](ortmodule_config_accessor, data) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 5cd9ab203a..533e0469ce 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -86,6 +86,13 @@ def run_ortmodule_custom_autograd_tests(cwd, log): run_subprocess(command, cwd=cwd, log=log).check_returncode() +def run_ortmodule_experimental_json_config_tests(cwd, log): + log.debug('Running: ORTModule Experimental Load Config tests') + + command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_experimental_json_config.py'] + + run_subprocess(command, cwd=cwd, log=log).check_returncode() + def main(): args = parse_arguments() @@ -110,6 +117,9 @@ def main(): # TODO: enable this once the PyTorch used for testing meets the requirements running # auto grad testing. #run_ortmodule_custom_autograd_tests(cwd, log) + + run_ortmodule_experimental_json_config_tests(cwd, log) + return 0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py new file mode 100644 index 0000000000..ead7cbf426 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py @@ -0,0 +1,113 @@ + +import os +import torch +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.capi import _pybind_state as C +from onnxruntime.training.ortmodule.experimental.json_config import load_from_json + + +class Net(torch.nn.Module): + def __init__(self, input_size=784, hidden_size=500, num_classes=10): + super(Net, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc2(out) + return out + +def test_load_config_from_json_1(): + device = 'cuda' + model = ORTModule(Net().to(device)) + + # load from json once. + path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') + load_from_json(model, path_to_json) + + # load from json another time + path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') + load_from_json(model, path_to_json) + + for training_mode in [True, False]: + ort_model_attributes = model._torch_module._execution_manager(training_mode) + + # test propagate cast ops + assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL + assert ort_model_attributes._propagate_cast_ops_level == 3 + assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"] + + # test use external gpu allocator + assert ort_model_attributes._use_external_gpu_allocator == False + + # test enable custom autograd function + assert ort_model_attributes._enable_custom_autograd_function == True + + # test allow layer norm mod precision + assert ort_model_attributes._allow_layer_norm_mod_precision == True + + # test use static shape + assert ort_model_attributes._use_static_shape == True + + # test run symbolic shape inference + assert ort_model_attributes._run_symbolic_shape_infer == False + + # test enable grad acc optimization + assert ort_model_attributes._enable_grad_acc_optimization == True + + # test skip check + assert ort_model_attributes._skip_check.value == 14 + + # test debug options + assert ort_model_attributes._debug_options.save_onnx_models.save == True + assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_model' + assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE" + +def test_load_config_from_json_2(): + device = 'cuda' + model = ORTModule(Net().to(device)) + + # load from json once. + path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') + load_from_json(model, path_to_json) + + # load from json another time + path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') + load_from_json(model, path_to_json) + + for training_mode in [True, False]: + ort_model_attributes = model._torch_module._execution_manager(training_mode) + + # test propagate cast ops + assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.REMOVE_INPUT_OUTPUT_UP_DOWN_CASTS + assert ort_model_attributes._propagate_cast_ops_level == 5 + assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"] + + # test use external gpu allocator + assert ort_model_attributes._use_external_gpu_allocator == True + + # test enable custom autograd function + assert ort_model_attributes._enable_custom_autograd_function == False + + # test allow layer norm mod precision + assert ort_model_attributes._allow_layer_norm_mod_precision == False + + # test use static shape + assert ort_model_attributes._use_static_shape == False + + # test run symbolic shape inference + assert ort_model_attributes._run_symbolic_shape_infer == True + + # test enable grad acc optimization + assert ort_model_attributes._enable_grad_acc_optimization == False + + # test skip check + assert ort_model_attributes._skip_check.value == 10 + + # test debug options + assert ort_model_attributes._debug_options.save_onnx_models.save == True + assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_other_model' + assert ort_model_attributes._debug_options.logging.log_level.name == "INFO" diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json new file mode 100644 index 0000000000..dd3d5453e9 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json @@ -0,0 +1,26 @@ +{ + "PropagateCastOps": + { + "Strategy": "FLOOD_FILL", + "Level": 3, + "Allow": ["ABC", "DEF"] + }, + "UseExternalGPUAllocator" : false, + "EnableCustomAutogradFunction": true, + "AllowLayerNormModPrecision": true, + "EnableGradAccOptimization": true, + "UseStaticShape": true, + "RunSymbolicShapeInference": false, + "SkipCheck": + [ + "SKIP_CHECK_DEVICE", + "SKIP_CHECK_BUILD_GRADIENT", + "SKIP_CHECK_EXECUTION_AGENT" + ], + "DebugOptions": + { + "LogLevel": "VERBOSE", + "SaveONNX": true, + "ONNXPrefix": "my_model" + } +} diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json new file mode 100644 index 0000000000..fe1cecc03a --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json @@ -0,0 +1,25 @@ +{ + "PropagateCastOps": + { + "Strategy": "REMOVE_INPUT_OUTPUT_UP_DOWN_CASTS", + "Level": 5, + "Allow": ["XYZ", "PQR"] + }, + "UseExternalGPUAllocator" : true, + "EnableCustomAutogradFunction": false, + "AllowLayerNormModPrecision": false, + "EnableGradAccOptimization": false, + "UseStaticShape": false, + "RunSymbolicShapeInference": true, + "SkipCheck": + [ + "SKIP_CHECK_DEVICE", + "SKIP_CHECK_EXECUTION_AGENT" + ], + "DebugOptions": + { + "LogLevel": "INFO", + "SaveONNX": true, + "ONNXPrefix": "my_other_model" + } +} diff --git a/setup.py b/setup.py index 4e32f487ee..49bfe44ea6 100644 --- a/setup.py +++ b/setup.py @@ -279,6 +279,8 @@ if enable_training: 'onnxruntime.training.amp', 'onnxruntime.training.optim', 'onnxruntime.training.ortmodule', + 'onnxruntime.training.ortmodule.experimental', + 'onnxruntime.training.ortmodule.experimental.json_config', 'onnxruntime.training.ortmodule.torch_cpp_extensions', 'onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor', 'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'])