mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
ORTModule log clean up (#16795)
### ORTModule log clean up ORTModule log level - WARNING(Default) is for end users; INFO and VERBOSE is for internal ORT training developers. Few issues: 1. ONNX export will output lots of WARNING error message like "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is missing", which is useless for us or end users.  3. ORT also print some information like ""CleanUnusedInitializersAndNodeArgs] Removing initializer","ReverseBFSWithStopGradient] Skip building gradient for", which is also useless for us or end users most of the time.  5. Different ranks output logs and making ORT developers or end users feels there are too many logs but usually not useful until we need investigate. Few improvements for the issues: 1. For ONNX export logs, there are two kinds of logs: a. export verbose log; b. other logs printed by torch C++ backend. So this PR make following change: # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) # INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) e.g. for verbose level, print all logs as usually; for info level, print verbose export log, and filtered logs from torch C++ backend (removing messages like this "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is missing") . For higher level, only log the info on rank 0. 2. For ORT gradient graph build and session creation, also suppress the message and filtered out the message when log level >=INFO. 3. log level > INFO, then only logs on rank 0 is logged, to have a cleaner user experience This is the log for a BLOOM model training after the change: there are limited of warnings. 
This commit is contained in:
parent
bf006d34a9
commit
39fca225ea
16 changed files with 388 additions and 176 deletions
|
|
@ -33,7 +33,10 @@ Status ConcatSliceElimination::ApplyImpl(Graph& graph, bool& modified, int graph
|
|||
modified = true;
|
||||
}
|
||||
}
|
||||
LOGS(logger, INFO) << "Total fused concat node count: " << fused_count;
|
||||
|
||||
if (fused_count > 0) {
|
||||
LOGS(logger, INFO) << "Total fused concat node count: " << fused_count;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -252,8 +252,9 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
|
|||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count;
|
||||
if (shared_count > 0) {
|
||||
LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,10 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
|
|||
modified = true;
|
||||
}
|
||||
}
|
||||
LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count;
|
||||
|
||||
if (fused_count > 0) {
|
||||
LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -164,8 +164,8 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c
|
|||
const std::unordered_set<size_t>* edges = GetStopGradientEdges(*n);
|
||||
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
|
||||
if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) {
|
||||
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name();
|
||||
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name();
|
||||
continue;
|
||||
}
|
||||
const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()];
|
||||
|
|
@ -173,13 +173,13 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c
|
|||
if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
|
||||
const int32_t type = type_proto->tensor_type().elem_type();
|
||||
if (GRAD_ALLOWED_TYPES.find(type) == GRAD_ALLOWED_TYPES.end()) {
|
||||
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name() << "because element type is: " << type;
|
||||
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name() << "because element type is: " << type;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name() << "because it is not a Tensor type";
|
||||
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name() << "because it is not a Tensor type";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -280,7 +280,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
|
||||
const std::unordered_set<size_t>* edges = GetStopGradientEdges(next_node);
|
||||
if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) {
|
||||
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << next_node.Name();
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager
|
|||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
|
||||
from ._runtime_inspector import RuntimeInspector
|
||||
from ._utils import check_function_has_param
|
||||
from ._utils import check_function_has_param, get_rank
|
||||
from .options import DebugOptions, LogLevel, _RuntimeOptions
|
||||
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
|
||||
|
||||
|
|
@ -104,7 +104,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# Input and output infos (including schema) for exported model.
|
||||
self._input_info: Optional[_InputInfo] = None
|
||||
self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None
|
||||
self._warning_log_detected_during_export = False
|
||||
|
||||
# Device where the model is placed.
|
||||
self._device: Optional[torch.device] = _utils.get_device_from_module(module)
|
||||
|
|
@ -253,7 +252,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
return session_options, providers, provider_options
|
||||
|
||||
@_logger.TrackTime(_logger.TimeTrackerPhase.EXPORT)
|
||||
@_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT)
|
||||
@_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False)
|
||||
def _export_model(self, *inputs, **kwargs) -> bool:
|
||||
# 1. Set the self._device from the user module
|
||||
# 2. Verify input schema matches the schema used on the previous model export
|
||||
|
|
@ -304,90 +304,88 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
TODO: How to support dynamic axes? Dimensions are determined by samples
|
||||
"""
|
||||
with _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level) as suppress_output:
|
||||
# Setup dynamic axes for onnx model
|
||||
self._input_info = _io.parse_inputs_for_onnx_export(
|
||||
self._module_parameters, None, input_schema, inputs, kwargs
|
||||
)
|
||||
(
|
||||
output_names,
|
||||
output_dynamic_axes,
|
||||
self._module_output_schema,
|
||||
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
|
||||
self._original_module, inputs, kwargs, self._logger
|
||||
)
|
||||
self._input_info.dynamic_axes.update(output_dynamic_axes)
|
||||
|
||||
# FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs
|
||||
self._flattened_module._input_info = self._input_info
|
||||
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend)
|
||||
# INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend)
|
||||
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend)
|
||||
# Be noted: rank 0 log only is controlled by logger configured in _logger.py
|
||||
torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO
|
||||
self._logger.info("Exporting the PyTorch model to ONNX...")
|
||||
|
||||
# Export torch.nn.Module to ONNX
|
||||
f = io.BytesIO()
|
||||
# Setup dynamic axes for onnx model
|
||||
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
|
||||
(
|
||||
output_names,
|
||||
output_dynamic_axes,
|
||||
self._module_output_schema,
|
||||
) = _io.parse_outputs_for_onnx_export_and_extract_schema(self._original_module, inputs, kwargs, self._logger)
|
||||
self._input_info.dynamic_axes.update(output_dynamic_axes)
|
||||
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
|
||||
# Therefore, deepcopy only the data component of the input tensors for export.
|
||||
sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs)
|
||||
# NOTE: Flattening the input will change the 'input schema', resulting in a re-export
|
||||
sample_inputs_as_tuple = tuple(
|
||||
self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)
|
||||
)
|
||||
# Ops behaving differently under train/eval mode need to be exported with the
|
||||
# correct training flag to reflect the expected behavior.
|
||||
# For example, the Dropout node in a model is dropped under eval mode.
|
||||
assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager"
|
||||
# FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs
|
||||
self._flattened_module._input_info = self._input_info
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
required_export_kwargs = {
|
||||
"input_names": self._input_info.names,
|
||||
"output_names": output_names,
|
||||
"opset_version": self._runtime_options.onnx_opset_version,
|
||||
"do_constant_folding": False,
|
||||
"training": self._export_mode,
|
||||
"dynamic_axes": self._input_info.dynamic_axes,
|
||||
"verbose": self._debug_options.logging.log_level < LogLevel.WARNING,
|
||||
"export_params": False,
|
||||
"keep_initializers_as_inputs": True,
|
||||
}
|
||||
# Export torch.nn.Module to ONNX
|
||||
f = io.BytesIO()
|
||||
|
||||
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
|
||||
# From some PyTorch version, autograd_inlining is a valid argument.
|
||||
# We allow it to be True if custom autograd function is disabled (where autograd.Function
|
||||
# anyway is not supported in ONNX until it can be inlined).
|
||||
required_export_kwargs[
|
||||
"autograd_inlining"
|
||||
] = not self._runtime_options.enable_custom_autograd_function
|
||||
# Deepcopy inputs, since input values may change after model run.
|
||||
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
|
||||
# Therefore, deepcopy only the data component of the input tensors for export.
|
||||
sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs)
|
||||
# NOTE: Flattening the input will change the 'input schema', resulting in a re-export
|
||||
sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device))
|
||||
# Ops behaving differently under train/eval mode need to be exported with the
|
||||
# correct training flag to reflect the expected behavior.
|
||||
# For example, the Dropout node in a model is dropped under eval mode.
|
||||
assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager"
|
||||
|
||||
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
|
||||
assert (
|
||||
len(invalid_args) == 0
|
||||
), f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'."
|
||||
torch.onnx.export(
|
||||
self._flattened_module,
|
||||
sample_inputs_as_tuple,
|
||||
f,
|
||||
**required_export_kwargs,
|
||||
**self._export_extra_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise wrap_exception( # noqa: B904
|
||||
ORTModuleONNXModelException,
|
||||
RuntimeError(
|
||||
f"There was an error while exporting the PyTorch model to ONNX: "
|
||||
f"\n\n{_utils.get_exception_as_string(e)}"
|
||||
),
|
||||
try:
|
||||
with torch.no_grad():
|
||||
required_export_kwargs = {
|
||||
"input_names": self._input_info.names,
|
||||
"output_names": output_names,
|
||||
"opset_version": self._runtime_options.onnx_opset_version,
|
||||
"do_constant_folding": False,
|
||||
"training": self._export_mode,
|
||||
"dynamic_axes": self._input_info.dynamic_axes,
|
||||
"verbose": torch_exporter_verbose_log,
|
||||
"export_params": False,
|
||||
"keep_initializers_as_inputs": True,
|
||||
}
|
||||
|
||||
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
|
||||
# From some PyTorch version, autograd_inlining is a valid argument.
|
||||
# We allow it to be True if custom autograd function is disabled (where autograd.Function
|
||||
# anyway is not supported in ONNX until it can be inlined).
|
||||
required_export_kwargs[
|
||||
"autograd_inlining"
|
||||
] = not self._runtime_options.enable_custom_autograd_function
|
||||
|
||||
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
|
||||
|
||||
if len(invalid_args) != 0:
|
||||
error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'."
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
torch.onnx.export(
|
||||
self._flattened_module,
|
||||
sample_inputs_as_tuple,
|
||||
f,
|
||||
**required_export_kwargs,
|
||||
**self._export_extra_kwargs,
|
||||
)
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
exported_model = _post_process_after_export(
|
||||
exported_model, self._runtime_options.enable_custom_autograd_function
|
||||
except Exception as e:
|
||||
raise wrap_exception( # noqa: B904
|
||||
ORTModuleONNXModelException,
|
||||
RuntimeError(
|
||||
f"There was an error while exporting the PyTorch model to ONNX: "
|
||||
f"\n\n{_utils.get_exception_as_string(e)}"
|
||||
),
|
||||
)
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
# If anything was captured by suppress_output during export, set the flag to
|
||||
# raise a single user warning letting users know in the log.
|
||||
if suppress_output.tell() > 0:
|
||||
self._warning_log_detected_during_export = True
|
||||
exported_model = _post_process_after_export(
|
||||
exported_model, self._runtime_options.enable_custom_autograd_function
|
||||
)
|
||||
|
||||
return exported_model
|
||||
|
||||
|
|
@ -411,7 +409,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
|
||||
return graph_transformer_config
|
||||
|
||||
@_logger.TrackTime(_logger.TimeTrackerPhase.GRAPH_BUILDER_INIT)
|
||||
@_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)
|
||||
@_logger.SuppressLogs(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)
|
||||
def _initialize_graph_builder(self):
|
||||
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""
|
||||
|
||||
|
|
@ -483,7 +482,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
_utils.reinitialize_graph_execution_manager(self)
|
||||
|
||||
@_logger.TrackTime(_logger.TimeTrackerPhase.DETECTION)
|
||||
@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
|
||||
def _enable_conditional_optimizations(
|
||||
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
|
||||
):
|
||||
|
|
@ -544,11 +543,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._runtime_inspector.enable_memory_inspector(self._original_module)
|
||||
|
||||
def _log_feature_stats(self):
|
||||
rank = 0
|
||||
if torch.distributed.is_initialized():
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
if rank != 0:
|
||||
if get_rank() != 0:
|
||||
return
|
||||
|
||||
feature_map: List[Tuple[str, bool, str]] = [
|
||||
|
|
@ -638,11 +633,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
switch_str = "ON" if feature_tuple[1] else "OFF"
|
||||
stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n"
|
||||
|
||||
stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n"
|
||||
stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n"
|
||||
|
||||
# Collect ORTModule overheads for different phases.
|
||||
stat += f"{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n"
|
||||
stat += f"\n{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n"
|
||||
|
||||
stat += f"Versions: ONNX Runtime - {onnxruntime.__version__}, ONNX - {onnx.__version__}\n\n"
|
||||
stat += f"{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n"
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ from .options import DebugOptions
|
|||
|
||||
class GraphExecutionManagerFactory:
|
||||
def __init__(
|
||||
self, module: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
|
||||
self,
|
||||
module: _FlattenedModule,
|
||||
debug_options: DebugOptions,
|
||||
fallback_manager: _FallbackManager,
|
||||
logger: Logger,
|
||||
):
|
||||
self._training_manager = TrainingManager(module, debug_options, fallback_manager, logger)
|
||||
self._inference_manager = InferenceManager(module, debug_options, fallback_manager, logger)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg
|
|||
from ._execution_agent import InferenceAgent
|
||||
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
|
||||
from ._logger import TimeTrackerPhase, TrackTime
|
||||
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
|
||||
from ._utils import save_tuning_results, set_tuning_results
|
||||
from .options import DebugOptions, _SkipCheck
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
or not self._onnx_models.exported_model
|
||||
):
|
||||
self.time_tracker.start(TimeTrackerPhase.EndToEnd)
|
||||
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)
|
||||
|
||||
# Exporting module to ONNX for the first time
|
||||
build_graph = self._export_model(*inputs, **kwargs)
|
||||
|
|
@ -151,7 +151,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
# Create execution session creates the inference_session
|
||||
self._create_execution_agent()
|
||||
|
||||
self.time_tracker.end(TimeTrackerPhase.EndToEnd)
|
||||
self.time_tracker.end(ORTModuleInitPhase.EndToEnd)
|
||||
self._log_feature_stats()
|
||||
|
||||
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
|
|
@ -201,7 +201,8 @@ class InferenceManager(GraphExecutionManager):
|
|||
if self._fallback_manager.is_pending():
|
||||
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
|
||||
|
||||
@TrackTime(TimeTrackerPhase.BUILD_GRAPH)
|
||||
@TrackTime(ORTModuleInitPhase.BUILD_GRAPH)
|
||||
@SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH)
|
||||
def _build_graph(self, graph_transformer_config):
|
||||
"""Build an inference graph using the module_graph_builder"""
|
||||
|
||||
|
|
@ -214,7 +215,8 @@ class InferenceManager(GraphExecutionManager):
|
|||
self._export_mode,
|
||||
)
|
||||
|
||||
@TrackTime(TimeTrackerPhase.CREATE_SESSION)
|
||||
@TrackTime(ORTModuleInitPhase.CREATE_SESSION)
|
||||
@SuppressLogs(ORTModuleInitPhase.CREATE_SESSION)
|
||||
def _create_execution_agent(self):
|
||||
"""Creates an InferenceAgent that can run forward graph on an inference model"""
|
||||
|
||||
|
|
|
|||
|
|
@ -3,16 +3,21 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
from typing import Callable, Dict, List
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from onnxruntime.capi._pybind_state import Severity
|
||||
|
||||
from ._utils import get_rank, get_world_size
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
VERBOSE = 0
|
||||
|
|
@ -22,34 +27,6 @@ class LogLevel(IntEnum):
|
|||
FATAL = 4
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_os_stream_output(suppress_stdout=True, suppress_stderr=True, log_level=LogLevel.WARNING):
|
||||
"""Suppress output from being printed to stdout and stderr if log_level is WARNING or higher.
|
||||
|
||||
If there is any output detected, a single warning is issued in the context
|
||||
"""
|
||||
|
||||
# stdout and stderr is written to a tempfile instead
|
||||
stdout = sys.stdout
|
||||
stderr = sys.stderr
|
||||
|
||||
suppress_logs = log_level >= LogLevel.WARNING
|
||||
|
||||
fo = io.StringIO()
|
||||
|
||||
try:
|
||||
if suppress_stdout and suppress_logs:
|
||||
sys.stdout = fo
|
||||
if suppress_stderr and suppress_logs:
|
||||
sys.stderr = fo
|
||||
yield fo
|
||||
finally:
|
||||
if suppress_stdout:
|
||||
sys.stdout = stdout
|
||||
if suppress_stderr:
|
||||
sys.stderr = stderr
|
||||
|
||||
|
||||
ORTMODULE_LOG_LEVEL_MAP: Dict[LogLevel, List[int]] = {
|
||||
LogLevel.VERBOSE: [Severity.VERBOSE, logging.DEBUG],
|
||||
LogLevel.INFO: [Severity.INFO, logging.INFO],
|
||||
|
|
@ -67,6 +44,21 @@ def ortmodule_loglevel_to_python_loglevel(loglevel: LogLevel) -> int:
|
|||
return ORTMODULE_LOG_LEVEL_MAP.get(loglevel, [Severity.WARNING, logging.WARNING])[1]
|
||||
|
||||
|
||||
def configure_ortmodule_logger(log_level: LogLevel) -> logging.Logger:
|
||||
"""Configure the logger for ortmodule according to following rules.
|
||||
1. If multiple processes are used, the rank will be appended
|
||||
to the logger name.
|
||||
2. If the log level is greater than info, the logger will be
|
||||
disabled for non-zero ranks.
|
||||
"""
|
||||
rank_info = f".rank-{get_rank()}" if get_world_size() > 1 else ""
|
||||
logger = logging.getLogger(f"orttraining{rank_info}")
|
||||
# Disable the logger for non-zero ranks when level > info
|
||||
logger.disabled = log_level > LogLevel.INFO and get_rank() != 0
|
||||
logger.setLevel(ortmodule_loglevel_to_python_loglevel(log_level))
|
||||
return logger
|
||||
|
||||
|
||||
class LogColor:
|
||||
HEADER = "\033[95m"
|
||||
BLUE = "\033[94m"
|
||||
|
|
@ -79,26 +71,26 @@ class LogColor:
|
|||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
class TimeTrackerPhase(IntEnum):
|
||||
EndToEnd = 0 # The total overhead of ORT first-time initialization
|
||||
EXPORT = 1 # The latency of preparing and exporting the model to ONNX
|
||||
GRAPH_BUILDER_INIT = 2 # The latency of initializing the graph builder
|
||||
DETECTION = 3 # The latency of runtime detection
|
||||
BUILD_GRAPH = 4 # The latency of optimizing forward graph (and building the gradient graph for training).
|
||||
CREATE_SESSION = 5 # The latency of creating the session
|
||||
class ORTModuleInitPhase(IntEnum):
|
||||
EndToEnd = 0 # The end to end of ORT first-time initialization
|
||||
EXPORT = 1 # The phase of preparing and exporting the model to ONNX
|
||||
GRAPH_BUILDER_INIT = 2 # The phase of initializing the graph builder
|
||||
DETECTION = 3 # The phase of runtime detection
|
||||
BUILD_GRAPH = 4 # The phase of optimizing forward graph (and building the gradient graph for training).
|
||||
CREATE_SESSION = 5 # The phase of creating the session
|
||||
|
||||
def to_string(self) -> str:
|
||||
if self == TimeTrackerPhase.EndToEnd:
|
||||
if self == ORTModuleInitPhase.EndToEnd:
|
||||
return "end to end"
|
||||
if self == TimeTrackerPhase.EXPORT:
|
||||
if self == ORTModuleInitPhase.EXPORT:
|
||||
return "export"
|
||||
elif self == TimeTrackerPhase.GRAPH_BUILDER_INIT:
|
||||
elif self == ORTModuleInitPhase.GRAPH_BUILDER_INIT:
|
||||
return "graph builder init"
|
||||
elif self == TimeTrackerPhase.DETECTION:
|
||||
elif self == ORTModuleInitPhase.DETECTION:
|
||||
return "runtime detection"
|
||||
elif self == TimeTrackerPhase.BUILD_GRAPH:
|
||||
elif self == ORTModuleInitPhase.BUILD_GRAPH:
|
||||
return "graph building"
|
||||
elif self == TimeTrackerPhase.CREATE_SESSION:
|
||||
elif self == ORTModuleInitPhase.CREATE_SESSION:
|
||||
return "session creation"
|
||||
else:
|
||||
return "invalid"
|
||||
|
|
@ -112,24 +104,24 @@ class TimeTracker:
|
|||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase)
|
||||
self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase)
|
||||
self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase)
|
||||
self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase)
|
||||
|
||||
def start(self, phase: TimeTrackerPhase):
|
||||
def start(self, phase: ORTModuleInitPhase):
|
||||
self.starts_[phase] = time.time()
|
||||
|
||||
def end(self, phase: TimeTrackerPhase):
|
||||
def end(self, phase: ORTModuleInitPhase):
|
||||
self.ends_[phase] = time.time()
|
||||
|
||||
def _get_duration(self, phase: TimeTrackerPhase):
|
||||
def _get_duration(self, phase: ORTModuleInitPhase):
|
||||
if self.ends_[phase] == TimeTracker.NOT_RECORD or self.starts_[phase] == TimeTracker.NOT_RECORD:
|
||||
return TimeTracker.NOT_RECORD
|
||||
return self.ends_[phase] - self.starts_[phase]
|
||||
|
||||
def to_string(self, log_details=False) -> str:
|
||||
end_to_end_str = self._get_duration(TimeTrackerPhase.EndToEnd)
|
||||
end_to_end_str = self._get_duration(ORTModuleInitPhase.EndToEnd)
|
||||
end_to_end_str = f"{end_to_end_str:.2f}" if end_to_end_str != TimeTracker.NOT_RECORD else "N/A"
|
||||
export_str = self._get_duration(TimeTrackerPhase.EXPORT)
|
||||
export_str = self._get_duration(ORTModuleInitPhase.EXPORT)
|
||||
export_str = f"{export_str:.2f}" if export_str != TimeTracker.NOT_RECORD else "N/A"
|
||||
overhead_title_str = (
|
||||
f"Total ORT initialization overhead is {end_to_end_str}s where export takes {export_str}s.\n"
|
||||
|
|
@ -139,9 +131,9 @@ class TimeTracker:
|
|||
return overhead_title_str
|
||||
|
||||
duration_summaries = []
|
||||
for phase in TimeTrackerPhase:
|
||||
for phase in ORTModuleInitPhase:
|
||||
_get_duration = self._get_duration(phase)
|
||||
if phase in [TimeTrackerPhase.EndToEnd, TimeTrackerPhase.EXPORT]:
|
||||
if phase in [ORTModuleInitPhase.EndToEnd, ORTModuleInitPhase.EXPORT]:
|
||||
continue
|
||||
|
||||
val = (
|
||||
|
|
@ -155,7 +147,7 @@ class TimeTracker:
|
|||
class TrackTime:
|
||||
"""A function decorator to track time spent in different phases of ORT backend first-time initialization."""
|
||||
|
||||
def __init__(self, phase: TimeTrackerPhase):
|
||||
def __init__(self, phase: ORTModuleInitPhase):
|
||||
self.phase = phase
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
|
|
@ -168,3 +160,108 @@ class TrackTime:
|
|||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None):
|
||||
"""Suppress output from being printed to stdout and stderr.
|
||||
|
||||
If on_exit is not None, it will be called when the context manager exits.
|
||||
"""
|
||||
if enable:
|
||||
# stdout and stderr is written to a tempfile instead
|
||||
with tempfile.TemporaryFile() as fp:
|
||||
try:
|
||||
# Store original stdout and stderr file no.
|
||||
old_stdout = os.dup(sys.stdout.fileno())
|
||||
old_stderr = os.dup(sys.stderr.fileno())
|
||||
|
||||
# Redirect stdout and stderr (printed from Python or C++) to the file.
|
||||
os.dup2(fp.fileno(), sys.stdout.fileno())
|
||||
os.dup2(fp.fileno(), sys.stderr.fileno())
|
||||
yield
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
# Restore stdout and stderr.
|
||||
os.dup2(old_stdout, sys.stdout.fileno())
|
||||
os.dup2(old_stderr, sys.stderr.fileno())
|
||||
|
||||
if on_exit:
|
||||
on_exit(fp)
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _log_with_filter(logger: logging.Logger, record_filters: Optional[List[str]], name: Optional[str], fo):
|
||||
"""Log the content by filtering with list of string patterns.
|
||||
Args:
|
||||
logger: The logger to log the content.
|
||||
record_filters: The list of string patterns to filter the content.
|
||||
If record_filters is None, the full content will be logged.
|
||||
name: The name of log filter.
|
||||
fo: The file object to read the content.
|
||||
"""
|
||||
if fo.tell() > 0:
|
||||
if logger.disabled:
|
||||
return
|
||||
|
||||
fo.seek(0)
|
||||
suppress_output_messages = fo.readlines()
|
||||
if record_filters:
|
||||
filtered_messages = []
|
||||
filtered_lines = 0
|
||||
for suppressed_message in suppress_output_messages:
|
||||
msg = suppressed_message.decode("utf-8")
|
||||
found = False
|
||||
for warning in record_filters:
|
||||
if warning in msg:
|
||||
found = True
|
||||
filtered_lines += 1
|
||||
break
|
||||
|
||||
if not found:
|
||||
filtered_messages.extend(textwrap.wrap(msg, 180))
|
||||
if filtered_messages:
|
||||
filtered_messages.insert(0, f"[{name}] Filtered logs ({filtered_lines} records suppressed):")
|
||||
logger.warning("\n ".join(filtered_messages))
|
||||
else:
|
||||
out_messages = []
|
||||
for suppressed_message in suppress_output_messages:
|
||||
out_messages.extend(textwrap.wrap(suppressed_message.decode("utf-8"), 180))
|
||||
if out_messages:
|
||||
out_messages.insert(0, f"[{name}] Full logs:")
|
||||
logger.warning("\n ".join(out_messages))
|
||||
|
||||
|
||||
class SuppressLogs:
|
||||
"""A function decorator to suppress in different phases of ORT backend first-time initialization."""
|
||||
|
||||
def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True):
|
||||
self.phase = phase
|
||||
self.is_ort_filter = is_ort_filter
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
def wrapper(graph_execution_manager, *args, **kwargs):
|
||||
if not hasattr(graph_execution_manager, "_logger"):
|
||||
raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.")
|
||||
|
||||
if not hasattr(graph_execution_manager, "_debug_options"):
|
||||
raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.")
|
||||
|
||||
with _suppress_os_stream_output(
|
||||
enable=graph_execution_manager._debug_options.log_level >= LogLevel.INFO,
|
||||
on_exit=partial(
|
||||
_log_with_filter,
|
||||
graph_execution_manager._logger,
|
||||
graph_execution_manager._debug_options.onnxruntime_log_filter
|
||||
if self.is_ort_filter
|
||||
else graph_execution_manager._debug_options.torch_exporter_filter,
|
||||
self.phase.to_string(),
|
||||
),
|
||||
):
|
||||
result = func(graph_execution_manager, *args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class InputDensityObserver:
|
|||
"""
|
||||
|
||||
def __init__(self, logger: Logger, log_steps=1):
|
||||
self._logger: Logger = logger
|
||||
self._logger = logger
|
||||
self._embedding_graph_input_to_padding_idx_map = {}
|
||||
self._loss_label_graph_input_to_ignore_idx_map = {}
|
||||
self._stats = []
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ T = TypeVar("T", bound="torch.nn.Module")
|
|||
|
||||
class TorchModuleORT(TorchModuleInterface):
|
||||
def __init__(
|
||||
self, module: torch.nn.Module, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
debug_options: DebugOptions,
|
||||
fallback_manager: _FallbackManager,
|
||||
logger: Logger,
|
||||
):
|
||||
super().__init__(module)
|
||||
self._flattened_module = _io._FlattenedModule(module)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo
|
|||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
|
||||
from ._io import _FlattenedModule, _InputInfo
|
||||
from ._logger import TimeTrackerPhase, TrackTime
|
||||
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
|
||||
from ._runtime_inspector import Phase
|
||||
from ._utils import save_tuning_results, set_tuning_results
|
||||
from .graph_transformer_registry import GraphTransformerRegistry
|
||||
|
|
@ -32,7 +32,11 @@ class TrainingManager(GraphExecutionManager):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, model: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
|
||||
self,
|
||||
model: _FlattenedModule,
|
||||
debug_options: DebugOptions,
|
||||
fallback_manager: _FallbackManager,
|
||||
logger: Logger,
|
||||
):
|
||||
super().__init__(model, debug_options, fallback_manager, logger)
|
||||
|
||||
|
|
@ -246,7 +250,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
or not self._onnx_models.exported_model
|
||||
):
|
||||
self.time_tracker.start(TimeTrackerPhase.EndToEnd)
|
||||
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)
|
||||
|
||||
build_gradient_graph = self._export_model(*inputs, **kwargs)
|
||||
|
||||
|
|
@ -302,7 +306,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info
|
||||
)
|
||||
|
||||
self.time_tracker.end(TimeTrackerPhase.EndToEnd)
|
||||
self.time_tracker.end(ORTModuleInitPhase.EndToEnd)
|
||||
self._log_feature_stats()
|
||||
|
||||
self._gradient_accumulation_manager.maybe_update_cache_before_run()
|
||||
|
|
@ -349,7 +353,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
if self._fallback_manager.is_pending():
|
||||
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
|
||||
|
||||
@TrackTime(TimeTrackerPhase.BUILD_GRAPH)
|
||||
@TrackTime(ORTModuleInitPhase.BUILD_GRAPH)
|
||||
@SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH)
|
||||
def _build_graph(self, graph_transformer_config):
|
||||
"""Build an optimized gradient graph using the module_graph_builder"""
|
||||
|
||||
|
|
@ -394,7 +399,8 @@ class TrainingManager(GraphExecutionManager):
|
|||
else:
|
||||
self._gradient_map.append(-1)
|
||||
|
||||
@TrackTime(TimeTrackerPhase.CREATE_SESSION)
|
||||
@TrackTime(ORTModuleInitPhase.CREATE_SESSION)
|
||||
@SuppressLogs(ORTModuleInitPhase.CREATE_SESSION)
|
||||
def _create_execution_agent(self):
|
||||
"""Creates a TrainingAgent that can run the forward and backward graph on the training model"""
|
||||
|
||||
|
|
@ -427,6 +433,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
] * len(bw_fetches_names)
|
||||
|
||||
local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device)
|
||||
|
||||
self._execution_agent = TrainingAgent(
|
||||
self._onnx_models.optimized_model.SerializeToString(),
|
||||
fw_feed_names,
|
||||
|
|
|
|||
|
|
@ -444,3 +444,19 @@ def set_tuning_results(session, is_training, tuning_results_path):
|
|||
if os.path.isfile(tuning_result_file):
|
||||
with open(tuning_result_file, encoding="utf-8") as f:
|
||||
session.set_tuning_results(json.load(f))
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Returns the rank of the current process. If distributed training is not initialized, returns 0."""
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size of the current process. If distributed training is not initialized, returns 1."""
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_world_size()
|
||||
|
||||
return 1
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@ from enum import IntFlag
|
|||
from functools import reduce
|
||||
from logging import Logger
|
||||
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from ._fallback import _FallbackPolicy
|
||||
from ._logger import LogLevel
|
||||
from ._utils import parse_os_env_skip_check_flags
|
||||
from ._utils import get_runtime_pytorch_version, parse_os_env_skip_check_flags
|
||||
|
||||
|
||||
class _SaveOnnxOptions:
|
||||
|
|
@ -131,6 +133,39 @@ class DebugOptions:
|
|||
|
||||
return self._logging
|
||||
|
||||
@property
|
||||
def torch_exporter_filter(self):
|
||||
"""Accessor for the filter export logs configuration."""
|
||||
if self.log_level >= LogLevel.INFO and get_runtime_pytorch_version() < version.parse("2.0"):
|
||||
return [
|
||||
# WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
|
||||
# WARNING: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
|
||||
# WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
|
||||
# WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
|
||||
"type is missing, so it may result in wrong shape inference",
|
||||
# Warning: Checker does not support models with experimental ops: ATen
|
||||
"Checker does not support models with experimental ops:",
|
||||
"Dropout is a training op and should not be exported in inference mode.",
|
||||
# Warning: Shape inference does not support models with experimental operators: ATen
|
||||
"Shape inference does not support models with experimental operators:",
|
||||
# Warning: Unsupported operator Trilu. No schema registered for this operator.
|
||||
# Warning: Unsupported operator ATen. No schema registered for this operator.
|
||||
# Warning: Unsupported operator SoftmaxCrossEntropyLossInternal. No schema registered for this operator.
|
||||
"No schema registered for this operator.",
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def onnxruntime_log_filter(self):
|
||||
"""Accessor for the filter onnxruntime logs configuration."""
|
||||
if self.log_level >= LogLevel.INFO:
|
||||
return [
|
||||
"CleanUnusedInitializersAndNodeArgs] Removing initializer",
|
||||
"Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED",
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
class _SkipCheck(IntFlag):
|
||||
"""Enumeration to specify which checks should be skipped, allowing faster execution"""
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# isort: skip_file
|
||||
# Import ordering is important in this module to aviod circular dependencies
|
||||
import logging
|
||||
# Import ordering is important in this module to avoid circular dependencies
|
||||
|
||||
from ._torch_module_factory import TorchModuleFactory
|
||||
from ._torch_module_ort import TorchModuleORT
|
||||
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
|
||||
|
|
@ -12,7 +12,7 @@ from ._custom_gradient_registry import CustomGradientRegistry
|
|||
from . import _utils
|
||||
from .options import DebugOptions
|
||||
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException
|
||||
from ._logger import ortmodule_loglevel_to_python_loglevel
|
||||
from ._logger import configure_ortmodule_logger
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
|
|
@ -53,8 +53,7 @@ class ORTModule(torch.nn.Module):
|
|||
if not debug_options:
|
||||
debug_options = DebugOptions()
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._logger.setLevel(ortmodule_loglevel_to_python_loglevel(debug_options.logging.log_level))
|
||||
self._logger = configure_ortmodule_logger(debug_options.logging.log_level)
|
||||
|
||||
# Fallback settings
|
||||
self._fallback_manager = _FallbackManager(
|
||||
|
|
|
|||
|
|
@ -6061,3 +6061,52 @@ def test_e2e_padding_elimination():
|
|||
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
|
||||
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
|
||||
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
Version(torch.__version__) >= Version("1.13.0"),
|
||||
reason="PyTorch since 1.13 don't output expected warning messages any more",
|
||||
)
|
||||
@pytest.mark.parametrize("log_level", [LogLevel.VERBOSE, LogLevel.INFO, LogLevel.WARNING])
|
||||
def test_ortmodule_log_level_control(log_level, caplog):
|
||||
class NeuralNetCrossEntropyLoss(torch.nn.Module):
|
||||
def __init__(self, num_embeddings, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=1)
|
||||
|
||||
def forward(self, input, positions):
|
||||
output = torch.transpose(self.embedding(input), 0, 1)
|
||||
ignored_index = output.size(1)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
|
||||
return loss_fct(output, positions)
|
||||
|
||||
device = "cuda"
|
||||
num_embeddings, embedding_dim = 32, 128
|
||||
pt_model = NeuralNetCrossEntropyLoss(num_embeddings, embedding_dim).to(device)
|
||||
|
||||
ort_model = ORTModule(pt_model, DebugOptions(log_level=log_level))
|
||||
use_fp16 = True
|
||||
|
||||
def run_step(model, input, positions):
|
||||
with amp.autocast(use_fp16):
|
||||
loss = model(input, positions)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
N = random.randint(16, 32) # noqa: N806
|
||||
input = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64, device=device)
|
||||
positions = torch.randint(high=N, size=(embedding_dim,), dtype=torch.int64, device=device)
|
||||
_ = run_step(ort_model, input, positions)
|
||||
|
||||
found_missing_inference_log = False
|
||||
for record in caplog.records:
|
||||
msg = record.getMessage()
|
||||
print(msg)
|
||||
if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg:
|
||||
found_missing_inference_log = True
|
||||
break
|
||||
|
||||
if log_level == LogLevel.VERBOSE:
|
||||
assert found_missing_inference_log
|
||||
else:
|
||||
assert not found_missing_inference_log
|
||||
|
|
|
|||
|
|
@ -564,7 +564,7 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog):
|
|||
if i == 0:
|
||||
# For the first time, run ORTModule, feature map is logged as warning
|
||||
# And the fallback warning is logged.
|
||||
assert len(caplog.records) == 2
|
||||
assert len(caplog.records) >= 2
|
||||
else:
|
||||
# For the other time, only the fallback warning is logged.
|
||||
assert len(caplog.records) == 1
|
||||
|
|
@ -576,8 +576,8 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog):
|
|||
if i == 0:
|
||||
# For the first time, run ORTModule, feature map is logged as warning
|
||||
# And the fallback warning is logged.
|
||||
assert len(caplog.records) == 2
|
||||
assert "Fallback to PyTorch due to exception" in caplog.records[1].message
|
||||
assert len(caplog.records) >= 2
|
||||
assert "Fallback to PyTorch due to exception" in caplog.records[-1].message
|
||||
caplog.clear()
|
||||
else:
|
||||
# For the other time, no fallback warning will be logged because
|
||||
|
|
|
|||
Loading…
Reference in a new issue