ORTModule log clean up (#16795)

### ORTModule log clean up

ORTModule log level - WARNING(Default) is for end users; INFO and
VERBOSE is for internal ORT training developers.

Few issues: 
1. ONNX export will output lots of WARNING error message like "The shape
inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing", which is useless for us or end users.

![image](https://github.com/microsoft/onnxruntime/assets/10530022/f2409480-32e1-483d-bd18-f14149f0588d)

3. ORT also print some information like
""CleanUnusedInitializersAndNodeArgs] Removing
initializer","ReverseBFSWithStopGradient] Skip building gradient for",
which is also useless for us or end users most of the time.

![image](https://github.com/microsoft/onnxruntime/assets/10530022/ff3feaf1-3cb2-4392-b087-86b30b72994c)


5. Different ranks output logs and making ORT developers or end users
feels there are too many logs but usually not useful until we need
investigate.

Few improvements for the issues:
1. For ONNX export logs, there are two kinds of logs: a. export verbose
log; b. other logs printed by torch C++ backend. So this PR make
following change:
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout
and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from
stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other
logs from stdout and stderr (C++ backend)

e.g. for verbose level, print all logs as usually; for info level, print
verbose export log, and filtered logs from torch C++ backend (removing
messages like this "The shape inference of
com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is
missing") . For higher level, only log the info on rank 0.

2. For ORT gradient graph build and session creation, also suppress the
message and filtered out the message when log level >=INFO.

3. log level > INFO, then only logs on rank 0 is logged, to have a
cleaner user experience


This is the log for a BLOOM model training after the change: there are
limited of warnings.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/f270b8d5-2944-49d2-a253-c07057d641a0)
This commit is contained in:
pengwa 2023-07-26 12:42:50 +08:00 committed by GitHub
parent bf006d34a9
commit 39fca225ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 388 additions and 176 deletions

View file

@ -33,7 +33,10 @@ Status ConcatSliceElimination::ApplyImpl(Graph& graph, bool& modified, int graph
modified = true;
}
}
LOGS(logger, INFO) << "Total fused concat node count: " << fused_count;
if (fused_count > 0) {
LOGS(logger, INFO) << "Total fused concat node count: " << fused_count;
}
return Status::OK();
}

View file

@ -252,8 +252,9 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
modified = true;
}
LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count;
if (shared_count > 0) {
LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count;
}
return Status::OK();
}

View file

@ -50,7 +50,10 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
modified = true;
}
}
LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count;
if (fused_count > 0) {
LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count;
}
return Status::OK();
}

View file

@ -164,8 +164,8 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c
const std::unordered_set<size_t>* edges = GetStopGradientEdges(*n);
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) {
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name();
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name();
continue;
}
const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()];
@ -173,13 +173,13 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c
if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) {
const int32_t type = type_proto->tensor_type().elem_type();
if (GRAD_ALLOWED_TYPES.find(type) == GRAD_ALLOWED_TYPES.end()) {
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name() << "because element type is: " << type;
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name() << "because element type is: " << type;
continue;
}
} else {
LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name() << "because it is not a Tensor type";
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name() << "because it is not a Tensor type";
continue;
}
@ -280,7 +280,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
const std::unordered_set<size_t>* edges = GetStopGradientEdges(next_node);
if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) {
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << next_node.Name();
continue;
}

View file

