mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Configuring ORTModule - End User Facing Options (#8470)
This commit is contained in:
parent
6f5bf8b8f2
commit
2e28cbaa64
17 changed files with 337 additions and 88 deletions
|
|
@ -64,3 +64,4 @@ torch.cuda.manual_seed = override_torch_cuda_manual_seed
|
|||
|
||||
# ORTModule must be loaded only after all validation passes
|
||||
from .ortmodule import ORTModule
|
||||
from .debug_options import DebugOptions, LogLevel
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext
|
||||
from .debug_options import DebugOptions, LogLevel
|
||||
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext, _onnx_models
|
||||
from ._custom_autograd_function_exporter import _post_process_after_export
|
||||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
|
||||
|
|
@ -54,7 +55,7 @@ class _SkipCheck(IntFlag):
|
|||
return _SkipCheck.SKIP_CHECK_DISABLED in self
|
||||
|
||||
class GraphExecutionManager(GraphExecutionInterface):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module, debug_options: DebugOptions):
|
||||
"""Manages building and execution of onnx graphs
|
||||
|
||||
This class is an abstract class and should not directly be instantiated.
|
||||
|
|
@ -69,8 +70,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# Original and flattened (tranformed) output module
|
||||
self._flattened_module = module
|
||||
|
||||
# Exported model
|
||||
self._onnx_model = None
|
||||
# onnx models
|
||||
self._onnx_models = _onnx_models.ONNXModels()
|
||||
|
||||
# Model after inference optimization or gradient building.
|
||||
self._optimized_onnx_model = None
|
||||
|
|
@ -87,8 +88,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED
|
||||
|
||||
# Debug flags
|
||||
self._save_onnx = False
|
||||
self._save_onnx_prefix = ''
|
||||
self._debug_options = debug_options
|
||||
|
||||
# Graph transformer config
|
||||
# Specify cast propagation strategy. Currently three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL
|
||||
|
|
@ -126,9 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._input_info = None
|
||||
self._module_output_schema = None
|
||||
|
||||
# Log level
|
||||
self._loglevel = _logger.LogLevel.WARNING
|
||||
|
||||
# TODO: Single device support for now
|
||||
self._device = _utils.get_device_from_module(module)
|
||||
|
||||
|
|
@ -137,7 +134,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters.
|
||||
for input_parameter in self._module_parameters:
|
||||
if input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
if self._loglevel <= _logger.LogLevel.WARNING:
|
||||
if self._debug_options.logging.log_level <= LogLevel.WARNING:
|
||||
warnings.warn("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!",
|
||||
UserWarning)
|
||||
|
||||
|
|
@ -194,7 +191,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
else:
|
||||
self._graph_builder.build()
|
||||
|
||||
self._optimized_onnx_model = onnx.load_model_from_string(self._graph_builder.get_model())
|
||||
self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_model())
|
||||
self._graph_info = self._graph_builder.get_graph_info()
|
||||
|
||||
def _get_session_config(self):
|
||||
|
|
@ -222,11 +219,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# default to PRIORITY_BASED execution order
|
||||
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
session_options.log_severity_level = int(self._loglevel)
|
||||
|
||||
# enable dumping optimized training graph
|
||||
if self._save_onnx:
|
||||
session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx'
|
||||
session_options.log_severity_level = int(self._debug_options.logging.log_level)
|
||||
|
||||
return session_options, providers, provider_options
|
||||
|
||||
|
|
@ -243,18 +236,21 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# or the user explicitly changed model parameters after the onnx export.
|
||||
|
||||
schema = _io._extract_schema({'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)})
|
||||
if self._onnx_model and schema == self._input_info.schema:
|
||||
if self._onnx_models.exported_model and schema == self._input_info.schema:
|
||||
# All required models have already been exported previously
|
||||
return False
|
||||
|
||||
self._set_device_from_module(inputs, kwargs)
|
||||
self._onnx_model = self._get_exported_model(*inputs, **kwargs)
|
||||
_cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model)
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_model, self._save_onnx_prefix + '_torch_exporter.onnx')
|
||||
self._onnx_models.exported_model = self._get_exported_model(*inputs, **kwargs)
|
||||
_cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_models.exported_model)
|
||||
if self._debug_options.save_onnx_models.save:
|
||||
self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path,
|
||||
self._debug_options.save_onnx_models.name_prefix,
|
||||
self._export_mode)
|
||||
|
||||
if self._run_symbolic_shape_infer:
|
||||
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
|
||||
self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(self._onnx_models.exported_model,
|
||||
auto_merge=True, guess_output_rank=True)
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -292,7 +288,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
try:
|
||||
with torch.set_grad_enabled(self._enable_custom_autograd_function), \
|
||||
_logger.suppress_os_stream_output(log_level=self._loglevel):
|
||||
_logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level):
|
||||
torch.onnx.export(self._flattened_module,
|
||||
sample_inputs_as_tuple,
|
||||
f,
|
||||
|
|
@ -302,7 +298,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
do_constant_folding=False,
|
||||
training=self._export_mode,
|
||||
dynamic_axes=self._input_info.dynamic_axes,
|
||||
verbose=self._loglevel < _logger.LogLevel.WARNING,
|
||||
verbose=self._debug_options.logging.log_level < LogLevel.WARNING,
|
||||
export_params=False,
|
||||
keep_initializers_as_inputs=True)
|
||||
except RuntimeError as e:
|
||||
|
|
@ -337,7 +333,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
# All initializer names along with user inputs are a part of the onnx graph inputs
|
||||
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
|
||||
onnx_initializer_names = {p.name for p in self._onnx_model.graph.input}
|
||||
onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}
|
||||
|
||||
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
|
||||
initializer_names = [name for name, _ in self._flattened_module.named_parameters()
|
||||
|
|
@ -353,12 +349,12 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
grad_builder_config.build_gradient_graph = training
|
||||
grad_builder_config.graph_transformer_config = self._get_graph_transformer_config()
|
||||
grad_builder_config.enable_caching = self._enable_grad_acc_optimization
|
||||
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel)
|
||||
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level)
|
||||
self._graph_builder = C.OrtModuleGraphBuilder()
|
||||
|
||||
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
|
||||
# and are kept as they appear in the exported onnx model.
|
||||
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
|
||||
self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
|
||||
|
||||
# TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train
|
||||
# a set (unordered_set in the backend) that does not require a copy on each reference.
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from ._inference_manager import InferenceManager
|
|||
|
||||
|
||||
class GraphExecutionManagerFactory(object):
|
||||
def __init__(self, module):
|
||||
self._training_manager = TrainingManager(module)
|
||||
self._inference_manager = InferenceManager(module)
|
||||
def __init__(self, module, debug_options):
|
||||
self._training_manager = TrainingManager(module, debug_options)
|
||||
self._inference_manager = InferenceManager(module, debug_options)
|
||||
|
||||
def __call__(self, is_training):
|
||||
if is_training:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class InferenceManager(GraphExecutionManager):
|
|||
InferenceManager is resposible for building and running the forward graph of the inference model
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
def __init__(self, model, debug_options):
|
||||
super().__init__(model, debug_options)
|
||||
self._export_mode = torch.onnx.TrainingMode.EVAL
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -67,10 +67,6 @@ class InferenceManager(GraphExecutionManager):
|
|||
# If model was exported, then initialize the graph builder
|
||||
self._initialize_graph_builder(training=False)
|
||||
|
||||
# Save the onnx model if the model was exported
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_model, self._save_onnx_prefix + '_exported_inference_model.onnx')
|
||||
|
||||
# Build the inference graph
|
||||
if build_graph:
|
||||
self._build_graph()
|
||||
|
|
@ -86,7 +82,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
self._create_execution_agent()
|
||||
|
||||
user_outputs, _ = InferenceManager.execution_session_run_forward(self._execution_agent,
|
||||
self._optimized_onnx_model,
|
||||
self._onnx_models.optimized_model,
|
||||
self._device,
|
||||
*_io._combine_input_buffers_initializers(
|
||||
self._graph_initializers,
|
||||
|
|
@ -104,12 +100,14 @@ class InferenceManager(GraphExecutionManager):
|
|||
"""Build an optimized inference graph using the module_graph_builder"""
|
||||
|
||||
super()._build_graph()
|
||||
if self._save_onnx:
|
||||
onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_inference.onnx')
|
||||
if self._debug_options.save_onnx_models.save:
|
||||
self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path,
|
||||
self._debug_options.save_onnx_models.name_prefix,
|
||||
self._export_mode)
|
||||
|
||||
def _create_execution_agent(self):
|
||||
"""Creates an InferenceAgent that can run forward graph on an inference model"""
|
||||
|
||||
session_options, providers, provider_options = self._get_session_config()
|
||||
self._execution_agent = InferenceAgent(self._optimized_onnx_model.SerializeToString(),
|
||||
self._execution_agent = InferenceAgent(self._onnx_models.optimized_model.SerializeToString(),
|
||||
session_options, providers, provider_options)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import io
|
|||
import sys
|
||||
import warnings
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
VERBOSE = 0
|
||||
INFO = 1
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# _onnx_models.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from filelock import SoftFileLock
|
||||
import onnx
|
||||
import os
|
||||
import torch
|
||||
|
||||
def _get_onnx_file_name(name_prefix, name, export_mode):
|
||||
suffix = 'training' if export_mode == torch.onnx.TrainingMode.TRAINING else 'inference'
|
||||
return f"{name_prefix}_{name}_{suffix}.onnx"
|
||||
|
||||
def _save_model(model: onnx.ModelProto, file_path: str):
|
||||
onnx.save(model, file_path)
|
||||
|
||||
@dataclass
|
||||
class ONNXModels:
|
||||
"""Encapsulates all ORTModule onnx models."""
|
||||
|
||||
exported_model: onnx.ModelProto = None
|
||||
optimized_model: onnx.ModelProto = None
|
||||
|
||||
def save_exported_model(self, path, name_prefix, export_mode):
|
||||
# save the ortmodule exported model
|
||||
_save_model(self.exported_model, os.path.join(path,
|
||||
_get_onnx_file_name(name_prefix, 'torch_exported', export_mode)))
|
||||
|
||||
def save_optimized_model(self, path, name_prefix, export_mode):
|
||||
# save the ortmodule optimized model
|
||||
_save_model(self.optimized_model, os.path.join(path,
|
||||
_get_onnx_file_name(name_prefix, 'optimized', export_mode)))
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
# _torch_module.py
|
||||
|
||||
from . import _io
|
||||
from .debug_options import DebugOptions
|
||||
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from ._torch_module_interface import TorchModuleInterface
|
||||
|
||||
|
|
@ -16,7 +17,7 @@ T = TypeVar('T', bound='torch.nn.Module')
|
|||
|
||||
|
||||
class TorchModule(TorchModuleInterface):
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, debug_options: DebugOptions):
|
||||
super(TorchModule, self).__init__(module)
|
||||
self._flattened_module = _io._FlattenedModule(module)
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ class TorchModule(TorchModuleInterface):
|
|||
functools.update_wrapper(
|
||||
self.forward.__func__, self._original_module.forward.__func__)
|
||||
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module, debug_options)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Override original method to delegate execution to the flattened PyTorch user module"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ._torch_module import TorchModule
|
|||
|
||||
|
||||
class TorchModuleFactory:
|
||||
def __call__(self, module):
|
||||
def __call__(self, module, debug_options):
|
||||
"""Creates a TorchModule instance based on the input module."""
|
||||
|
||||
return TorchModule(module)
|
||||
return TorchModule(module, debug_options)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from ._execution_agent import TrainingAgent
|
|||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import warnings
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
|
@ -22,8 +21,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
TrainingManager is resposible for building and running the forward and backward graph of the training model
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
def __init__(self, model, debug_options):
|
||||
super().__init__(model, debug_options)
|
||||
self._export_mode = torch.onnx.TrainingMode.TRAINING
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -67,7 +66,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._initialize_graph_builder(training=True)
|
||||
|
||||
input_info = _io.parse_inputs_for_onnx_export(self._module_parameters,
|
||||
self._onnx_model,
|
||||
self._onnx_models.exported_model,
|
||||
inputs,
|
||||
kwargs)
|
||||
|
||||
|
|
@ -122,7 +121,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
|
||||
self._optimized_onnx_model,
|
||||
self._onnx_models.optimized_model,
|
||||
*inputs)
|
||||
|
||||
# Disable materializing grads then None object will not be
|
||||
|
|
@ -233,30 +232,30 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
super()._build_graph()
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_training.onnx')
|
||||
inference_optimized_model = onnx.load_model_from_string(self._graph_builder.get_inference_optimized_model())
|
||||
onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx')
|
||||
if self._debug_options.save_onnx_models.save:
|
||||
self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path,
|
||||
self._debug_options.save_onnx_models.name_prefix,
|
||||
self._export_mode)
|
||||
|
||||
def _create_execution_agent(self):
|
||||
"""Creates a TrainingAgent that can run the forward and backward graph on the training model"""
|
||||
|
||||
session_options, providers, provider_options = self._get_session_config()
|
||||
fw_feed_names = [input.name for input in self._optimized_onnx_model.graph.input]
|
||||
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
|
||||
fw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device.type),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(self._graph_info.user_output_names)
|
||||
|
||||
bw_fetches_names = [output.name for output in self._optimized_onnx_model.graph.output]
|
||||
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
|
||||
bw_outputs_device_info = [
|
||||
C.OrtDevice(get_ort_device_type(self._device.type),
|
||||
C.OrtDevice.default_memory(),
|
||||
_utils.get_device_index(self._device)
|
||||
)] * len(bw_fetches_names)
|
||||
|
||||
self._execution_agent = TrainingAgent(self._optimized_onnx_model.SerializeToString(),
|
||||
self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(),
|
||||
fw_feed_names,
|
||||
fw_outputs_device_info,
|
||||
bw_fetches_names,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# debug_options.py
|
||||
|
||||
import os
|
||||
|
||||
from ._logger import LogLevel
|
||||
|
||||
class _SaveOnnxOptions:
|
||||
"""Configurable option to save ORTModule intermediate onnx models."""
|
||||
|
||||
# class variable
|
||||
_path_environment_key = 'ORTMODULE_SAVE_ONNX_PATH'
|
||||
|
||||
def __init__(self, save, name_prefix):
|
||||
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix)
|
||||
|
||||
def _extract_info(self, save, name_prefix):
|
||||
# get the destination path from os env variable
|
||||
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd())
|
||||
# perform validation only when save is True
|
||||
if save:
|
||||
self._validate(save, name_prefix, destination_path)
|
||||
return save, name_prefix, destination_path
|
||||
|
||||
def _validate(self, save, name_prefix, destination_path):
|
||||
# check if directory is writable
|
||||
if not os.access(destination_path, os.W_OK):
|
||||
raise OSError(f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path.")
|
||||
|
||||
# check if input prefix is a string
|
||||
if not isinstance(name_prefix, str):
|
||||
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
|
||||
|
||||
# if save_onnx is set, save_onnx_prefix must be a non empty string
|
||||
if not name_prefix:
|
||||
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
|
||||
|
||||
@property
|
||||
def save(self):
|
||||
return self._save
|
||||
|
||||
@property
|
||||
def name_prefix(self):
|
||||
return self._name_prefix
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
|
||||
class _LoggingOptions:
|
||||
"""Configurable option to set the log level in ORTModule."""
|
||||
|
||||
# class variable
|
||||
_log_level_environment_key = 'ORTMODULE_LOG_LEVEL'
|
||||
|
||||
def __init__(self, log_level):
|
||||
self._log_level = self._extract_info(log_level)
|
||||
|
||||
def _extract_info(self, log_level):
|
||||
# get the log_level from os env variable
|
||||
# os env variable log level supercededs the locally provided one
|
||||
self._validate(log_level)
|
||||
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
|
||||
return log_level
|
||||
|
||||
def _validate(self, log_level):
|
||||
# check if log_level is an instance of LogLevel
|
||||
if not isinstance(log_level, LogLevel):
|
||||
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
|
||||
|
||||
@property
|
||||
def log_level(self):
|
||||
return self._log_level
|
||||
|
||||
class DebugOptions:
|
||||
"""Configurable debugging options for ORTModule.
|
||||
|
||||
DebugOptions provides a way to configure ORTModule debug flags.
|
||||
|
||||
Args:
|
||||
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
|
||||
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
|
||||
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
|
||||
takes precedence.
|
||||
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
|
||||
The output directory of the onnx models by default is set to the current working directory.
|
||||
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
|
||||
set to the destination directory path.
|
||||
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
|
||||
Must be provided if save_onnx is True
|
||||
|
||||
Raises:
|
||||
OSError: If save_onnx is True and output directory is not writable.
|
||||
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
|
||||
log_level is not an instance of LogLevel.
|
||||
ValueError: If save_onnx is True and name_prefix is an empty string.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=''):
|
||||
self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix)
|
||||
self._logging = _LoggingOptions(log_level)
|
||||
|
||||
@property
|
||||
def save_onnx_models(self):
|
||||
"""Accessor for the save_onnx_models debug flag."""
|
||||
|
||||
return self._save_onnx_models
|
||||
|
||||
@property
|
||||
def logging(self):
|
||||
"""Accessor for the logging debug flag."""
|
||||
|
||||
return self._logging
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
from ._torch_module_factory import TorchModuleFactory
|
||||
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
|
||||
from ._custom_gradient_registry import CustomGradientRegistry
|
||||
from .debug_options import DebugOptions
|
||||
|
||||
from onnxruntime.training import register_custom_ops_pytorch_exporter
|
||||
|
||||
|
|
@ -21,10 +22,19 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
ORTModule specializes the user's :class:`torch.nn.Module` model, providing :meth:`~torch.nn.Module.forward`,
|
||||
:meth:`~torch.nn.Module.backward` along with all others :class:`torch.nn.Module`'s APIs.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): User's PyTorch module that ORTModule specializes
|
||||
debug_options (:obj:`DebugOptions`, optional): debugging options for ORTModule.
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
self._torch_module = TorchModuleFactory()(module)
|
||||
def __init__(self, module, debug_options=None):
|
||||
# Python default arguments are evaluated on function definintion
|
||||
# and not on function invocation. So, if debug_options is not provided,
|
||||
# instantiate it inside the function.
|
||||
if not debug_options:
|
||||
debug_options = DebugOptions()
|
||||
self._torch_module = TorchModuleFactory()(module, debug_options)
|
||||
|
||||
# Create forward dynamically, so each ORTModule instance will have its own copy.
|
||||
# This is needed to be able to copy the forward signatures from the original PyTorch models
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
|
|||
|
||||
def is_dynamic_axes(model):
|
||||
# Check inputs
|
||||
for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
|
||||
for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input:
|
||||
shape = inp.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
|
|
@ -132,7 +132,7 @@ def is_dynamic_axes(model):
|
|||
return False
|
||||
|
||||
# Check outputs
|
||||
for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
|
||||
for out in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.output:
|
||||
shape = out.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ from collections import OrderedDict
|
|||
from collections import namedtuple
|
||||
from inspect import signature
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel
|
||||
import _test_helpers
|
||||
|
||||
# Import autocasting libs
|
||||
|
|
@ -2714,10 +2715,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
|
|||
|
||||
input1 = torch.randn(N, D_in, device=device)
|
||||
ort_model(input1, bool_arguments[0])
|
||||
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
|
||||
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
|
||||
|
||||
ort_model(input1, bool_arguments[1])
|
||||
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
|
||||
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
|
||||
|
||||
assert exported_model1 != exported_model2
|
||||
|
||||
|
|
@ -2877,7 +2878,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
|
|||
_ = ort_model(x)
|
||||
|
||||
input_info = _io.parse_inputs_for_onnx_export(training_manager._module_parameters,
|
||||
training_manager._onnx_model,
|
||||
training_manager._onnx_models.exported_model,
|
||||
x,
|
||||
{})
|
||||
|
||||
|
|
@ -3038,3 +3039,95 @@ def test_ortmodule_nested_list_input():
|
|||
y = copy.deepcopy(x)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
|
||||
|
||||
@pytest.mark.parametrize("mode", ['training', 'inference'])
|
||||
def test_debug_options_save_onnx_models_os_environment(mode):
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
# Create a temporary directory for the onnx_models
|
||||
with tempfile.TemporaryDirectory() as temporary_dir:
|
||||
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = temporary_dir
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_model'))
|
||||
if mode == 'inference':
|
||||
ort_model.eval()
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = ort_model(x)
|
||||
|
||||
# assert that the onnx models have been saved
|
||||
assert os.path.exists(os.path.join(temporary_dir, f"my_model_torch_exported_{mode}.onnx"))
|
||||
assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_{mode}.onnx"))
|
||||
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
|
||||
|
||||
@pytest.mark.parametrize("mode", ['training', 'inference'])
|
||||
def test_debug_options_save_onnx_models_cwd(mode):
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_cwd_model'))
|
||||
if mode == 'inference':
|
||||
ort_model.eval()
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = ort_model(x)
|
||||
|
||||
# assert that the onnx models have been saved
|
||||
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
|
||||
|
||||
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
|
||||
|
||||
def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir():
|
||||
|
||||
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = '/non/existent/directory'
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
_ = DebugOptions(save_onnx=True, onnx_prefix='my_model')
|
||||
assert "Directory /non/existent/directory is not writable." in str(ex_info.value)
|
||||
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
|
||||
|
||||
def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix():
|
||||
prefix = 23
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
_ = DebugOptions(save_onnx=True, onnx_prefix=prefix)
|
||||
assert f"Expected name prefix of type str, got {type(prefix)}." in str(ex_info.value)
|
||||
|
||||
def test_debug_options_save_onnx_models_validate_fail_on_no_prefix():
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
_ = DebugOptions(save_onnx=True)
|
||||
assert f"onnx_prefix must be provided when save_onnx is set." in str(ex_info.value)
|
||||
|
||||
def test_debug_options_log_level():
|
||||
# NOTE: This test will output verbose logging
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model, DebugOptions(log_level=LogLevel.VERBOSE))
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = ort_model(x)
|
||||
|
||||
# assert that the logging is done in verbose mode
|
||||
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.VERBOSE
|
||||
|
||||
def test_debug_options_log_level_os_environment():
|
||||
# NOTE: This test will output info logging
|
||||
|
||||
os.environ['ORTMODULE_LOG_LEVEL'] = 'INFO'
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
_ = ort_model(x)
|
||||
|
||||
# assert that the logging is done in info mode
|
||||
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.INFO
|
||||
del os.environ['ORTMODULE_LOG_LEVEL']
|
||||
|
||||
def test_debug_options_log_level_validation_fails_on_type_mismatch():
|
||||
log_level = 'some_string'
|
||||
with pytest.raises(Exception) as ex_info:
|
||||
_ = DebugOptions(log_level=log_level)
|
||||
assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import datetime
|
|||
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
|
||||
|
||||
def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
||||
# ========================================
|
||||
|
|
@ -376,11 +376,10 @@ def main():
|
|||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
# Just for future debugging
|
||||
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassification')
|
||||
|
||||
# Just for future debugging
|
||||
model._torch_module._execution_manager(model._is_training())._save_onnx = False
|
||||
model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
model = ORTModule(model, debug_options)
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import datetime
|
|||
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
|
||||
|
||||
def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, args):
|
||||
# ========================================
|
||||
|
|
@ -377,11 +377,11 @@ def main():
|
|||
)
|
||||
|
||||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
# Just for future debugging
|
||||
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast')
|
||||
|
||||
model = ORTModule(model, debug_options)
|
||||
|
||||
model._torch_module._execution_manager(is_training=True)._save_onnx = True
|
||||
model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
|
|
|
|||
|
|
@ -9,14 +9,13 @@ $ deepspeed orttraining_test_ortmodule_deepspeed_zero_stage_1.py \
|
|||
```
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import time
|
||||
from torchvision import datasets, transforms
|
||||
import torch.distributed as dist
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
|
||||
|
||||
import deepspeed
|
||||
|
||||
|
|
@ -190,17 +189,20 @@ def main():
|
|||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
|
||||
if not args.pytorch_only:
|
||||
print('Training MNIST on ORTModule....')
|
||||
model = ORTModule(model)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
model._save_onnx_prefix = 'MNIST'
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
log_level_mapping = {"DEBUG": LogLevel.VERBOSE,
|
||||
"INFO": LogLevel.INFO,
|
||||
"WARNING": LogLevel.WARNING,
|
||||
"ERROR": LogLevel.ERROR,
|
||||
"CRITICAL": LogLevel.FATAL}
|
||||
log_level = log_level_mapping.get(args.log_level.upper(), None)
|
||||
if not isinstance(log_level, LogLevel):
|
||||
raise ValueError('Invalid log level: %s' % args.log_level)
|
||||
logging.basicConfig(level=numeric_level)
|
||||
debug_options = DebugOptions(log_level=log_level, save_onnx=False, onnx_prefix='MNIST')
|
||||
|
||||
model = ORTModule(model, debug_options)
|
||||
|
||||
else:
|
||||
print('Training MNIST on vanilla PyTorch....')
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import time
|
|||
from torchvision import datasets, transforms
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
|
||||
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
|
|
@ -167,11 +167,11 @@ def main():
|
|||
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
|
||||
if not args.pytorch_only:
|
||||
print('Training MNIST on ORTModule....')
|
||||
model = ORTModule(model)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
model._save_onnx_prefix = 'MNIST'
|
||||
# Just for future debugging
|
||||
debug_options = DebugOptions(save_onnx=False, onnx_prefix='MNIST')
|
||||
|
||||
model = ORTModule(model, debug_options)
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
|
|
|
|||
Loading…
Reference in a new issue