Configuring ORTModule - End User Facing Options (#8470)

This commit is contained in:
baijumeswani 2021-07-28 10:51:43 -07:00 committed by GitHub
parent 6f5bf8b8f2
commit 2e28cbaa64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 337 additions and 88 deletions

View file

@ -64,3 +64,4 @@ torch.cuda.manual_seed = override_torch_cuda_manual_seed
# ORTModule must be loaded only after all validation passes
from .ortmodule import ORTModule
from .debug_options import DebugOptions, LogLevel

View file

@ -3,7 +3,8 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext
from .debug_options import DebugOptions, LogLevel
from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext, _onnx_models
from ._custom_autograd_function_exporter import _post_process_after_export
from ._graph_execution_interface import GraphExecutionInterface
from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION
@ -54,7 +55,7 @@ class _SkipCheck(IntFlag):
return _SkipCheck.SKIP_CHECK_DISABLED in self
class GraphExecutionManager(GraphExecutionInterface):
def __init__(self, module):
def __init__(self, module, debug_options: DebugOptions):
"""Manages building and execution of onnx graphs
This class is an abstract class and should not directly be instantiated.
@ -69,8 +70,8 @@ class GraphExecutionManager(GraphExecutionInterface):
# Original and flattened (tranformed) output module
self._flattened_module = module
# Exported model
self._onnx_model = None
# onnx models
self._onnx_models = _onnx_models.ONNXModels()
# Model after inference optimization or gradient building.
self._optimized_onnx_model = None
@ -87,8 +88,7 @@ class GraphExecutionManager(GraphExecutionInterface):
self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED
# Debug flags
self._save_onnx = False
self._save_onnx_prefix = ''
self._debug_options = debug_options
# Graph transformer config
# Specify cast propagation strategy. Currently three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL
@ -126,9 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface):
self._input_info = None
self._module_output_schema = None
# Log level
self._loglevel = _logger.LogLevel.WARNING
# TODO: Single device support for now
self._device = _utils.get_device_from_module(module)
@ -137,7 +134,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters.
for input_parameter in self._module_parameters:
if input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
if self._loglevel <= _logger.LogLevel.WARNING:
if self._debug_options.logging.log_level <= LogLevel.WARNING:
warnings.warn("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!",
UserWarning)
@ -194,7 +191,7 @@ class GraphExecutionManager(GraphExecutionInterface):
else:
self._graph_builder.build()
self._optimized_onnx_model = onnx.load_model_from_string(self._graph_builder.get_model())
self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_model())
self._graph_info = self._graph_builder.get_graph_info()
def _get_session_config(self):
@ -222,11 +219,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# default to PRIORITY_BASED execution order
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = int(self._loglevel)
# enable dumping optimized training graph
if self._save_onnx:
session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx'
session_options.log_severity_level = int(self._debug_options.logging.log_level)
return session_options, providers, provider_options
@ -243,18 +236,21 @@ class GraphExecutionManager(GraphExecutionInterface):
# or the user explicitly changed model parameters after the onnx export.
schema = _io._extract_schema({'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)})
if self._onnx_model and schema == self._input_info.schema:
if self._onnx_models.exported_model and schema == self._input_info.schema:
# All required models have already been exported previously
return False
self._set_device_from_module(inputs, kwargs)
self._onnx_model = self._get_exported_model(*inputs, **kwargs)
_cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model)
if self._save_onnx:
onnx.save(self._onnx_model, self._save_onnx_prefix + '_torch_exporter.onnx')
self._onnx_models.exported_model = self._get_exported_model(*inputs, **kwargs)
_cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_models.exported_model)
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path,
self._debug_options.save_onnx_models.name_prefix,
self._export_mode)
if self._run_symbolic_shape_infer:
self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True)
self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(self._onnx_models.exported_model,
auto_merge=True, guess_output_rank=True)
return True
@ -292,7 +288,7 @@ class GraphExecutionManager(GraphExecutionInterface):
try:
with torch.set_grad_enabled(self._enable_custom_autograd_function), \
_logger.suppress_os_stream_output(log_level=self._loglevel):
_logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level):
torch.onnx.export(self._flattened_module,
sample_inputs_as_tuple,
f,
@ -302,7 +298,7 @@ class GraphExecutionManager(GraphExecutionInterface):
do_constant_folding=False,
training=self._export_mode,
dynamic_axes=self._input_info.dynamic_axes,
verbose=self._loglevel < _logger.LogLevel.WARNING,
verbose=self._debug_options.logging.log_level < LogLevel.WARNING,
export_params=False,
keep_initializers_as_inputs=True)
except RuntimeError as e:
@ -337,7 +333,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# All initializer names along with user inputs are a part of the onnx graph inputs
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
onnx_initializer_names = {p.name for p in self._onnx_model.graph.input}
onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input}
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [name for name, _ in self._flattened_module.named_parameters()
@ -353,12 +349,12 @@ class GraphExecutionManager(GraphExecutionInterface):
grad_builder_config.build_gradient_graph = training
grad_builder_config.graph_transformer_config = self._get_graph_transformer_config()
grad_builder_config.enable_caching = self._enable_grad_acc_optimization
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel)
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level)
self._graph_builder = C.OrtModuleGraphBuilder()
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
# and are kept as they appear in the exported onnx model.
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
# TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train
# a set (unordered_set in the backend) that does not require a copy on each reference.

View file

@ -8,9 +8,9 @@ from ._inference_manager import InferenceManager
class GraphExecutionManagerFactory(object):
def __init__(self, module):
self._training_manager = TrainingManager(module)
self._inference_manager = InferenceManager(module)
def __init__(self, module, debug_options):
self._training_manager = TrainingManager(module, debug_options)
self._inference_manager = InferenceManager(module, debug_options)
def __call__(self, is_training):
if is_training:

View file

@ -18,8 +18,8 @@ class InferenceManager(GraphExecutionManager):
InferenceManager is resposible for building and running the forward graph of the inference model
"""
def __init__(self, model):
super().__init__(model)
def __init__(self, model, debug_options):
super().__init__(model, debug_options)
self._export_mode = torch.onnx.TrainingMode.EVAL
@staticmethod
@ -67,10 +67,6 @@ class InferenceManager(GraphExecutionManager):
# If model was exported, then initialize the graph builder
self._initialize_graph_builder(training=False)
# Save the onnx model if the model was exported
if self._save_onnx:
onnx.save(self._onnx_model, self._save_onnx_prefix + '_exported_inference_model.onnx')
# Build the inference graph
if build_graph:
self._build_graph()
@ -86,7 +82,7 @@ class InferenceManager(GraphExecutionManager):
self._create_execution_agent()
user_outputs, _ = InferenceManager.execution_session_run_forward(self._execution_agent,
self._optimized_onnx_model,
self._onnx_models.optimized_model,
self._device,
*_io._combine_input_buffers_initializers(
self._graph_initializers,
@ -104,12 +100,14 @@ class InferenceManager(GraphExecutionManager):
"""Build an optimized inference graph using the module_graph_builder"""
super()._build_graph()
if self._save_onnx:
onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_inference.onnx')
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path,
self._debug_options.save_onnx_models.name_prefix,
self._export_mode)
def _create_execution_agent(self):
"""Creates an InferenceAgent that can run forward graph on an inference model"""
session_options, providers, provider_options = self._get_session_config()
self._execution_agent = InferenceAgent(self._optimized_onnx_model.SerializeToString(),
self._execution_agent = InferenceAgent(self._onnx_models.optimized_model.SerializeToString(),
session_options, providers, provider_options)

View file

@ -10,6 +10,7 @@ import io
import sys
import warnings
class LogLevel(IntEnum):
VERBOSE = 0
INFO = 1

View file

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# _onnx_models.py
from dataclasses import dataclass
from filelock import SoftFileLock
import onnx
import os
import torch
def _get_onnx_file_name(name_prefix, name, export_mode):
suffix = 'training' if export_mode == torch.onnx.TrainingMode.TRAINING else 'inference'
return f"{name_prefix}_{name}_{suffix}.onnx"
def _save_model(model: onnx.ModelProto, file_path: str):
onnx.save(model, file_path)
@dataclass
class ONNXModels:
"""Encapsulates all ORTModule onnx models."""
exported_model: onnx.ModelProto = None
optimized_model: onnx.ModelProto = None
def save_exported_model(self, path, name_prefix, export_mode):
# save the ortmodule exported model
_save_model(self.exported_model, os.path.join(path,
_get_onnx_file_name(name_prefix, 'torch_exported', export_mode)))
def save_optimized_model(self, path, name_prefix, export_mode):
# save the ortmodule optimized model
_save_model(self.optimized_model, os.path.join(path,
_get_onnx_file_name(name_prefix, 'optimized', export_mode)))

View file

@ -3,6 +3,7 @@
# _torch_module.py
from . import _io
from .debug_options import DebugOptions
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._torch_module_interface import TorchModuleInterface
@ -16,7 +17,7 @@ T = TypeVar('T', bound='torch.nn.Module')
class TorchModule(TorchModuleInterface):
def __init__(self, module: torch.nn.Module):
def __init__(self, module: torch.nn.Module, debug_options: DebugOptions):
super(TorchModule, self).__init__(module)
self._flattened_module = _io._FlattenedModule(module)
@ -36,7 +37,7 @@ class TorchModule(TorchModuleInterface):
functools.update_wrapper(
self.forward.__func__, self._original_module.forward.__func__)
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module, debug_options)
def _apply(self, fn):
"""Override original method to delegate execution to the flattened PyTorch user module"""

View file

@ -6,7 +6,7 @@ from ._torch_module import TorchModule
class TorchModuleFactory:
def __call__(self, module):
def __call__(self, module, debug_options):
"""Creates a TorchModule instance based on the input module."""
return TorchModule(module)
return TorchModule(module, debug_options)

View file

@ -10,7 +10,6 @@ from ._execution_agent import TrainingAgent
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type
import onnx
import torch
import warnings
from torch.utils.dlpack import from_dlpack, to_dlpack
@ -22,8 +21,8 @@ class TrainingManager(GraphExecutionManager):
TrainingManager is resposible for building and running the forward and backward graph of the training model
"""
def __init__(self, model):
super().__init__(model)
def __init__(self, model, debug_options):
super().__init__(model, debug_options)
self._export_mode = torch.onnx.TrainingMode.TRAINING
@staticmethod
@ -67,7 +66,7 @@ class TrainingManager(GraphExecutionManager):
self._initialize_graph_builder(training=True)
input_info = _io.parse_inputs_for_onnx_export(self._module_parameters,
self._onnx_model,
self._onnx_models.exported_model,
inputs,
kwargs)
@ -122,7 +121,7 @@ class TrainingManager(GraphExecutionManager):
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent,
self._optimized_onnx_model,
self._onnx_models.optimized_model,
*inputs)
# Disable materializing grads then None object will not be
@ -233,30 +232,30 @@ class TrainingManager(GraphExecutionManager):
super()._build_graph()
if self._save_onnx:
onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_training.onnx')
inference_optimized_model = onnx.load_model_from_string(self._graph_builder.get_inference_optimized_model())
onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx')
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path,
self._debug_options.save_onnx_models.name_prefix,
self._export_mode)
def _create_execution_agent(self):
"""Creates a TrainingAgent that can run the forward and backward graph on the training model"""
session_options, providers, provider_options = self._get_session_config()
fw_feed_names = [input.name for input in self._optimized_onnx_model.graph.input]
fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input]
fw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device.type),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(self._graph_info.user_output_names)
bw_fetches_names = [output.name for output in self._optimized_onnx_model.graph.output]
bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output]
bw_outputs_device_info = [
C.OrtDevice(get_ort_device_type(self._device.type),
C.OrtDevice.default_memory(),
_utils.get_device_index(self._device)
)] * len(bw_fetches_names)
self._execution_agent = TrainingAgent(self._optimized_onnx_model.SerializeToString(),
self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(),
fw_feed_names,
fw_outputs_device_info,
bw_fetches_names,

View file

@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# debug_options.py
import os
from ._logger import LogLevel
class _SaveOnnxOptions:
"""Configurable option to save ORTModule intermediate onnx models."""
# class variable
_path_environment_key = 'ORTMODULE_SAVE_ONNX_PATH'
def __init__(self, save, name_prefix):
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix)
def _extract_info(self, save, name_prefix):
# get the destination path from os env variable
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd())
# perform validation only when save is True
if save:
self._validate(save, name_prefix, destination_path)
return save, name_prefix, destination_path
def _validate(self, save, name_prefix, destination_path):
# check if directory is writable
if not os.access(destination_path, os.W_OK):
raise OSError(f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path.")
# check if input prefix is a string
if not isinstance(name_prefix, str):
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
# if save_onnx is set, save_onnx_prefix must be a non empty string
if not name_prefix:
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
@property
def save(self):
return self._save
@property
def name_prefix(self):
return self._name_prefix
@property
def path(self):
return self._path
class _LoggingOptions:
"""Configurable option to set the log level in ORTModule."""
# class variable
_log_level_environment_key = 'ORTMODULE_LOG_LEVEL'
def __init__(self, log_level):
self._log_level = self._extract_info(log_level)
def _extract_info(self, log_level):
# get the log_level from os env variable
# os env variable log level supercededs the locally provided one
self._validate(log_level)
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
return log_level
def _validate(self, log_level):
# check if log_level is an instance of LogLevel
if not isinstance(log_level, LogLevel):
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
@property
def log_level(self):
return self._log_level
class DebugOptions:
"""Configurable debugging options for ORTModule.
DebugOptions provides a way to configure ORTModule debug flags.
Args:
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
takes precedence.
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory.
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
set to the destination directory path.
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
Must be provided if save_onnx is True
Raises:
OSError: If save_onnx is True and output directory is not writable.
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
log_level is not an instance of LogLevel.
ValueError: If save_onnx is True and name_prefix is an empty string.
"""
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=''):
self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix)
self._logging = _LoggingOptions(log_level)
@property
def save_onnx_models(self):
"""Accessor for the save_onnx_models debug flag."""
return self._save_onnx_models
@property
def logging(self):
"""Accessor for the logging debug flag."""
return self._logging

View file

@ -6,6 +6,7 @@
from ._torch_module_factory import TorchModuleFactory
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
from ._custom_gradient_registry import CustomGradientRegistry
from .debug_options import DebugOptions
from onnxruntime.training import register_custom_ops_pytorch_exporter
@ -21,10 +22,19 @@ class ORTModule(torch.nn.Module):
ORTModule specializes the user's :class:`torch.nn.Module` model, providing :meth:`~torch.nn.Module.forward`,
:meth:`~torch.nn.Module.backward` along with all others :class:`torch.nn.Module`'s APIs.
Args:
module (torch.nn.Module): User's PyTorch module that ORTModule specializes
debug_options (:obj:`DebugOptions`, optional): debugging options for ORTModule.
"""
def __init__(self, module):
self._torch_module = TorchModuleFactory()(module)
def __init__(self, module, debug_options=None):
# Python default arguments are evaluated on function definintion
# and not on function invocation. So, if debug_options is not provided,
# instantiate it inside the function.
if not debug_options:
debug_options = DebugOptions()
self._torch_module = TorchModuleFactory()(module, debug_options)
# Create forward dynamically, so each ORTModule instance will have its own copy.
# This is needed to be able to copy the forward signatures from the original PyTorch models

View file

@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0):
def is_dynamic_axes(model):
# Check inputs
for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input:
for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input:
shape = inp.type.tensor_type.shape
if shape:
for dim in shape.dim:
@ -132,7 +132,7 @@ def is_dynamic_axes(model):
return False
# Check outputs
for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output:
for out in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.output:
shape = out.type.tensor_type.shape
if shape:
for dim in shape.dim:

View file

@ -16,8 +16,9 @@ from collections import OrderedDict
from collections import namedtuple
from inspect import signature
import tempfile
import os
from onnxruntime.training.ortmodule import ORTModule, _utils, _io
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel
import _test_helpers
# Import autocasting libs
@ -2714,10 +2715,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments):
input1 = torch.randn(N, D_in, device=device)
ort_model(input1, bool_arguments[0])
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
ort_model(input1, bool_arguments[1])
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model
exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
assert exported_model1 != exported_model2
@ -2877,7 +2878,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model):
_ = ort_model(x)
input_info = _io.parse_inputs_for_onnx_export(training_manager._module_parameters,
training_manager._onnx_model,
training_manager._onnx_models.exported_model,
x,
{})
@ -3038,3 +3039,95 @@ def test_ortmodule_nested_list_input():
y = copy.deepcopy(x)
_test_helpers.assert_values_are_close(pt_model(x), ort_model(y))
@pytest.mark.parametrize("mode", ['training', 'inference'])
def test_debug_options_save_onnx_models_os_environment(mode):
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
# Create a temporary directory for the onnx_models
with tempfile.TemporaryDirectory() as temporary_dir:
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = temporary_dir
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_model'))
if mode == 'inference':
ort_model.eval()
x = torch.randn(N, D_in, device=device)
_ = ort_model(x)
# assert that the onnx models have been saved
assert os.path.exists(os.path.join(temporary_dir, f"my_model_torch_exported_{mode}.onnx"))
assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_{mode}.onnx"))
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
@pytest.mark.parametrize("mode", ['training', 'inference'])
def test_debug_options_save_onnx_models_cwd(mode):
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_cwd_model'))
if mode == 'inference':
ort_model.eval()
x = torch.randn(N, D_in, device=device)
_ = ort_model(x)
# assert that the onnx models have been saved
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir():
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = '/non/existent/directory'
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True, onnx_prefix='my_model')
assert "Directory /non/existent/directory is not writable." in str(ex_info.value)
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix():
prefix = 23
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True, onnx_prefix=prefix)
assert f"Expected name prefix of type str, got {type(prefix)}." in str(ex_info.value)
def test_debug_options_save_onnx_models_validate_fail_on_no_prefix():
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True)
assert f"onnx_prefix must be provided when save_onnx is set." in str(ex_info.value)
def test_debug_options_log_level():
# NOTE: This test will output verbose logging
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model, DebugOptions(log_level=LogLevel.VERBOSE))
x = torch.randn(N, D_in, device=device)
_ = ort_model(x)
# assert that the logging is done in verbose mode
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.VERBOSE
def test_debug_options_log_level_os_environment():
# NOTE: This test will output info logging
os.environ['ORTMODULE_LOG_LEVEL'] = 'INFO'
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
x = torch.randn(N, D_in, device=device)
_ = ort_model(x)
# assert that the logging is done in info mode
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.INFO
del os.environ['ORTMODULE_LOG_LEVEL']
def test_debug_options_log_level_validation_fails_on_type_mismatch():
log_level = 'some_string'
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(log_level=log_level)
assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value)