@ -33,7 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_interface import GraphExecutionInterface
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
from ._runtime_inspector import RuntimeInspector
from ._utils import check_function_has_param
from ._utils import check_function_has_param, get_rank
from .options import DebugOptions, LogLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
@ -104,7 +104,6 @@ class GraphExecutionManager(GraphExecutionInterface):
# Input and output infos (including schema) for exported model.
self._input_info: Optional[_InputInfo] = None
self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None
self._warning_log_detected_during_export = False
# Device where the model is placed.
self._device: Optional[torch.device] = _utils.get_device_from_module(module)
@ -253,7 +252,8 @@ class GraphExecutionManager(GraphExecutionInterface):
return session_options, providers, provider_options
@_logger.TrackTime(_logger.TimeTrackerPhase.EXPORT)
@_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT)
@_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False)
def _export_model(self, *inputs, **kwargs) -> bool:
# 1. Set the self._device from the user module
# 2. Verify input schema matches the schema used on the previous model export
@ -304,90 +304,88 @@ class GraphExecutionManager(GraphExecutionInterface):
TODO: How to support dynamic axes? Dimensions are determined by samples
"""
with _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level) as suppress_output:
# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(
self._module_parameters, None, input_schema, inputs, kwargs
)
(
output_names,
output_dynamic_axes,
self._module_output_schema,
) = _io.parse_outputs_for_onnx_export_and_extract_schema(
self._original_module, inputs, kwargs, self._logger
)
self._input_info.dynamic_axes.update(output_dynamic_axes)
# FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs
self._flattened_module._input_info = self._input_info
# VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend)
# INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend)
# WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend)
# Be noted: rank 0 log only is controlled by logger configured in _logger.py
torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO
self._logger.info("Exporting the PyTorch model to ONNX...")
# Export torch.nn.Module to ONNX
f = io.BytesIO()
# Setup dynamic axes for onnx model
self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs)
(
output_names,
output_dynamic_axes,
self._module_output_schema,
) = _io.parse_outputs_for_onnx_export_and_extract_schema(self._original_module, inputs, kwargs, self._logger)
self._input_info.dynamic_axes.update(output_dynamic_axes)
# Deepcopy inputs, since input values may change after model run.
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
# Therefore, deepcopy only the data component of the input tensors for export.
sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs)
# NOTE: Flattening the input will change the 'input schema', resulting in a re-export
sample_inputs_as_tuple = tuple(
self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)
)
# Ops behaving differently under train/eval mode need to be exported with the
# correct training flag to reflect the expected behavior.
# For example, the Dropout node in a model is dropped under eval mode.
assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager"
# FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs
self._flattened_module._input_info = self._input_info
try:
with torch.no_grad():
required_export_kwargs = {
"input_names": self._input_info.names,
"output_names": output_names,
"opset_version": self._runtime_options.onnx_opset_version,
"do_constant_folding": False,
"training": self._export_mode,
"dynamic_axes": self._input_info.dynamic_axes,
"verbose": self._debug_options.logging.log_level < LogLevel.WARNING,
"export_params": False,
"keep_initializers_as_inputs": True,
}
# Export torch.nn.Module to ONNX
f = io.BytesIO()
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
# From some PyTorch version, autograd_inlining is a valid argument.
# We allow it to be True if custom autograd function is disabled (where autograd.Function
# anyway is not supported in ONNX until it can be inlined).
required_export_kwargs[
"autograd_inlining"
] = not self._runtime_options.enable_custom_autograd_function
# Deepcopy inputs, since input values may change after model run.
# NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn).
# Therefore, deepcopy only the data component of the input tensors for export.
sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs)
# NOTE: Flattening the input will change the 'input schema', resulting in a re-export
sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device))
# Ops behaving differently under train/eval mode need to be exported with the
# correct training flag to reflect the expected behavior.
# For example, the Dropout node in a model is dropped under eval mode.
assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager"
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
assert (
len(invalid_args) == 0
), f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'."
torch.onnx.export(
self._flattened_module,
sample_inputs_as_tuple,
f,
**required_export_kwargs,
**self._export_extra_kwargs,
)
except Exception as e:
raise wrap_exception( # noqa: B904
ORTModuleONNXModelException,
RuntimeError(
f"There was an error while exporting the PyTorch model to ONNX: "
f"\n\n{_utils.get_exception_as_string(e)}"
),
try:
with torch.no_grad():
required_export_kwargs = {
"input_names": self._input_info.names,
"output_names": output_names,
"opset_version": self._runtime_options.onnx_opset_version,
"do_constant_folding": False,
"training": self._export_mode,
"dynamic_axes": self._input_info.dynamic_axes,
"verbose": torch_exporter_verbose_log,
"export_params": False,
"keep_initializers_as_inputs": True,
}
if check_function_has_param(torch.onnx.export, "autograd_inlining"):
# From some PyTorch version, autograd_inlining is a valid argument.
# We allow it to be True if custom autograd function is disabled (where autograd.Function
# anyway is not supported in ONNX until it can be inlined).
required_export_kwargs[
"autograd_inlining"
] = not self._runtime_options.enable_custom_autograd_function
invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
if len(invalid_args) != 0:
error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'."
raise RuntimeError(error_msg)
torch.onnx.export(
self._flattened_module,
sample_inputs_as_tuple,
f,
**required_export_kwargs,
**self._export_extra_kwargs,
)
exported_model = onnx.load_model_from_string(f.getvalue())
exported_model = _post_process_after_export(
exported_model, self._runtime_options.enable_custom_autograd_function
except Exception as e:
raise wrap_exception( # noqa: B904
ORTModuleONNXModelException,
RuntimeError(
f"There was an error while exporting the PyTorch model to ONNX: "
f"\n\n{_utils.get_exception_as_string(e)}"
),
)
exported_model = onnx.load_model_from_string(f.getvalue())
# If anything was captured by suppress_output during export, set the flag to
# raise a single user warning letting users know in the log.
if suppress_output.tell() > 0:
self._warning_log_detected_during_export = True
exported_model = _post_process_after_export(
exported_model, self._runtime_options.enable_custom_autograd_function
)
return exported_model
@ -411,7 +409,8 @@ class GraphExecutionManager(GraphExecutionInterface):
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
return graph_transformer_config
@_logger.TrackTime(_logger.TimeTrackerPhase.GRAPH_BUILDER_INIT)
@_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)
@_logger.SuppressLogs(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)
def _initialize_graph_builder(self):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""
@ -483,7 +482,7 @@ class GraphExecutionManager(GraphExecutionInterface):
_utils.reinitialize_graph_execution_manager(self)
@_logger.TrackTime(_logger.TimeTrackerPhase.DETECTION)
@_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION)
def _enable_conditional_optimizations(
self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict
):
@ -544,11 +543,7 @@ class GraphExecutionManager(GraphExecutionInterface):
self._runtime_inspector.enable_memory_inspector(self._original_module)
def _log_feature_stats(self):
rank = 0
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
if rank != 0:
if get_rank() != 0:
return
feature_map: List[Tuple[str, bool, str]] = [
@ -638,11 +633,8 @@ class GraphExecutionManager(GraphExecutionInterface):
switch_str = "ON" if feature_tuple[1] else "OFF"
stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n"
stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n"
stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n"
# Collect ORTModule overheads for different phases.
stat += f"{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n"
stat += f"\n{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n"
stat += f"Versions: ONNX Runtime - {onnxruntime.__version__}, ONNX - {onnx.__version__}\n\n"
stat += f"{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n"

View file

@ -15,7 +15,11 @@ from .options import DebugOptions
class GraphExecutionManagerFactory:
def __init__(
self, module: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
self,
module: _FlattenedModule,
debug_options: DebugOptions,
fallback_manager: _FallbackManager,
logger: Logger,
):
self._training_manager = TrainingManager(module, debug_options, fallback_manager, logger)
self._inference_manager = InferenceManager(module, debug_options, fallback_manager, logger)

View file

@ -15,7 +15,7 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg
from ._execution_agent import InferenceAgent
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from ._logger import TimeTrackerPhase, TrackTime
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
from ._utils import save_tuning_results, set_tuning_results
from .options import DebugOptions, _SkipCheck
@ -111,7 +111,7 @@ class InferenceManager(GraphExecutionManager):
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
):
self.time_tracker.start(TimeTrackerPhase.EndToEnd)
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)
# Exporting module to ONNX for the first time
build_graph = self._export_model(*inputs, **kwargs)
@ -151,7 +151,7 @@ class InferenceManager(GraphExecutionManager):
# Create execution session creates the inference_session
self._create_execution_agent()
self.time_tracker.end(TimeTrackerPhase.EndToEnd)
self.time_tracker.end(ORTModuleInitPhase.EndToEnd)
self._log_feature_stats()
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
@ -201,7 +201,8 @@ class InferenceManager(GraphExecutionManager):
if self._fallback_manager.is_pending():
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
@TrackTime(TimeTrackerPhase.BUILD_GRAPH)
@TrackTime(ORTModuleInitPhase.BUILD_GRAPH)
@SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH)
def _build_graph(self, graph_transformer_config):
"""Build an inference graph using the module_graph_builder"""
@ -214,7 +215,8 @@ class InferenceManager(GraphExecutionManager):
self._export_mode,
)
@TrackTime(TimeTrackerPhase.CREATE_SESSION)
@TrackTime(ORTModuleInitPhase.CREATE_SESSION)
@SuppressLogs(ORTModuleInitPhase.CREATE_SESSION)
def _create_execution_agent(self):
"""Creates an InferenceAgent that can run forward graph on an inference model"""

View file

@ -3,16 +3,21 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import io
import logging
import os
import sys
import tempfile
import textwrap
import time
from contextlib import contextmanager
from enum import IntEnum
from typing import Callable, Dict, List
from functools import partial
from typing import Callable, Dict, List, Optional
from onnxruntime.capi._pybind_state import Severity
from ._utils import get_rank, get_world_size
class LogLevel(IntEnum):
VERBOSE = 0
@ -22,34 +27,6 @@ class LogLevel(IntEnum):
FATAL = 4
@contextmanager
def suppress_os_stream_output(suppress_stdout=True, suppress_stderr=True, log_level=LogLevel.WARNING):
"""Suppress output from being printed to stdout and stderr if log_level is WARNING or higher.
If there is any output detected, a single warning is issued in the context
"""
# stdout and stderr is written to a tempfile instead
stdout = sys.stdout
stderr = sys.stderr
suppress_logs = log_level >= LogLevel.WARNING
fo = io.StringIO()
try:
if suppress_stdout and suppress_logs:
sys.stdout = fo
if suppress_stderr and suppress_logs:
sys.stderr = fo
yield fo
finally:
if suppress_stdout:
sys.stdout = stdout
if suppress_stderr:
sys.stderr = stderr
ORTMODULE_LOG_LEVEL_MAP: Dict[LogLevel, List[int]] = {
LogLevel.VERBOSE: [Severity.VERBOSE, logging.DEBUG],
LogLevel.INFO: [Severity.INFO, logging.INFO],
@ -67,6 +44,21 @@ def ortmodule_loglevel_to_python_loglevel(loglevel: LogLevel) -> int:
return ORTMODULE_LOG_LEVEL_MAP.get(loglevel, [Severity.WARNING, logging.WARNING])[1]
def configure_ortmodule_logger(log_level: LogLevel) -> logging.Logger:
"""Configure the logger for ortmodule according to following rules.
1. If multiple processes are used, the rank will be appended
to the logger name.
2. If the log level is greater than info, the logger will be
disabled for non-zero ranks.
"""
rank_info = f".rank-{get_rank()}" if get_world_size() > 1 else ""
logger = logging.getLogger(f"orttraining{rank_info}")
# Disable the logger for non-zero ranks when level > info
logger.disabled = log_level > LogLevel.INFO and get_rank() != 0
logger.setLevel(ortmodule_loglevel_to_python_loglevel(log_level))
return logger
class LogColor:
HEADER = "\033[95m"
BLUE = "\033[94m"
@ -79,26 +71,26 @@ class LogColor:
UNDERLINE = "\033[4m"
class TimeTrackerPhase(IntEnum):
EndToEnd = 0 # The total overhead of ORT first-time initialization
EXPORT = 1 # The latency of preparing and exporting the model to ONNX
GRAPH_BUILDER_INIT = 2 # The latency of initializing the graph builder
DETECTION = 3 # The latency of runtime detection
BUILD_GRAPH = 4 # The latency of optimizing forward graph (and building the gradient graph for training).
CREATE_SESSION = 5 # The latency of creating the session
class ORTModuleInitPhase(IntEnum):
EndToEnd = 0 # The end to end of ORT first-time initialization
EXPORT = 1 # The phase of preparing and exporting the model to ONNX
GRAPH_BUILDER_INIT = 2 # The phase of initializing the graph builder
DETECTION = 3 # The phase of runtime detection
BUILD_GRAPH = 4 # The phase of optimizing forward graph (and building the gradient graph for training).
CREATE_SESSION = 5 # The phase of creating the session
def to_string(self) -> str:
if self == TimeTrackerPhase.EndToEnd:
if self == ORTModuleInitPhase.EndToEnd:
return "end to end"
if self == TimeTrackerPhase.EXPORT:
if self == ORTModuleInitPhase.EXPORT:
return "export"
elif self == TimeTrackerPhase.GRAPH_BUILDER_INIT:
elif self == ORTModuleInitPhase.GRAPH_BUILDER_INIT:
return "graph builder init"
elif self == TimeTrackerPhase.DETECTION:
elif self == ORTModuleInitPhase.DETECTION:
return "runtime detection"
elif self == TimeTrackerPhase.BUILD_GRAPH:
elif self == ORTModuleInitPhase.BUILD_GRAPH:
return "graph building"
elif self == TimeTrackerPhase.CREATE_SESSION:
elif self == ORTModuleInitPhase.CREATE_SESSION:
return "session creation"
else:
return "invalid"
@ -112,24 +104,24 @@ class TimeTracker:
def __init__(
self,
):
self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase)
self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase)
self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase)
self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase)
def start(self, phase: TimeTrackerPhase):
def start(self, phase: ORTModuleInitPhase):
self.starts_[phase] = time.time()
def end(self, phase: TimeTrackerPhase):
def end(self, phase: ORTModuleInitPhase):
self.ends_[phase] = time.time()
def _get_duration(self, phase: TimeTrackerPhase):
def _get_duration(self, phase: ORTModuleInitPhase):
if self.ends_[phase] == TimeTracker.NOT_RECORD or self.starts_[phase] == TimeTracker.NOT_RECORD:
return TimeTracker.NOT_RECORD
return self.ends_[phase] - self.starts_[phase]
def to_string(self, log_details=False) -> str:
end_to_end_str = self._get_duration(TimeTrackerPhase.EndToEnd)
end_to_end_str = self._get_duration(ORTModuleInitPhase.EndToEnd)
end_to_end_str = f"{end_to_end_str:.2f}" if end_to_end_str != TimeTracker.NOT_RECORD else "N/A"
export_str = self._get_duration(TimeTrackerPhase.EXPORT)
export_str = self._get_duration(ORTModuleInitPhase.EXPORT)
export_str = f"{export_str:.2f}" if export_str != TimeTracker.NOT_RECORD else "N/A"
overhead_title_str = (
f"Total ORT initialization overhead is {end_to_end_str}s where export takes {export_str}s.\n"
@ -139,9 +131,9 @@ class TimeTracker:
return overhead_title_str
duration_summaries = []
for phase in TimeTrackerPhase:
for phase in ORTModuleInitPhase:
_get_duration = self._get_duration(phase)
if phase in [TimeTrackerPhase.EndToEnd, TimeTrackerPhase.EXPORT]:
if phase in [ORTModuleInitPhase.EndToEnd, ORTModuleInitPhase.EXPORT]:
continue
val = (
@ -155,7 +147,7 @@ class TimeTracker:
class TrackTime:
"""A function decorator to track time spent in different phases of ORT backend first-time initialization."""
def __init__(self, phase: TimeTrackerPhase):
def __init__(self, phase: ORTModuleInitPhase):
self.phase = phase
def __call__(self, func: Callable):
@ -168,3 +160,108 @@ class TrackTime:
return result
return wrapper
@contextmanager
def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None):
"""Suppress output from being printed to stdout and stderr.
If on_exit is not None, it will be called when the context manager exits.
"""
if enable:
# stdout and stderr is written to a tempfile instead
with tempfile.TemporaryFile() as fp:
try:
# Store original stdout and stderr file no.
old_stdout = os.dup(sys.stdout.fileno())
old_stderr = os.dup(sys.stderr.fileno())
# Redirect stdout and stderr (printed from Python or C++) to the file.
os.dup2(fp.fileno(), sys.stdout.fileno())
os.dup2(fp.fileno(), sys.stderr.fileno())
yield
finally:
sys.stdout.flush()
sys.stderr.flush()
# Restore stdout and stderr.
os.dup2(old_stdout, sys.stdout.fileno())
os.dup2(old_stderr, sys.stderr.fileno())
if on_exit:
on_exit(fp)
else:
yield
def _log_with_filter(logger: logging.Logger, record_filters: Optional[List[str]], name: Optional[str], fo):
"""Log the content by filtering with list of string patterns.
Args:
logger: The logger to log the content.
record_filters: The list of string patterns to filter the content.
If record_filters is None, the full content will be logged.
name: The name of log filter.
fo: The file object to read the content.
"""
if fo.tell() > 0:
if logger.disabled:
return
fo.seek(0)
suppress_output_messages = fo.readlines()
if record_filters:
filtered_messages = []
filtered_lines = 0
for suppressed_message in suppress_output_messages:
msg = suppressed_message.decode("utf-8")
found = False
for warning in record_filters:
if warning in msg:
found = True
filtered_lines += 1
break
if not found:
filtered_messages.extend(textwrap.wrap(msg, 180))
if filtered_messages:
filtered_messages.insert(0, f"[{name}] Filtered logs ({filtered_lines} records suppressed):")
logger.warning("\n ".join(filtered_messages))
else:
out_messages = []
for suppressed_message in suppress_output_messages:
out_messages.extend(textwrap.wrap(suppressed_message.decode("utf-8"), 180))
if out_messages:
out_messages.insert(0, f"[{name}] Full logs:")
logger.warning("\n ".join(out_messages))
class SuppressLogs:
"""A function decorator to suppress in different phases of ORT backend first-time initialization."""
def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True):
self.phase = phase
self.is_ort_filter = is_ort_filter
def __call__(self, func: Callable):
def wrapper(graph_execution_manager, *args, **kwargs):
if not hasattr(graph_execution_manager, "_logger"):
raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.")
if not hasattr(graph_execution_manager, "_debug_options"):
raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.")
with _suppress_os_stream_output(
enable=graph_execution_manager._debug_options.log_level >= LogLevel.INFO,
on_exit=partial(
_log_with_filter,
graph_execution_manager._logger,
graph_execution_manager._debug_options.onnxruntime_log_filter
if self.is_ort_filter
else graph_execution_manager._debug_options.torch_exporter_filter,
self.phase.to_string(),
),
):
result = func(graph_execution_manager, *args, **kwargs)
return result
return wrapper

View file

@ -114,7 +114,7 @@ class InputDensityObserver:
"""
def __init__(self, logger: Logger, log_steps=1):
self._logger: Logger = logger
self._logger = logger
self._embedding_graph_input_to_padding_idx_map = {}
self._loss_label_graph_input_to_ignore_idx_map = {}
self._stats = []

