From 39fca225ea90f306955d4c35b3dc46170a468d43 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 26 Jul 2023 12:42:50 +0800 Subject: [PATCH] ORTModule log clean up (#16795) ### ORTModule log clean up ORTModule log level - WARNING(Default) is for end users; INFO and VERBOSE is for internal ORT training developers. Few issues: 1. ONNX export will output lots of WARNING error message like "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is missing", which is useless for us or end users. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/f2409480-32e1-483d-bd18-f14149f0588d) 3. ORT also print some information like ""CleanUnusedInitializersAndNodeArgs] Removing initializer","ReverseBFSWithStopGradient] Skip building gradient for", which is also useless for us or end users most of the time. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/ff3feaf1-3cb2-4392-b087-86b30b72994c) 5. Different ranks output logs and making ORT developers or end users feels there are too many logs but usually not useful until we need investigate. Few improvements for the issues: 1. For ONNX export logs, there are two kinds of logs: a. export verbose log; b. other logs printed by torch C++ backend. So this PR make following change: # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) # INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) e.g. for verbose level, print all logs as usually; for info level, print verbose export log, and filtered logs from torch C++ backend (removing messages like this "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal/ATen/PythonOp type is missing") . For higher level, only log the info on rank 0. 2. For ORT gradient graph build and session creation, also suppress the message and filtered out the message when log level >=INFO. 3. log level > INFO, then only logs on rank 0 is logged, to have a cleaner user experience This is the log for a BLOOM model training after the change: there are limited of warnings. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/f270b8d5-2944-49d2-a253-c07057d641a0) --- .../optimizer/concat_slice_elimination.cc | 5 +- .../core/optimizer/constant_sharing.cc | 5 +- onnxruntime/core/optimizer/reshape_fusion.cc | 5 +- .../core/framework/gradient_graph_builder.cc | 14 +- .../ortmodule/_graph_execution_manager.py | 170 +++++++-------- .../_graph_execution_manager_factory.py | 6 +- .../training/ortmodule/_inference_manager.py | 12 +- .../python/training/ortmodule/_logger.py | 203 +++++++++++++----- .../training/ortmodule/_runtime_inspector.py | 2 +- .../training/ortmodule/_torch_module_ort.py | 6 +- .../training/ortmodule/_training_manager.py | 19 +- .../python/training/ortmodule/_utils.py | 16 ++ .../python/training/ortmodule/options.py | 37 +++- .../python/training/ortmodule/ortmodule.py | 9 +- .../python/orttraining_test_ortmodule_api.py | 49 +++++ .../orttraining_test_ortmodule_fallback.py | 6 +- 16 files changed, 388 insertions(+), 176 deletions(-) diff --git a/onnxruntime/core/optimizer/concat_slice_elimination.cc b/onnxruntime/core/optimizer/concat_slice_elimination.cc index dcee53e5c8..f7a2b3be44 100644 --- a/onnxruntime/core/optimizer/concat_slice_elimination.cc +++ b/onnxruntime/core/optimizer/concat_slice_elimination.cc @@ -33,7 +33,10 @@ Status ConcatSliceElimination::ApplyImpl(Graph& graph, bool& modified, int graph modified = true; } } - LOGS(logger, INFO) << "Total fused concat node count: " << fused_count; + + if (fused_count > 0) { + LOGS(logger, INFO) << "Total fused concat node count: " << fused_count; + } return Status::OK(); } diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index 116061a542..a3c5a72ee7 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -252,8 +252,9 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve modified = true; } - - LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; + if (shared_count > 0) { + LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; + } return Status::OK(); } diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index c6e52299ca..7768a835d5 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -50,7 +50,10 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c modified = true; } } - LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count; + + if (fused_count > 0) { + LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count; + } return Status::OK(); } diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index c18bbb9066..c9cda92972 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -164,8 +164,8 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c const std::unordered_set* edges = GetStopGradientEdges(*n); for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) { if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name(); + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name(); continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; @@ -173,13 +173,13 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { const int32_t type = type_proto->tensor_type().elem_type(); if (GRAD_ALLOWED_TYPES.find(type) == GRAD_ALLOWED_TYPES.end()) { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because element type is: " << type; + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << "because element type is: " << type; continue; } } else { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because it is not a Tensor type"; + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << "because it is not a Tensor type"; continue; } @@ -280,7 +280,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init const std::unordered_set* edges = GetStopGradientEdges(next_node); if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) { - LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() << " of node: " << next_node.Name(); continue; } diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index f0e19724e9..aaa189cb45 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -33,7 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType from ._runtime_inspector import RuntimeInspector -from ._utils import check_function_has_param +from ._utils import check_function_has_param, get_rank from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -104,7 +104,6 @@ class GraphExecutionManager(GraphExecutionInterface): # Input and output infos (including schema) for exported model. self._input_info: Optional[_InputInfo] = None self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None - self._warning_log_detected_during_export = False # Device where the model is placed. self._device: Optional[torch.device] = _utils.get_device_from_module(module) @@ -253,7 +252,8 @@ class GraphExecutionManager(GraphExecutionInterface): return session_options, providers, provider_options - @_logger.TrackTime(_logger.TimeTrackerPhase.EXPORT) + @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) + @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) def _export_model(self, *inputs, **kwargs) -> bool: # 1. Set the self._device from the user module # 2. Verify input schema matches the schema used on the previous model export @@ -304,90 +304,88 @@ class GraphExecutionManager(GraphExecutionInterface): TODO: How to support dynamic axes? Dimensions are determined by samples """ - with _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level) as suppress_output: - # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export( - self._module_parameters, None, input_schema, inputs, kwargs - ) - ( - output_names, - output_dynamic_axes, - self._module_output_schema, - ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger - ) - self._input_info.dynamic_axes.update(output_dynamic_axes) - # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - self._flattened_module._input_info = self._input_info + # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) + # INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) + # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) + # Be noted: rank 0 log only is controlled by logger configured in _logger.py + torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO + self._logger.info("Exporting the PyTorch model to ONNX...") - # Export torch.nn.Module to ONNX - f = io.BytesIO() + # Setup dynamic axes for onnx model + self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + ( + output_names, + output_dynamic_axes, + self._module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema(self._original_module, inputs, kwargs, self._logger) + self._input_info.dynamic_axes.update(output_dynamic_axes) - # Deepcopy inputs, since input values may change after model run. - # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). - # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) - # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple( - self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device) - ) - # Ops behaving differently under train/eval mode need to be exported with the - # correct training flag to reflect the expected behavior. - # For example, the Dropout node in a model is dropped under eval mode. - assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" + # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs + self._flattened_module._input_info = self._input_info - try: - with torch.no_grad(): - required_export_kwargs = { - "input_names": self._input_info.names, - "output_names": output_names, - "opset_version": self._runtime_options.onnx_opset_version, - "do_constant_folding": False, - "training": self._export_mode, - "dynamic_axes": self._input_info.dynamic_axes, - "verbose": self._debug_options.logging.log_level < LogLevel.WARNING, - "export_params": False, - "keep_initializers_as_inputs": True, - } + # Export torch.nn.Module to ONNX + f = io.BytesIO() - if check_function_has_param(torch.onnx.export, "autograd_inlining"): - # From some PyTorch version, autograd_inlining is a valid argument. - # We allow it to be True if custom autograd function is disabled (where autograd.Function - # anyway is not supported in ONNX until it can be inlined). - required_export_kwargs[ - "autograd_inlining" - ] = not self._runtime_options.enable_custom_autograd_function + # Deepcopy inputs, since input values may change after model run. + # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). + # Therefore, deepcopy only the data component of the input tensors for export. + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) + # NOTE: Flattening the input will change the 'input schema', resulting in a re-export + sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) + # Ops behaving differently under train/eval mode need to be exported with the + # correct training flag to reflect the expected behavior. + # For example, the Dropout node in a model is dropped under eval mode. + assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" - invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - assert ( - len(invalid_args) == 0 - ), f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - torch.onnx.export( - self._flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs, - ) - except Exception as e: - raise wrap_exception( # noqa: B904 - ORTModuleONNXModelException, - RuntimeError( - f"There was an error while exporting the PyTorch model to ONNX: " - f"\n\n{_utils.get_exception_as_string(e)}" - ), + try: + with torch.no_grad(): + required_export_kwargs = { + "input_names": self._input_info.names, + "output_names": output_names, + "opset_version": self._runtime_options.onnx_opset_version, + "do_constant_folding": False, + "training": self._export_mode, + "dynamic_axes": self._input_info.dynamic_axes, + "verbose": torch_exporter_verbose_log, + "export_params": False, + "keep_initializers_as_inputs": True, + } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs[ + "autograd_inlining" + ] = not self._runtime_options.enable_custom_autograd_function + + invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() + + if len(invalid_args) != 0: + error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + raise RuntimeError(error_msg) + + torch.onnx.export( + self._flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **self._export_extra_kwargs, ) - exported_model = onnx.load_model_from_string(f.getvalue()) - - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function + except Exception as e: + raise wrap_exception( # noqa: B904 + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), ) + exported_model = onnx.load_model_from_string(f.getvalue()) - # If anything was captured by suppress_output during export, set the flag to - # raise a single user warning letting users know in the log. - if suppress_output.tell() > 0: - self._warning_log_detected_during_export = True + exported_model = _post_process_after_export( + exported_model, self._runtime_options.enable_custom_autograd_function + ) return exported_model @@ -411,7 +409,8 @@ class GraphExecutionManager(GraphExecutionInterface): graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer return graph_transformer_config - @_logger.TrackTime(_logger.TimeTrackerPhase.GRAPH_BUILDER_INIT) + @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) + @_logger.SuppressLogs(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) def _initialize_graph_builder(self): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" @@ -483,7 +482,7 @@ class GraphExecutionManager(GraphExecutionInterface): _utils.reinitialize_graph_execution_manager(self) - @_logger.TrackTime(_logger.TimeTrackerPhase.DETECTION) + @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _enable_conditional_optimizations( self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict ): @@ -544,11 +543,7 @@ class GraphExecutionManager(GraphExecutionInterface): self._runtime_inspector.enable_memory_inspector(self._original_module) def _log_feature_stats(self): - rank = 0 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - - if rank != 0: + if get_rank() != 0: return feature_map: List[Tuple[str, bool, str]] = [ @@ -638,11 +633,8 @@ class GraphExecutionManager(GraphExecutionInterface): switch_str = "ON" if feature_tuple[1] else "OFF" stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n" - stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n" - stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n" - # Collect ORTModule overheads for different phases. - stat += f"{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n" + stat += f"\n{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n" stat += f"Versions: ONNX Runtime - {onnxruntime.__version__}, ONNX - {onnx.__version__}\n\n" stat += f"{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n" diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py index 1b09eafee9..104cc0a894 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py @@ -15,7 +15,11 @@ from .options import DebugOptions class GraphExecutionManagerFactory: def __init__( - self, module: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + module: _FlattenedModule, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): self._training_manager = TrainingManager(module, debug_options, fallback_manager, logger) self._inference_manager = InferenceManager(module, debug_options, fallback_manager, logger) diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 96936250f3..24215364d6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -15,7 +15,7 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._logger import TimeTrackerPhase, TrackTime +from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._utils import save_tuning_results, set_tuning_results from .options import DebugOptions, _SkipCheck @@ -111,7 +111,7 @@ class InferenceManager(GraphExecutionManager): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): - self.time_tracker.start(TimeTrackerPhase.EndToEnd) + self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time build_graph = self._export_model(*inputs, **kwargs) @@ -151,7 +151,7 @@ class InferenceManager(GraphExecutionManager): # Create execution session creates the inference_session self._create_execution_agent() - self.time_tracker.end(TimeTrackerPhase.EndToEnd) + self.time_tracker.end(ORTModuleInitPhase.EndToEnd) self._log_feature_stats() if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: @@ -201,7 +201,8 @@ class InferenceManager(GraphExecutionManager): if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) - @TrackTime(TimeTrackerPhase.BUILD_GRAPH) + @TrackTime(ORTModuleInitPhase.BUILD_GRAPH) + @SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH) def _build_graph(self, graph_transformer_config): """Build an inference graph using the module_graph_builder""" @@ -214,7 +215,8 @@ class InferenceManager(GraphExecutionManager): self._export_mode, ) - @TrackTime(TimeTrackerPhase.CREATE_SESSION) + @TrackTime(ORTModuleInitPhase.CREATE_SESSION) + @SuppressLogs(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates an InferenceAgent that can run forward graph on an inference model""" diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index f075897434..76a4035855 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -3,16 +3,21 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import io import logging +import os import sys +import tempfile +import textwrap import time from contextlib import contextmanager from enum import IntEnum -from typing import Callable, Dict, List +from functools import partial +from typing import Callable, Dict, List, Optional from onnxruntime.capi._pybind_state import Severity +from ._utils import get_rank, get_world_size + class LogLevel(IntEnum): VERBOSE = 0 @@ -22,34 +27,6 @@ class LogLevel(IntEnum): FATAL = 4 -@contextmanager -def suppress_os_stream_output(suppress_stdout=True, suppress_stderr=True, log_level=LogLevel.WARNING): - """Suppress output from being printed to stdout and stderr if log_level is WARNING or higher. - - If there is any output detected, a single warning is issued in the context - """ - - # stdout and stderr is written to a tempfile instead - stdout = sys.stdout - stderr = sys.stderr - - suppress_logs = log_level >= LogLevel.WARNING - - fo = io.StringIO() - - try: - if suppress_stdout and suppress_logs: - sys.stdout = fo - if suppress_stderr and suppress_logs: - sys.stderr = fo - yield fo - finally: - if suppress_stdout: - sys.stdout = stdout - if suppress_stderr: - sys.stderr = stderr - - ORTMODULE_LOG_LEVEL_MAP: Dict[LogLevel, List[int]] = { LogLevel.VERBOSE: [Severity.VERBOSE, logging.DEBUG], LogLevel.INFO: [Severity.INFO, logging.INFO], @@ -67,6 +44,21 @@ def ortmodule_loglevel_to_python_loglevel(loglevel: LogLevel) -> int: return ORTMODULE_LOG_LEVEL_MAP.get(loglevel, [Severity.WARNING, logging.WARNING])[1] +def configure_ortmodule_logger(log_level: LogLevel) -> logging.Logger: + """Configure the logger for ortmodule according to following rules. + 1. If multiple processes are used, the rank will be appended + to the logger name. + 2. If the log level is greater than info, the logger will be + disabled for non-zero ranks. + """ + rank_info = f".rank-{get_rank()}" if get_world_size() > 1 else "" + logger = logging.getLogger(f"orttraining{rank_info}") + # Disable the logger for non-zero ranks when level > info + logger.disabled = log_level > LogLevel.INFO and get_rank() != 0 + logger.setLevel(ortmodule_loglevel_to_python_loglevel(log_level)) + return logger + + class LogColor: HEADER = "\033[95m" BLUE = "\033[94m" @@ -79,26 +71,26 @@ class LogColor: UNDERLINE = "\033[4m" -class TimeTrackerPhase(IntEnum): - EndToEnd = 0 # The total overhead of ORT first-time initialization - EXPORT = 1 # The latency of preparing and exporting the model to ONNX - GRAPH_BUILDER_INIT = 2 # The latency of initializing the graph builder - DETECTION = 3 # The latency of runtime detection - BUILD_GRAPH = 4 # The latency of optimizing forward graph (and building the gradient graph for training). - CREATE_SESSION = 5 # The latency of creating the session +class ORTModuleInitPhase(IntEnum): + EndToEnd = 0 # The end to end of ORT first-time initialization + EXPORT = 1 # The phase of preparing and exporting the model to ONNX + GRAPH_BUILDER_INIT = 2 # The phase of initializing the graph builder + DETECTION = 3 # The phase of runtime detection + BUILD_GRAPH = 4 # The phase of optimizing forward graph (and building the gradient graph for training). + CREATE_SESSION = 5 # The phase of creating the session def to_string(self) -> str: - if self == TimeTrackerPhase.EndToEnd: + if self == ORTModuleInitPhase.EndToEnd: return "end to end" - if self == TimeTrackerPhase.EXPORT: + if self == ORTModuleInitPhase.EXPORT: return "export" - elif self == TimeTrackerPhase.GRAPH_BUILDER_INIT: + elif self == ORTModuleInitPhase.GRAPH_BUILDER_INIT: return "graph builder init" - elif self == TimeTrackerPhase.DETECTION: + elif self == ORTModuleInitPhase.DETECTION: return "runtime detection" - elif self == TimeTrackerPhase.BUILD_GRAPH: + elif self == ORTModuleInitPhase.BUILD_GRAPH: return "graph building" - elif self == TimeTrackerPhase.CREATE_SESSION: + elif self == ORTModuleInitPhase.CREATE_SESSION: return "session creation" else: return "invalid" @@ -112,24 +104,24 @@ class TimeTracker: def __init__( self, ): - self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase) - self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase) + self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) + self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) - def start(self, phase: TimeTrackerPhase): + def start(self, phase: ORTModuleInitPhase): self.starts_[phase] = time.time() - def end(self, phase: TimeTrackerPhase): + def end(self, phase: ORTModuleInitPhase): self.ends_[phase] = time.time() - def _get_duration(self, phase: TimeTrackerPhase): + def _get_duration(self, phase: ORTModuleInitPhase): if self.ends_[phase] == TimeTracker.NOT_RECORD or self.starts_[phase] == TimeTracker.NOT_RECORD: return TimeTracker.NOT_RECORD return self.ends_[phase] - self.starts_[phase] def to_string(self, log_details=False) -> str: - end_to_end_str = self._get_duration(TimeTrackerPhase.EndToEnd) + end_to_end_str = self._get_duration(ORTModuleInitPhase.EndToEnd) end_to_end_str = f"{end_to_end_str:.2f}" if end_to_end_str != TimeTracker.NOT_RECORD else "N/A" - export_str = self._get_duration(TimeTrackerPhase.EXPORT) + export_str = self._get_duration(ORTModuleInitPhase.EXPORT) export_str = f"{export_str:.2f}" if export_str != TimeTracker.NOT_RECORD else "N/A" overhead_title_str = ( f"Total ORT initialization overhead is {end_to_end_str}s where export takes {export_str}s.\n" @@ -139,9 +131,9 @@ class TimeTracker: return overhead_title_str duration_summaries = [] - for phase in TimeTrackerPhase: + for phase in ORTModuleInitPhase: _get_duration = self._get_duration(phase) - if phase in [TimeTrackerPhase.EndToEnd, TimeTrackerPhase.EXPORT]: + if phase in [ORTModuleInitPhase.EndToEnd, ORTModuleInitPhase.EXPORT]: continue val = ( @@ -155,7 +147,7 @@ class TimeTracker: class TrackTime: """A function decorator to track time spent in different phases of ORT backend first-time initialization.""" - def __init__(self, phase: TimeTrackerPhase): + def __init__(self, phase: ORTModuleInitPhase): self.phase = phase def __call__(self, func: Callable): @@ -168,3 +160,108 @@ class TrackTime: return result return wrapper + + +@contextmanager +def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): + """Suppress output from being printed to stdout and stderr. + + If on_exit is not None, it will be called when the context manager exits. + """ + if enable: + # stdout and stderr is written to a tempfile instead + with tempfile.TemporaryFile() as fp: + try: + # Store original stdout and stderr file no. + old_stdout = os.dup(sys.stdout.fileno()) + old_stderr = os.dup(sys.stderr.fileno()) + + # Redirect stdout and stderr (printed from Python or C++) to the file. + os.dup2(fp.fileno(), sys.stdout.fileno()) + os.dup2(fp.fileno(), sys.stderr.fileno()) + yield + finally: + sys.stdout.flush() + sys.stderr.flush() + + # Restore stdout and stderr. + os.dup2(old_stdout, sys.stdout.fileno()) + os.dup2(old_stderr, sys.stderr.fileno()) + + if on_exit: + on_exit(fp) + else: + yield + + +def _log_with_filter(logger: logging.Logger, record_filters: Optional[List[str]], name: Optional[str], fo): + """Log the content by filtering with list of string patterns. + Args: + logger: The logger to log the content. + record_filters: The list of string patterns to filter the content. + If record_filters is None, the full content will be logged. + name: The name of log filter. + fo: The file object to read the content. + """ + if fo.tell() > 0: + if logger.disabled: + return + + fo.seek(0) + suppress_output_messages = fo.readlines() + if record_filters: + filtered_messages = [] + filtered_lines = 0 + for suppressed_message in suppress_output_messages: + msg = suppressed_message.decode("utf-8") + found = False + for warning in record_filters: + if warning in msg: + found = True + filtered_lines += 1 + break + + if not found: + filtered_messages.extend(textwrap.wrap(msg, 180)) + if filtered_messages: + filtered_messages.insert(0, f"[{name}] Filtered logs ({filtered_lines} records suppressed):") + logger.warning("\n ".join(filtered_messages)) + else: + out_messages = [] + for suppressed_message in suppress_output_messages: + out_messages.extend(textwrap.wrap(suppressed_message.decode("utf-8"), 180)) + if out_messages: + out_messages.insert(0, f"[{name}] Full logs:") + logger.warning("\n ".join(out_messages)) + + +class SuppressLogs: + """A function decorator to suppress in different phases of ORT backend first-time initialization.""" + + def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True): + self.phase = phase + self.is_ort_filter = is_ort_filter + + def __call__(self, func: Callable): + def wrapper(graph_execution_manager, *args, **kwargs): + if not hasattr(graph_execution_manager, "_logger"): + raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.") + + if not hasattr(graph_execution_manager, "_debug_options"): + raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") + + with _suppress_os_stream_output( + enable=graph_execution_manager._debug_options.log_level >= LogLevel.INFO, + on_exit=partial( + _log_with_filter, + graph_execution_manager._logger, + graph_execution_manager._debug_options.onnxruntime_log_filter + if self.is_ort_filter + else graph_execution_manager._debug_options.torch_exporter_filter, + self.phase.to_string(), + ), + ): + result = func(graph_execution_manager, *args, **kwargs) + return result + + return wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index a871ed08b9..dda909e8cb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -114,7 +114,7 @@ class InputDensityObserver: """ def __init__(self, logger: Logger, log_steps=1): - self._logger: Logger = logger + self._logger = logger self._embedding_graph_input_to_padding_idx_map = {} self._loss_label_graph_input_to_ignore_idx_map = {} self._stats = [] diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index dba68be944..1255909022 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -19,7 +19,11 @@ T = TypeVar("T", bound="torch.nn.Module") class TorchModuleORT(TorchModuleInterface): def __init__( - self, module: torch.nn.Module, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + module: torch.nn.Module, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): super().__init__(module) self._flattened_module = _io._FlattenedModule(module) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 4d192b041e..cb8561867a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -18,7 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo from ._io import _FlattenedModule, _InputInfo -from ._logger import TimeTrackerPhase, TrackTime +from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results from .graph_transformer_registry import GraphTransformerRegistry @@ -32,7 +32,11 @@ class TrainingManager(GraphExecutionManager): """ def __init__( - self, model: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + model: _FlattenedModule, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): super().__init__(model, debug_options, fallback_manager, logger) @@ -246,7 +250,7 @@ class TrainingManager(GraphExecutionManager): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): - self.time_tracker.start(TimeTrackerPhase.EndToEnd) + self.time_tracker.start(ORTModuleInitPhase.EndToEnd) build_gradient_graph = self._export_model(*inputs, **kwargs) @@ -302,7 +306,7 @@ class TrainingManager(GraphExecutionManager): self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info ) - self.time_tracker.end(TimeTrackerPhase.EndToEnd) + self.time_tracker.end(ORTModuleInitPhase.EndToEnd) self._log_feature_stats() self._gradient_accumulation_manager.maybe_update_cache_before_run() @@ -349,7 +353,8 @@ class TrainingManager(GraphExecutionManager): if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) - @TrackTime(TimeTrackerPhase.BUILD_GRAPH) + @TrackTime(ORTModuleInitPhase.BUILD_GRAPH) + @SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH) def _build_graph(self, graph_transformer_config): """Build an optimized gradient graph using the module_graph_builder""" @@ -394,7 +399,8 @@ class TrainingManager(GraphExecutionManager): else: self._gradient_map.append(-1) - @TrackTime(TimeTrackerPhase.CREATE_SESSION) + @TrackTime(ORTModuleInitPhase.CREATE_SESSION) + @SuppressLogs(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" @@ -427,6 +433,7 @@ class TrainingManager(GraphExecutionManager): ] * len(bw_fetches_names) local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device) + self._execution_agent = TrainingAgent( self._onnx_models.optimized_model.SerializeToString(), fw_feed_names, diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 43bbd5dd21..8aba7c8826 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -444,3 +444,19 @@ def set_tuning_results(session, is_training, tuning_results_path): if os.path.isfile(tuning_result_file): with open(tuning_result_file, encoding="utf-8") as f: session.set_tuning_results(json.load(f)) + + +def get_rank() -> int: + """Returns the rank of the current process. If distributed training is not initialized, returns 0.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + + return 0 + + +def get_world_size() -> int: + """Returns the world size of the current process. If distributed training is not initialized, returns 1.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + + return 1 diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 6eb1e70991..7758ff518a 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -7,12 +7,14 @@ from enum import IntFlag from functools import reduce from logging import Logger +from packaging import version + from onnxruntime.capi import _pybind_state as C from onnxruntime.training import ortmodule from ._fallback import _FallbackPolicy from ._logger import LogLevel -from ._utils import parse_os_env_skip_check_flags +from ._utils import get_runtime_pytorch_version, parse_os_env_skip_check_flags class _SaveOnnxOptions: @@ -131,6 +133,39 @@ class DebugOptions: return self._logging + @property + def torch_exporter_filter(self): + """Accessor for the filter export logs configuration.""" + if self.log_level >= LogLevel.INFO and get_runtime_pytorch_version() < version.parse("2.0"): + return [ + # WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + "type is missing, so it may result in wrong shape inference", + # Warning: Checker does not support models with experimental ops: ATen + "Checker does not support models with experimental ops:", + "Dropout is a training op and should not be exported in inference mode.", + # Warning: Shape inference does not support models with experimental operators: ATen + "Shape inference does not support models with experimental operators:", + # Warning: Unsupported operator Trilu. No schema registered for this operator. + # Warning: Unsupported operator ATen. No schema registered for this operator. + # Warning: Unsupported operator SoftmaxCrossEntropyLossInternal. No schema registered for this operator. + "No schema registered for this operator.", + ] + + return None + + @property + def onnxruntime_log_filter(self): + """Accessor for the filter onnxruntime logs configuration.""" + if self.log_level >= LogLevel.INFO: + return [ + "CleanUnusedInitializersAndNodeArgs] Removing initializer", + "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED", + ] + return None + class _SkipCheck(IntFlag): """Enumeration to specify which checks should be skipped, allowing faster execution""" diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 7978695f0e..b5c52bdaef 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # isort: skip_file -# Import ordering is important in this module to aviod circular dependencies -import logging +# Import ordering is important in this module to avoid circular dependencies + from ._torch_module_factory import TorchModuleFactory from ._torch_module_ort import TorchModuleORT from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry @@ -12,7 +12,7 @@ from ._custom_gradient_registry import CustomGradientRegistry from . import _utils from .options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException -from ._logger import ortmodule_loglevel_to_python_loglevel +from ._logger import configure_ortmodule_logger from onnxruntime.training import ortmodule from onnxruntime.tools import pytorch_export_contrib_ops @@ -53,8 +53,7 @@ class ORTModule(torch.nn.Module): if not debug_options: debug_options = DebugOptions() - self._logger = logging.getLogger(__name__) - self._logger.setLevel(ortmodule_loglevel_to_python_loglevel(debug_options.logging.log_level)) + self._logger = configure_ortmodule_logger(debug_options.logging.log_level) # Fallback settings self._fallback_manager = _FallbackManager( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 625c1ce0d4..8e3f70ded7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6061,3 +6061,52 @@ def test_e2e_padding_elimination(): assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] + + +@pytest.mark.skipif( + Version(torch.__version__) >= Version("1.13.0"), + reason="PyTorch since 1.13 don't output expected warning messages any more", +) +@pytest.mark.parametrize("log_level", [LogLevel.VERBOSE, LogLevel.INFO, LogLevel.WARNING]) +def test_ortmodule_log_level_control(log_level, caplog): + class NeuralNetCrossEntropyLoss(torch.nn.Module): + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=1) + + def forward(self, input, positions): + output = torch.transpose(self.embedding(input), 0, 1) + ignored_index = output.size(1) + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) + return loss_fct(output, positions) + + device = "cuda" + num_embeddings, embedding_dim = 32, 128 + pt_model = NeuralNetCrossEntropyLoss(num_embeddings, embedding_dim).to(device) + + ort_model = ORTModule(pt_model, DebugOptions(log_level=log_level)) + use_fp16 = True + + def run_step(model, input, positions): + with amp.autocast(use_fp16): + loss = model(input, positions) + loss.backward() + return loss + + N = random.randint(16, 32) # noqa: N806 + input = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64, device=device) + positions = torch.randint(high=N, size=(embedding_dim,), dtype=torch.int64, device=device) + _ = run_step(ort_model, input, positions) + + found_missing_inference_log = False + for record in caplog.records: + msg = record.getMessage() + print(msg) + if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg: + found_missing_inference_log = True + break + + if log_level == LogLevel.VERBOSE: + assert found_missing_inference_log + else: + assert not found_missing_inference_log diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 020ed06513..5b2ef9ad28 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -564,7 +564,7 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog): if i == 0: # For the first time, run ORTModule, feature map is logged as warning # And the fallback warning is logged. - assert len(caplog.records) == 2 + assert len(caplog.records) >= 2 else: # For the other time, only the fallback warning is logged. assert len(caplog.records) == 1 @@ -576,8 +576,8 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog): if i == 0: # For the first time, run ORTModule, feature map is logged as warning # And the fallback warning is logged. - assert len(caplog.records) == 2 - assert "Fallback to PyTorch due to exception" in caplog.records[1].message + assert len(caplog.records) >= 2 + assert "Fallback to PyTorch due to exception" in caplog.records[-1].message caplog.clear() else: # For the other time, no fallback warning will be logged because