mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Configuring ORTModule - Internal Options (#8537)
This commit is contained in:
parent
c6f95841dc
commit
816ad86d14
10 changed files with 445 additions and 3 deletions
|
|
@ -202,6 +202,12 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_experimental_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_experimental_json_config_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/json_config/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/*.py"
|
||||
)
|
||||
|
|
@ -411,6 +417,8 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/amp
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/optim
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/json_config
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/aten_op_executor
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/torch_gpu_allocator
|
||||
|
|
@ -429,6 +437,12 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_experimental_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_experimental_json_config_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/json_config/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/
|
||||
|
|
|
|||
|
|
@ -141,15 +141,19 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False)
|
||||
|
||||
self._use_external_gpu_allocator = True
|
||||
# assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True
|
||||
self._get_torch_gpu_allocator_function_addresses()
|
||||
|
||||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self._enable_grad_acc_optimization = False
|
||||
|
||||
def _get_torch_gpu_allocator_function_addresses(self):
|
||||
if self._use_external_gpu_allocator and torch.cuda.is_available():
|
||||
# CPP extension to get torch GPU allocator's alloc and free function addresses
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator
|
||||
self._torch_alloc = torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address()
|
||||
self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address()
|
||||
|
||||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self._enable_grad_acc_optimization = False
|
||||
|
||||
def _validate_module_type(self, module):
|
||||
"""Raises a TypeError if the module is not a torch.nn.Module"""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# __init__.py
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# __init__.py
|
||||
|
||||
from ._load_config_from_json import load_from_json
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _load_config_from_json.py
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from functools import reduce
|
||||
from ..._graph_execution_manager import _SkipCheck
|
||||
from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def _load_data_from_json(path):
|
||||
"""Loads data from the json file path provided."""
|
||||
|
||||
data = None
|
||||
with open(path) as f:
|
||||
data = json.load(f, object_hook=lambda d: SimpleNamespace(**d))
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError(f"No data found in provided json file {path}.")
|
||||
|
||||
return data
|
||||
|
||||
def _load_propagate_cast_ops(ortmodule_config_accessor, data):
|
||||
"""Load PropagateCastOps from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_propagate_cast_ops.loading_key)
|
||||
log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
def _update_strategy():
|
||||
ortmodule_config_accessor._propagate_cast_ops_strategy = \
|
||||
C.PropagateCastOpsStrategy.__members__[data.PropagateCastOps.Strategy]
|
||||
|
||||
def _update_level():
|
||||
ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level
|
||||
|
||||
def _update_allow():
|
||||
ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow
|
||||
|
||||
key_to_function_mapping = {
|
||||
"Strategy": _update_strategy,
|
||||
"Level": _update_level,
|
||||
"Allow": _update_allow
|
||||
}
|
||||
|
||||
for key, _ in data.PropagateCastOps.__dict__.items():
|
||||
key_to_function_mapping[key]()
|
||||
|
||||
def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
|
||||
"""Load UseExternalGPUAllocator from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_use_external_gpu_allocator.loading_key)
|
||||
log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.UseExternalGPUAllocator, bool), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator
|
||||
ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses()
|
||||
|
||||
def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
||||
"""Load EnableCustomAutogradFunction from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_enable_custom_autograd_function.loading_key)
|
||||
log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
|
||||
|
||||
def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
|
||||
"""Load AllowLayerNormModPrecision from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key)
|
||||
log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision
|
||||
|
||||
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
||||
"""Load EnableGradAccOptimization from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_enable_grad_acc_optimization.loading_key)
|
||||
log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization
|
||||
|
||||
def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
|
||||
"""Load RunSymbolicShapeInference from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_run_symbolic_shape_infer.loading_key)
|
||||
log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference
|
||||
|
||||
def _load_use_static_shape(ortmodule_config_accessor, data):
|
||||
"""Load UseStaticShape from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_use_static_shape.loading_key)
|
||||
log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_static_shape = data.UseStaticShape
|
||||
|
||||
def _load_skip_check(ortmodule_config_accessor, data):
|
||||
"""Load SkipCheck from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_skip_check.loading_key)
|
||||
log.info(f"Found keyword {_load_skip_check.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
skip_check = reduce(lambda x, y: x|y, [_SkipCheck[name] for name in data.SkipCheck])
|
||||
if skip_check.value > 0:
|
||||
ortmodule_config_accessor._skip_check = skip_check
|
||||
|
||||
def _load_debug_options(ortmodule_config_accessor, data):
|
||||
"""Load DebugOptions from json file onto ORTModule"""
|
||||
|
||||
assert hasattr(data, _load_debug_options.loading_key)
|
||||
log.info(f"Found keyword {_load_debug_options.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
log_level = LogLevel.WARNING
|
||||
def _update_log_level():
|
||||
nonlocal log_level
|
||||
log_level = LogLevel[data.DebugOptions.LogLevel]
|
||||
|
||||
save_onnx = False
|
||||
def _update_save_onnx():
|
||||
nonlocal save_onnx
|
||||
save_onnx = data.DebugOptions.SaveONNX
|
||||
|
||||
onnx_prefix = ''
|
||||
def _update_onnx_prefix():
|
||||
nonlocal onnx_prefix
|
||||
onnx_prefix = data.DebugOptions.ONNXPrefix
|
||||
|
||||
def _update_onnx_path():
|
||||
os.environ[_SaveOnnxOptions._path_environment_key] = data.DebugOptions.SaveONNXPath
|
||||
|
||||
key_to_function_mapping = {
|
||||
"LogLevel": _update_log_level,
|
||||
"SaveONNX": _update_save_onnx,
|
||||
"ONNXPrefix": _update_onnx_prefix,
|
||||
"SaveONNXPath": _update_onnx_path
|
||||
}
|
||||
|
||||
for key, _ in data.DebugOptions.__dict__.items():
|
||||
key_to_function_mapping[key]()
|
||||
|
||||
debug_options = DebugOptions(log_level=log_level, save_onnx=save_onnx, onnx_prefix=onnx_prefix)
|
||||
ortmodule_config_accessor._debug_options = debug_options
|
||||
|
||||
def _define_load_function_keys():
|
||||
"""Define static key variables for each loading function"""
|
||||
|
||||
_load_propagate_cast_ops.loading_key = "PropagateCastOps"
|
||||
_load_use_external_gpu_allocator.loading_key = "UseExternalGPUAllocator"
|
||||
_load_enable_custom_autograd_function.loading_key = "EnableCustomAutogradFunction"
|
||||
_load_allow_layer_norm_mod_precision.loading_key = "AllowLayerNormModPrecision"
|
||||
_load_enable_grad_acc_optimization.loading_key = "EnableGradAccOptimization"
|
||||
_load_run_symbolic_shape_infer.loading_key = "RunSymbolicShapeInference"
|
||||
_load_use_static_shape.loading_key = "UseStaticShape"
|
||||
_load_skip_check.loading_key = "SkipCheck"
|
||||
_load_debug_options.loading_key = "DebugOptions"
|
||||
|
||||
def load_from_json(ortmodule, path=None):
|
||||
"""Load config from json file at given path.
|
||||
|
||||
Here is the schema that the json file must adhere to:
|
||||
{
|
||||
"PropagateCastOps":
|
||||
{
|
||||
"Strategy": "FLOOD_FILL", # str representing strategy (like "NONE", "FLOOD_FILL"...)
|
||||
"Level": 3, # propagate cast ops level as an int
|
||||
"Allow": ["ABC", "DEF"] # propagate cast ops allow as list of strs
|
||||
},
|
||||
"UseExternalGPUAllocator" : false, # bool flag
|
||||
"EnableCustomAutogradFunction": true, # bool flag
|
||||
"AllowLayerNormModPrecision": true, # bool flag
|
||||
"EnableGradAccOptimization": true, # bool flag
|
||||
"UseStaticShape": true, # bool flag
|
||||
"RunSymbolicShapeInference": false, # bool flag
|
||||
"SkipCheck": # list of strs representing `_SkipCheck`s checks to skip which will be aggregated using |
|
||||
[
|
||||
"SKIP_CHECK_DEVICE",
|
||||
"SKIP_CHECK_BUILD_GRADIENT",
|
||||
"SKIP_CHECK_EXECUTION_AGENT"
|
||||
],
|
||||
"DebugOptions": # debug options for user facing configuration
|
||||
{
|
||||
"LogLevel": "VERBOSE",
|
||||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_model",
|
||||
"SaveONNXPath": "/path/to/onnx/directory"
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
ortmodule (:obj:`ORTModule`): ORTModule instance that needs to be configured
|
||||
path (:obj:`str`, optional): Path to json file. Alternatively, users can set the
|
||||
environement variable ORTMODULE_JSON_CONFIG_PATH to the json config path. In case
|
||||
both path and environment variable are set, the environment variable gets precedence.
|
||||
"""
|
||||
|
||||
global JSON_PATH_ENVIRONMENT_KEY
|
||||
JSON_PATH_ENVIRONMENT_KEY = "ORTMODULE_JSON_CONFIG_PATH"
|
||||
path = os.getenv(JSON_PATH_ENVIRONMENT_KEY, path)
|
||||
|
||||
# figure out the json path
|
||||
if path is None:
|
||||
raise ValueError(f"Path to json is not provided. Provide the path through function call or by setting the environment variable {json_path_environment_key}")
|
||||
|
||||
# load the entire json file
|
||||
data = _load_data_from_json(path)
|
||||
|
||||
# define the keys for all loading functions
|
||||
_define_load_function_keys()
|
||||
|
||||
# define all load functions to iterate over
|
||||
load_functions = {
|
||||
_load_propagate_cast_ops.loading_key: _load_propagate_cast_ops,
|
||||
_load_use_external_gpu_allocator.loading_key: _load_use_external_gpu_allocator,
|
||||
_load_enable_custom_autograd_function.loading_key: _load_enable_custom_autograd_function,
|
||||
_load_allow_layer_norm_mod_precision.loading_key: _load_allow_layer_norm_mod_precision,
|
||||
_load_enable_grad_acc_optimization.loading_key: _load_enable_grad_acc_optimization,
|
||||
_load_run_symbolic_shape_infer.loading_key: _load_run_symbolic_shape_infer,
|
||||
_load_use_static_shape.loading_key: _load_use_static_shape,
|
||||
_load_skip_check.loading_key: _load_skip_check,
|
||||
_load_debug_options.loading_key: _load_debug_options
|
||||
}
|
||||
|
||||
for training_mode in [True, False]:
|
||||
# update the debug config for both train and eval modes
|
||||
ortmodule_config_accessor = ortmodule._torch_module._execution_manager(training_mode)
|
||||
# iterate over the json data instead of checking for keys in json to catch key errors
|
||||
for key, _ in data.__dict__.items():
|
||||
load_functions[key](ortmodule_config_accessor, data)
|
||||
|
|
@ -86,6 +86,13 @@ def run_ortmodule_custom_autograd_tests(cwd, log):
|
|||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
||||
def run_ortmodule_experimental_json_config_tests(cwd, log):
|
||||
log.debug('Running: ORTModule Experimental Load Config tests')
|
||||
|
||||
command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_experimental_json_config.py']
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
|
@ -110,6 +117,9 @@ def main():
|
|||
# TODO: enable this once the PyTorch used for testing meets the requirements running
|
||||
# auto grad testing.
|
||||
#run_ortmodule_custom_autograd_tests(cwd, log)
|
||||
|
||||
run_ortmodule_experimental_json_config_tests(cwd, log)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
|
||||
import os
|
||||
import torch
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.training.ortmodule.experimental.json_config import load_from_json
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self, input_size=784, hidden_size=500, num_classes=10):
|
||||
super(Net, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def test_load_config_from_json_1():
|
||||
device = 'cuda'
|
||||
model = ORTModule(Net().to(device))
|
||||
|
||||
# load from json once.
|
||||
path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json')
|
||||
load_from_json(model, path_to_json)
|
||||
|
||||
# load from json another time
|
||||
path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json')
|
||||
load_from_json(model, path_to_json)
|
||||
|
||||
for training_mode in [True, False]:
|
||||
ort_model_attributes = model._torch_module._execution_manager(training_mode)
|
||||
|
||||
# test propagate cast ops
|
||||
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
|
||||
assert ort_model_attributes._propagate_cast_ops_level == 3
|
||||
assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"]
|
||||
|
||||
# test use external gpu allocator
|
||||
assert ort_model_attributes._use_external_gpu_allocator == False
|
||||
|
||||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function == True
|
||||
|
||||
# test allow layer norm mod precision
|
||||
assert ort_model_attributes._allow_layer_norm_mod_precision == True
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape == True
|
||||
|
||||
# test run symbolic shape inference
|
||||
assert ort_model_attributes._run_symbolic_shape_infer == False
|
||||
|
||||
# test enable grad acc optimization
|
||||
assert ort_model_attributes._enable_grad_acc_optimization == True
|
||||
|
||||
# test skip check
|
||||
assert ort_model_attributes._skip_check.value == 14
|
||||
|
||||
# test debug options
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.save == True
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_model'
|
||||
assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE"
|
||||
|
||||
def test_load_config_from_json_2():
|
||||
device = 'cuda'
|
||||
model = ORTModule(Net().to(device))
|
||||
|
||||
# load from json once.
|
||||
path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json')
|
||||
load_from_json(model, path_to_json)
|
||||
|
||||
# load from json another time
|
||||
path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json')
|
||||
load_from_json(model, path_to_json)
|
||||
|
||||
for training_mode in [True, False]:
|
||||
ort_model_attributes = model._torch_module._execution_manager(training_mode)
|
||||
|
||||
# test propagate cast ops
|
||||
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.REMOVE_INPUT_OUTPUT_UP_DOWN_CASTS
|
||||
assert ort_model_attributes._propagate_cast_ops_level == 5
|
||||
assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"]
|
||||
|
||||
# test use external gpu allocator
|
||||
assert ort_model_attributes._use_external_gpu_allocator == True
|
||||
|
||||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function == False
|
||||
|
||||
# test allow layer norm mod precision
|
||||
assert ort_model_attributes._allow_layer_norm_mod_precision == False
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape == False
|
||||
|
||||
# test run symbolic shape inference
|
||||
assert ort_model_attributes._run_symbolic_shape_infer == True
|
||||
|
||||
# test enable grad acc optimization
|
||||
assert ort_model_attributes._enable_grad_acc_optimization == False
|
||||
|
||||
# test skip check
|
||||
assert ort_model_attributes._skip_check.value == 10
|
||||
|
||||
# test debug options
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.save == True
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_other_model'
|
||||
assert ort_model_attributes._debug_options.logging.log_level.name == "INFO"
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"PropagateCastOps":
|
||||
{
|
||||
"Strategy": "FLOOD_FILL",
|
||||
"Level": 3,
|
||||
"Allow": ["ABC", "DEF"]
|
||||
},
|
||||
"UseExternalGPUAllocator" : false,
|
||||
"EnableCustomAutogradFunction": true,
|
||||
"AllowLayerNormModPrecision": true,
|
||||
"EnableGradAccOptimization": true,
|
||||
"UseStaticShape": true,
|
||||
"RunSymbolicShapeInference": false,
|
||||
"SkipCheck":
|
||||
[
|
||||
"SKIP_CHECK_DEVICE",
|
||||
"SKIP_CHECK_BUILD_GRADIENT",
|
||||
"SKIP_CHECK_EXECUTION_AGENT"
|
||||
],
|
||||
"DebugOptions":
|
||||
{
|
||||
"LogLevel": "VERBOSE",
|
||||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_model"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"PropagateCastOps":
|
||||
{
|
||||
"Strategy": "REMOVE_INPUT_OUTPUT_UP_DOWN_CASTS",
|
||||
"Level": 5,
|
||||
"Allow": ["XYZ", "PQR"]
|
||||
},
|
||||
"UseExternalGPUAllocator" : true,
|
||||
"EnableCustomAutogradFunction": false,
|
||||
"AllowLayerNormModPrecision": false,
|
||||
"EnableGradAccOptimization": false,
|
||||
"UseStaticShape": false,
|
||||
"RunSymbolicShapeInference": true,
|
||||
"SkipCheck":
|
||||
[
|
||||
"SKIP_CHECK_DEVICE",
|
||||
"SKIP_CHECK_EXECUTION_AGENT"
|
||||
],
|
||||
"DebugOptions":
|
||||
{
|
||||
"LogLevel": "INFO",
|
||||
"SaveONNX": true,
|
||||
"ONNXPrefix": "my_other_model"
|
||||
}
|
||||
}
|
||||
2
setup.py
2
setup.py
|
|
@ -279,6 +279,8 @@ if enable_training:
|
|||
'onnxruntime.training.amp',
|
||||
'onnxruntime.training.optim',
|
||||
'onnxruntime.training.ortmodule',
|
||||
'onnxruntime.training.ortmodule.experimental',
|
||||
'onnxruntime.training.ortmodule.experimental.json_config',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions.aten_op_executor',
|
||||
'onnxruntime.training.ortmodule.torch_cpp_extensions.torch_gpu_allocator'])
|
||||
|
|
|
|||
Loading…
Reference in a new issue