View file

@ -19,7 +19,11 @@ T = TypeVar("T", bound="torch.nn.Module")
class TorchModuleORT(TorchModuleInterface):
def __init__(
self, module: torch.nn.Module, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
self,
module: torch.nn.Module,
debug_options: DebugOptions,
fallback_manager: _FallbackManager,
logger: Logger,
):
super().__init__(module)
self._flattened_module = _io._FlattenedModule(module)

View file

@ -18,7 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo
from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from ._io import _FlattenedModule, _InputInfo
from ._logger import TimeTrackerPhase, TrackTime
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
from ._runtime_inspector import Phase
from ._utils import save_tuning_results, set_tuning_results
from .graph_transformer_registry import GraphTransformerRegistry
@ -32,7 +32,11 @@ class TrainingManager(GraphExecutionManager):
"""
def __init__(
self, model: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger
self,
model: _FlattenedModule,
debug_options: DebugOptions,
fallback_manager: _FallbackManager,
logger: Logger,
):
super().__init__(model, debug_options, fallback_manager, logger)
@ -246,7 +250,7 @@ class TrainingManager(GraphExecutionManager):
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
):
self.time_tracker.start(TimeTrackerPhase.EndToEnd)
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)
build_gradient_graph = self._export_model(*inputs, **kwargs)
@ -302,7 +306,7 @@ class TrainingManager(GraphExecutionManager):
self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info
)
self.time_tracker.end(TimeTrackerPhase.EndToEnd)
self.time_tracker.end(ORTModuleInitPhase.EndToEnd)
self._log_feature_stats()
self._gradient_accumulation_manager.maybe_update_cache_before_run()
@ -349,7 +353,8 @@ class TrainingManager(GraphExecutionManager):
if self._fallback_manager.is_pending():
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
@TrackTime(TimeTrackerPhase.BUILD_GRAPH)
@TrackTime(ORTModuleInitPhase.BUILD_GRAPH)
@SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH)
def _build_graph(self, graph_transformer_config):
"""Build an optimized gradient graph using the module_graph_builder"""
@ -394,7 +399,8 @@ class TrainingManager(GraphExecutionManager):
else:
self._gradient_map.append(-1)
@TrackTime(TimeTrackerPhase.CREATE_SESSION)
@TrackTime(ORTModuleInitPhase.CREATE_SESSION)
@SuppressLogs(ORTModuleInitPhase.CREATE_SESSION)
def _create_execution_agent(self):
"""Creates a TrainingAgent that can run the forward and backward graph on the training model"""
@ -427,6 +433,7 @@ class TrainingManager(GraphExecutionManager):
] * len(bw_fetches_names)
local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device)
self._execution_agent = TrainingAgent(
self._onnx_models.optimized_model.SerializeToString(),
fw_feed_names,

View file

@ -444,3 +444,19 @@ def set_tuning_results(session, is_training, tuning_results_path):
if os.path.isfile(tuning_result_file):
with open(tuning_result_file, encoding="utf-8") as f:
session.set_tuning_results(json.load(f))
def get_rank() -> int:
"""Returns the rank of the current process. If distributed training is not initialized, returns 0."""
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
return 0
def get_world_size() -> int:
"""Returns the world size of the current process. If distributed training is not initialized, returns 1."""
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
return 1

View file

@ -7,12 +7,14 @@ from enum import IntFlag
from functools import reduce
from logging import Logger
from packaging import version
from onnxruntime.capi import _pybind_state as C
from onnxruntime.training import ortmodule
from ._fallback import _FallbackPolicy
from ._logger import LogLevel
from ._utils import parse_os_env_skip_check_flags
from ._utils import get_runtime_pytorch_version, parse_os_env_skip_check_flags
class _SaveOnnxOptions:
@ -131,6 +133,39 @@ class DebugOptions:
return self._logging
@property
def torch_exporter_filter(self):
"""Accessor for the filter export logs configuration."""
if self.log_level >= LogLevel.INFO and get_runtime_pytorch_version() < version.parse("2.0"):
return [
# WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
# WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
"type is missing, so it may result in wrong shape inference",
# Warning: Checker does not support models with experimental ops: ATen
"Checker does not support models with experimental ops:",
"Dropout is a training op and should not be exported in inference mode.",
# Warning: Shape inference does not support models with experimental operators: ATen
"Shape inference does not support models with experimental operators:",
# Warning: Unsupported operator Trilu. No schema registered for this operator.
# Warning: Unsupported operator ATen. No schema registered for this operator.
# Warning: Unsupported operator SoftmaxCrossEntropyLossInternal. No schema registered for this operator.
"No schema registered for this operator.",
]
return None
@property
def onnxruntime_log_filter(self):
"""Accessor for the filter onnxruntime logs configuration."""
if self.log_level >= LogLevel.INFO:
return [
"CleanUnusedInitializersAndNodeArgs] Removing initializer",
"Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED",
]
return None
class _SkipCheck(IntFlag):
"""Enumeration to specify which checks should be skipped, allowing faster execution"""

View file

@ -3,8 +3,8 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# isort: skip_file
# Import ordering is important in this module to aviod circular dependencies
import logging
# Import ordering is important in this module to avoid circular dependencies
from ._torch_module_factory import TorchModuleFactory
from ._torch_module_ort import TorchModuleORT
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
@ -12,7 +12,7 @@ from ._custom_gradient_registry import CustomGradientRegistry
from . import _utils
from .options import DebugOptions
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException
from ._logger import ortmodule_loglevel_to_python_loglevel
from ._logger import configure_ortmodule_logger
from onnxruntime.training import ortmodule
from onnxruntime.tools import pytorch_export_contrib_ops
@ -53,8 +53,7 @@ class ORTModule(torch.nn.Module):
if not debug_options:
debug_options = DebugOptions()
self._logger = logging.getLogger(__name__)
self._logger.setLevel(ortmodule_loglevel_to_python_loglevel(debug_options.logging.log_level))
self._logger = configure_ortmodule_logger(debug_options.logging.log_level)
# Fallback settings
self._fallback_manager = _FallbackManager(

View file

@ -6061,3 +6061,52 @@ def test_e2e_padding_elimination():
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]
@pytest.mark.skipif(
Version(torch.__version__) >= Version("1.13.0"),
reason="PyTorch since 1.13 don't output expected warning messages any more",
)
@pytest.mark.parametrize("log_level", [LogLevel.VERBOSE, LogLevel.INFO, LogLevel.WARNING])
def test_ortmodule_log_level_control(log_level, caplog):
class NeuralNetCrossEntropyLoss(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=1)
def forward(self, input, positions):
output = torch.transpose(self.embedding(input), 0, 1)
ignored_index = output.size(1)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
return loss_fct(output, positions)
device = "cuda"
num_embeddings, embedding_dim = 32, 128
pt_model = NeuralNetCrossEntropyLoss(num_embeddings, embedding_dim).to(device)
ort_model = ORTModule(pt_model, DebugOptions(log_level=log_level))
use_fp16 = True
def run_step(model, input, positions):
with amp.autocast(use_fp16):
loss = model(input, positions)
loss.backward()
return loss
N = random.randint(16, 32) # noqa: N806
input = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64, device=device)
positions = torch.randint(high=N, size=(embedding_dim,), dtype=torch.int64, device=device)
_ = run_step(ort_model, input, positions)
found_missing_inference_log = False
for record in caplog.records:
msg = record.getMessage()
print(msg)
if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg:
found_missing_inference_log = True
break
if log_level == LogLevel.VERBOSE:
assert found_missing_inference_log
else:
assert not found_missing_inference_log

View file

@ -564,7 +564,7 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog):
if i == 0:
# For the first time, run ORTModule, feature map is logged as warning
# And the fallback warning is logged.
assert len(caplog.records) == 2
assert len(caplog.records) >= 2
else:
# For the other time, only the fallback warning is logged.
assert len(caplog.records) == 1
@ -576,8 +576,8 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog):
if i == 0:
# For the first time, run ORTModule, feature map is logged as warning
# And the fallback warning is logged.
assert len(caplog.records) == 2
assert "Fallback to PyTorch due to exception" in caplog.records[1].message
assert len(caplog.records) >= 2
assert "Fallback to PyTorch due to exception" in caplog.records[-1].message
caplog.clear()
else:
# For the other time, no fallback warning will be logged because