View file

@ -17,7 +17,7 @@ import datetime
import onnxruntime
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# ========================================
@ -376,11 +376,10 @@ def main():
)
if not args.pytorch_only:
model = ORTModule(model)
# Just for future debugging
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassification')
# Just for future debugging
model._torch_module._execution_manager(model._is_training())._save_onnx = False
model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification'
model = ORTModule(model, debug_options)
# Tell pytorch to run this model on the GPU.
if torch.cuda.is_available() and not args.no_cuda:

View file

@ -17,7 +17,7 @@ import datetime
import onnxruntime
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, args):
# ========================================
@ -377,11 +377,11 @@ def main():
)
if not args.pytorch_only:
model = ORTModule(model)
# Just for future debugging
debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast')
model = ORTModule(model, debug_options)
model._torch_module._execution_manager(is_training=True)._save_onnx = True
model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification'
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
# Tell pytorch to run this model on the GPU.

View file

@ -9,14 +9,13 @@ $ deepspeed orttraining_test_ortmodule_deepspeed_zero_stage_1.py \
```
"""
import argparse
import logging
import torch
import time
from torchvision import datasets, transforms
import torch.distributed as dist
import onnxruntime
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
import deepspeed
@ -190,17 +189,20 @@ def main():
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
if not args.pytorch_only:
print('Training MNIST on ORTModule....')
model = ORTModule(model)
# TODO: change it to False to stop saving ONNX models
model._save_onnx = True
model._save_onnx_prefix = 'MNIST'
# Set log level
numeric_level = getattr(logging, args.log_level.upper(), None)
if not isinstance(numeric_level, int):
log_level_mapping = {"DEBUG": LogLevel.VERBOSE,
"INFO": LogLevel.INFO,
"WARNING": LogLevel.WARNING,
"ERROR": LogLevel.ERROR,
"CRITICAL": LogLevel.FATAL}
log_level = log_level_mapping.get(args.log_level.upper(), None)
if not isinstance(log_level, LogLevel):
raise ValueError('Invalid log level: %s' % args.log_level)
logging.basicConfig(level=numeric_level)
debug_options = DebugOptions(log_level=log_level, save_onnx=False, onnx_prefix='MNIST')
model = ORTModule(model, debug_options)
else:
print('Training MNIST on vanilla PyTorch....')

View file

@ -5,7 +5,7 @@ import time
from torchvision import datasets, transforms
import onnxruntime
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule import ORTModule, DebugOptions
class NeuralNet(torch.nn.Module):
@ -167,11 +167,11 @@ def main():
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
if not args.pytorch_only:
print('Training MNIST on ORTModule....')
model = ORTModule(model)
# TODO: change it to False to stop saving ONNX models
model._save_onnx = True
model._save_onnx_prefix = 'MNIST'
# Just for future debugging
debug_options = DebugOptions(save_onnx=False, onnx_prefix='MNIST')
model = ORTModule(model, debug_options)
# Set log level
numeric_level = getattr(logging, args.log_level.upper(), None)