diff --git a/onnxruntime/core/optimizer/concat_slice_elimination.cc b/onnxruntime/core/optimizer/concat_slice_elimination.cc index dcee53e5c8..f7a2b3be44 100644 --- a/onnxruntime/core/optimizer/concat_slice_elimination.cc +++ b/onnxruntime/core/optimizer/concat_slice_elimination.cc @@ -33,7 +33,10 @@ Status ConcatSliceElimination::ApplyImpl(Graph& graph, bool& modified, int graph modified = true; } } - LOGS(logger, INFO) << "Total fused concat node count: " << fused_count; + + if (fused_count > 0) { + LOGS(logger, INFO) << "Total fused concat node count: " << fused_count; + } return Status::OK(); } diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index 116061a542..a3c5a72ee7 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -252,8 +252,9 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve modified = true; } - - LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; + if (shared_count > 0) { + LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; + } return Status::OK(); } diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index c6e52299ca..7768a835d5 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -50,7 +50,10 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c modified = true; } } - LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count; + + if (fused_count > 0) { + LOGS(logger, INFO) << "Total fused reshape node count: " << fused_count; + } return Status::OK(); } diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index c18bbb9066..c9cda92972 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -164,8 +164,8 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c const std::unordered_set* edges = GetStopGradientEdges(*n); for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) { if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name(); + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name(); continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; @@ -173,13 +173,13 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { const int32_t type = type_proto->tensor_type().elem_type(); if (GRAD_ALLOWED_TYPES.find(type) == GRAD_ALLOWED_TYPES.end()) { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because element type is: " << type; + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << "because element type is: " << type; continue; } } else { - LOGS(logger_, INFO) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because it is not a Tensor type"; + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name() << "because it is not a Tensor type"; continue; } @@ -280,7 +280,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init const std::unordered_set* edges = GetStopGradientEdges(next_node); if (edges != nullptr && edges->count(edge_it->GetDstArgIndex())) { - LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() << " of node: " << next_node.Name(); continue; } diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index f0e19724e9..aaa189cb45 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -33,7 +33,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType from ._runtime_inspector import RuntimeInspector -from ._utils import check_function_has_param +from ._utils import check_function_has_param, get_rank from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -104,7 +104,6 @@ class GraphExecutionManager(GraphExecutionInterface): # Input and output infos (including schema) for exported model. self._input_info: Optional[_InputInfo] = None self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None - self._warning_log_detected_during_export = False # Device where the model is placed. self._device: Optional[torch.device] = _utils.get_device_from_module(module) @@ -253,7 +252,8 @@ class GraphExecutionManager(GraphExecutionInterface): return session_options, providers, provider_options - @_logger.TrackTime(_logger.TimeTrackerPhase.EXPORT) + @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT) + @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False) def _export_model(self, *inputs, **kwargs) -> bool: # 1. Set the self._device from the user module # 2. Verify input schema matches the schema used on the previous model export @@ -304,90 +304,88 @@ class GraphExecutionManager(GraphExecutionInterface): TODO: How to support dynamic axes? Dimensions are determined by samples """ - with _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level) as suppress_output: - # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export( - self._module_parameters, None, input_schema, inputs, kwargs - ) - ( - output_names, - output_dynamic_axes, - self._module_output_schema, - ) = _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs, self._logger - ) - self._input_info.dynamic_axes.update(output_dynamic_axes) - # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs - self._flattened_module._input_info = self._input_info + # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) + # INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) + # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) + # Be noted: rank 0 log only is controlled by logger configured in _logger.py + torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO + self._logger.info("Exporting the PyTorch model to ONNX...") - # Export torch.nn.Module to ONNX - f = io.BytesIO() + # Setup dynamic axes for onnx model + self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + ( + output_names, + output_dynamic_axes, + self._module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema(self._original_module, inputs, kwargs, self._logger) + self._input_info.dynamic_axes.update(output_dynamic_axes) - # Deepcopy inputs, since input values may change after model run. - # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). - # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) - # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple( - self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device) - ) - # Ops behaving differently under train/eval mode need to be exported with the - # correct training flag to reflect the expected behavior. - # For example, the Dropout node in a model is dropped under eval mode. - assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" + # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs + self._flattened_module._input_info = self._input_info - try: - with torch.no_grad(): - required_export_kwargs = { - "input_names": self._input_info.names, - "output_names": output_names, - "opset_version": self._runtime_options.onnx_opset_version, - "do_constant_folding": False, - "training": self._export_mode, - "dynamic_axes": self._input_info.dynamic_axes, - "verbose": self._debug_options.logging.log_level < LogLevel.WARNING, - "export_params": False, - "keep_initializers_as_inputs": True, - } + # Export torch.nn.Module to ONNX + f = io.BytesIO() - if check_function_has_param(torch.onnx.export, "autograd_inlining"): - # From some PyTorch version, autograd_inlining is a valid argument. - # We allow it to be True if custom autograd function is disabled (where autograd.Function - # anyway is not supported in ONNX until it can be inlined). - required_export_kwargs[ - "autograd_inlining" - ] = not self._runtime_options.enable_custom_autograd_function + # Deepcopy inputs, since input values may change after model run. + # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). + # Therefore, deepcopy only the data component of the input tensors for export. + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) + # NOTE: Flattening the input will change the 'input schema', resulting in a re-export + sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) + # Ops behaving differently under train/eval mode need to be exported with the + # correct training flag to reflect the expected behavior. + # For example, the Dropout node in a model is dropped under eval mode. + assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" - invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - assert ( - len(invalid_args) == 0 - ), f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - torch.onnx.export( - self._flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs, - ) - except Exception as e: - raise wrap_exception( # noqa: B904 - ORTModuleONNXModelException, - RuntimeError( - f"There was an error while exporting the PyTorch model to ONNX: " - f"\n\n{_utils.get_exception_as_string(e)}" - ), + try: + with torch.no_grad(): + required_export_kwargs = { + "input_names": self._input_info.names, + "output_names": output_names, + "opset_version": self._runtime_options.onnx_opset_version, + "do_constant_folding": False, + "training": self._export_mode, + "dynamic_axes": self._input_info.dynamic_axes, + "verbose": torch_exporter_verbose_log, + "export_params": False, + "keep_initializers_as_inputs": True, + } + + if check_function_has_param(torch.onnx.export, "autograd_inlining"): + # From some PyTorch version, autograd_inlining is a valid argument. + # We allow it to be True if custom autograd function is disabled (where autograd.Function + # anyway is not supported in ONNX until it can be inlined). + required_export_kwargs[ + "autograd_inlining" + ] = not self._runtime_options.enable_custom_autograd_function + + invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() + + if len(invalid_args) != 0: + error_msg = f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + raise RuntimeError(error_msg) + + torch.onnx.export( + self._flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **self._export_extra_kwargs, ) - exported_model = onnx.load_model_from_string(f.getvalue()) - - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function + except Exception as e: + raise wrap_exception( # noqa: B904 + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), ) + exported_model = onnx.load_model_from_string(f.getvalue()) - # If anything was captured by suppress_output during export, set the flag to - # raise a single user warning letting users know in the log. - if suppress_output.tell() > 0: - self._warning_log_detected_during_export = True + exported_model = _post_process_after_export( + exported_model, self._runtime_options.enable_custom_autograd_function + ) return exported_model @@ -411,7 +409,8 @@ class GraphExecutionManager(GraphExecutionInterface): graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer return graph_transformer_config - @_logger.TrackTime(_logger.TimeTrackerPhase.GRAPH_BUILDER_INIT) + @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) + @_logger.SuppressLogs(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) def _initialize_graph_builder(self): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" @@ -483,7 +482,7 @@ class GraphExecutionManager(GraphExecutionInterface): _utils.reinitialize_graph_execution_manager(self) - @_logger.TrackTime(_logger.TimeTrackerPhase.DETECTION) + @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) def _enable_conditional_optimizations( self, graph_transformer_config: C.TrainingGraphTransformerConfiguration, inputs: Tuple, kwargs: Dict ): @@ -544,11 +543,7 @@ class GraphExecutionManager(GraphExecutionInterface): self._runtime_inspector.enable_memory_inspector(self._original_module) def _log_feature_stats(self): - rank = 0 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - - if rank != 0: + if get_rank() != 0: return feature_map: List[Tuple[str, bool, str]] = [ @@ -638,11 +633,8 @@ class GraphExecutionManager(GraphExecutionInterface): switch_str = "ON" if feature_tuple[1] else "OFF" stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n" - stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n" - stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n" - # Collect ORTModule overheads for different phases. - stat += f"{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n" + stat += f"\n{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n" stat += f"Versions: ONNX Runtime - {onnxruntime.__version__}, ONNX - {onnx.__version__}\n\n" stat += f"{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n" diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py index 1b09eafee9..104cc0a894 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py @@ -15,7 +15,11 @@ from .options import DebugOptions class GraphExecutionManagerFactory: def __init__( - self, module: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + module: _FlattenedModule, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): self._training_manager = TrainingManager(module, debug_options, fallback_manager, logger) self._inference_manager = InferenceManager(module, debug_options, fallback_manager, logger) diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 96936250f3..24215364d6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -15,7 +15,7 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo -from ._logger import TimeTrackerPhase, TrackTime +from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._utils import save_tuning_results, set_tuning_results from .options import DebugOptions, _SkipCheck @@ -111,7 +111,7 @@ class InferenceManager(GraphExecutionManager): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): - self.time_tracker.start(TimeTrackerPhase.EndToEnd) + self.time_tracker.start(ORTModuleInitPhase.EndToEnd) # Exporting module to ONNX for the first time build_graph = self._export_model(*inputs, **kwargs) @@ -151,7 +151,7 @@ class InferenceManager(GraphExecutionManager): # Create execution session creates the inference_session self._create_execution_agent() - self.time_tracker.end(TimeTrackerPhase.EndToEnd) + self.time_tracker.end(ORTModuleInitPhase.EndToEnd) self._log_feature_stats() if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: @@ -201,7 +201,8 @@ class InferenceManager(GraphExecutionManager): if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) - @TrackTime(TimeTrackerPhase.BUILD_GRAPH) + @TrackTime(ORTModuleInitPhase.BUILD_GRAPH) + @SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH) def _build_graph(self, graph_transformer_config): """Build an inference graph using the module_graph_builder""" @@ -214,7 +215,8 @@ class InferenceManager(GraphExecutionManager): self._export_mode, ) - @TrackTime(TimeTrackerPhase.CREATE_SESSION) + @TrackTime(ORTModuleInitPhase.CREATE_SESSION) + @SuppressLogs(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates an InferenceAgent that can run forward graph on an inference model""" diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index f075897434..76a4035855 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -3,16 +3,21 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import io import logging +import os import sys +import tempfile +import textwrap import time from contextlib import contextmanager from enum import IntEnum -from typing import Callable, Dict, List +from functools import partial +from typing import Callable, Dict, List, Optional from onnxruntime.capi._pybind_state import Severity +from ._utils import get_rank, get_world_size + class LogLevel(IntEnum): VERBOSE = 0 @@ -22,34 +27,6 @@ class LogLevel(IntEnum): FATAL = 4 -@contextmanager -def suppress_os_stream_output(suppress_stdout=True, suppress_stderr=True, log_level=LogLevel.WARNING): - """Suppress output from being printed to stdout and stderr if log_level is WARNING or higher. - - If there is any output detected, a single warning is issued in the context - """ - - # stdout and stderr is written to a tempfile instead - stdout = sys.stdout - stderr = sys.stderr - - suppress_logs = log_level >= LogLevel.WARNING - - fo = io.StringIO() - - try: - if suppress_stdout and suppress_logs: - sys.stdout = fo - if suppress_stderr and suppress_logs: - sys.stderr = fo - yield fo - finally: - if suppress_stdout: - sys.stdout = stdout - if suppress_stderr: - sys.stderr = stderr - - ORTMODULE_LOG_LEVEL_MAP: Dict[LogLevel, List[int]] = { LogLevel.VERBOSE: [Severity.VERBOSE, logging.DEBUG], LogLevel.INFO: [Severity.INFO, logging.INFO], @@ -67,6 +44,21 @@ def ortmodule_loglevel_to_python_loglevel(loglevel: LogLevel) -> int: return ORTMODULE_LOG_LEVEL_MAP.get(loglevel, [Severity.WARNING, logging.WARNING])[1] +def configure_ortmodule_logger(log_level: LogLevel) -> logging.Logger: + """Configure the logger for ortmodule according to following rules. + 1. If multiple processes are used, the rank will be appended + to the logger name. + 2. If the log level is greater than info, the logger will be + disabled for non-zero ranks. + """ + rank_info = f".rank-{get_rank()}" if get_world_size() > 1 else "" + logger = logging.getLogger(f"orttraining{rank_info}") + # Disable the logger for non-zero ranks when level > info + logger.disabled = log_level > LogLevel.INFO and get_rank() != 0 + logger.setLevel(ortmodule_loglevel_to_python_loglevel(log_level)) + return logger + + class LogColor: HEADER = "\033[95m" BLUE = "\033[94m" @@ -79,26 +71,26 @@ class LogColor: UNDERLINE = "\033[4m" -class TimeTrackerPhase(IntEnum): - EndToEnd = 0 # The total overhead of ORT first-time initialization - EXPORT = 1 # The latency of preparing and exporting the model to ONNX - GRAPH_BUILDER_INIT = 2 # The latency of initializing the graph builder - DETECTION = 3 # The latency of runtime detection - BUILD_GRAPH = 4 # The latency of optimizing forward graph (and building the gradient graph for training). - CREATE_SESSION = 5 # The latency of creating the session +class ORTModuleInitPhase(IntEnum): + EndToEnd = 0 # The end to end of ORT first-time initialization + EXPORT = 1 # The phase of preparing and exporting the model to ONNX + GRAPH_BUILDER_INIT = 2 # The phase of initializing the graph builder + DETECTION = 3 # The phase of runtime detection + BUILD_GRAPH = 4 # The phase of optimizing forward graph (and building the gradient graph for training). + CREATE_SESSION = 5 # The phase of creating the session def to_string(self) -> str: - if self == TimeTrackerPhase.EndToEnd: + if self == ORTModuleInitPhase.EndToEnd: return "end to end" - if self == TimeTrackerPhase.EXPORT: + if self == ORTModuleInitPhase.EXPORT: return "export" - elif self == TimeTrackerPhase.GRAPH_BUILDER_INIT: + elif self == ORTModuleInitPhase.GRAPH_BUILDER_INIT: return "graph builder init" - elif self == TimeTrackerPhase.DETECTION: + elif self == ORTModuleInitPhase.DETECTION: return "runtime detection" - elif self == TimeTrackerPhase.BUILD_GRAPH: + elif self == ORTModuleInitPhase.BUILD_GRAPH: return "graph building" - elif self == TimeTrackerPhase.CREATE_SESSION: + elif self == ORTModuleInitPhase.CREATE_SESSION: return "session creation" else: return "invalid" @@ -112,24 +104,24 @@ class TimeTracker: def __init__( self, ): - self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase) - self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(TimeTrackerPhase) + self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) + self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) - def start(self, phase: TimeTrackerPhase): + def start(self, phase: ORTModuleInitPhase): self.starts_[phase] = time.time() - def end(self, phase: TimeTrackerPhase): + def end(self, phase: ORTModuleInitPhase): self.ends_[phase] = time.time() - def _get_duration(self, phase: TimeTrackerPhase): + def _get_duration(self, phase: ORTModuleInitPhase): if self.ends_[phase] == TimeTracker.NOT_RECORD or self.starts_[phase] == TimeTracker.NOT_RECORD: return TimeTracker.NOT_RECORD return self.ends_[phase] - self.starts_[phase] def to_string(self, log_details=False) -> str: - end_to_end_str = self._get_duration(TimeTrackerPhase.EndToEnd) + end_to_end_str = self._get_duration(ORTModuleInitPhase.EndToEnd) end_to_end_str = f"{end_to_end_str:.2f}" if end_to_end_str != TimeTracker.NOT_RECORD else "N/A" - export_str = self._get_duration(TimeTrackerPhase.EXPORT) + export_str = self._get_duration(ORTModuleInitPhase.EXPORT) export_str = f"{export_str:.2f}" if export_str != TimeTracker.NOT_RECORD else "N/A" overhead_title_str = ( f"Total ORT initialization overhead is {end_to_end_str}s where export takes {export_str}s.\n" @@ -139,9 +131,9 @@ class TimeTracker: return overhead_title_str duration_summaries = [] - for phase in TimeTrackerPhase: + for phase in ORTModuleInitPhase: _get_duration = self._get_duration(phase) - if phase in [TimeTrackerPhase.EndToEnd, TimeTrackerPhase.EXPORT]: + if phase in [ORTModuleInitPhase.EndToEnd, ORTModuleInitPhase.EXPORT]: continue val = ( @@ -155,7 +147,7 @@ class TimeTracker: class TrackTime: """A function decorator to track time spent in different phases of ORT backend first-time initialization.""" - def __init__(self, phase: TimeTrackerPhase): + def __init__(self, phase: ORTModuleInitPhase): self.phase = phase def __call__(self, func: Callable): @@ -168,3 +160,108 @@ class TrackTime: return result return wrapper + + +@contextmanager +def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): + """Suppress output from being printed to stdout and stderr. + + If on_exit is not None, it will be called when the context manager exits. + """ + if enable: + # stdout and stderr is written to a tempfile instead + with tempfile.TemporaryFile() as fp: + try: + # Store original stdout and stderr file no. + old_stdout = os.dup(sys.stdout.fileno()) + old_stderr = os.dup(sys.stderr.fileno()) + + # Redirect stdout and stderr (printed from Python or C++) to the file. + os.dup2(fp.fileno(), sys.stdout.fileno()) + os.dup2(fp.fileno(), sys.stderr.fileno()) + yield + finally: + sys.stdout.flush() + sys.stderr.flush() + + # Restore stdout and stderr. + os.dup2(old_stdout, sys.stdout.fileno()) + os.dup2(old_stderr, sys.stderr.fileno()) + + if on_exit: + on_exit(fp) + else: + yield + + +def _log_with_filter(logger: logging.Logger, record_filters: Optional[List[str]], name: Optional[str], fo): + """Log the content by filtering with list of string patterns. + Args: + logger: The logger to log the content. + record_filters: The list of string patterns to filter the content. + If record_filters is None, the full content will be logged. + name: The name of log filter. + fo: The file object to read the content. + """ + if fo.tell() > 0: + if logger.disabled: + return + + fo.seek(0) + suppress_output_messages = fo.readlines() + if record_filters: + filtered_messages = [] + filtered_lines = 0 + for suppressed_message in suppress_output_messages: + msg = suppressed_message.decode("utf-8") + found = False + for warning in record_filters: + if warning in msg: + found = True + filtered_lines += 1 + break + + if not found: + filtered_messages.extend(textwrap.wrap(msg, 180)) + if filtered_messages: + filtered_messages.insert(0, f"[{name}] Filtered logs ({filtered_lines} records suppressed):") + logger.warning("\n ".join(filtered_messages)) + else: + out_messages = [] + for suppressed_message in suppress_output_messages: + out_messages.extend(textwrap.wrap(suppressed_message.decode("utf-8"), 180)) + if out_messages: + out_messages.insert(0, f"[{name}] Full logs:") + logger.warning("\n ".join(out_messages)) + + +class SuppressLogs: + """A function decorator to suppress in different phases of ORT backend first-time initialization.""" + + def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True): + self.phase = phase + self.is_ort_filter = is_ort_filter + + def __call__(self, func: Callable): + def wrapper(graph_execution_manager, *args, **kwargs): + if not hasattr(graph_execution_manager, "_logger"): + raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.") + + if not hasattr(graph_execution_manager, "_debug_options"): + raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") + + with _suppress_os_stream_output( + enable=graph_execution_manager._debug_options.log_level >= LogLevel.INFO, + on_exit=partial( + _log_with_filter, + graph_execution_manager._logger, + graph_execution_manager._debug_options.onnxruntime_log_filter + if self.is_ort_filter + else graph_execution_manager._debug_options.torch_exporter_filter, + self.phase.to_string(), + ), + ): + result = func(graph_execution_manager, *args, **kwargs) + return result + + return wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index a871ed08b9..dda909e8cb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -114,7 +114,7 @@ class InputDensityObserver: """ def __init__(self, logger: Logger, log_steps=1): - self._logger: Logger = logger + self._logger = logger self._embedding_graph_input_to_padding_idx_map = {} self._loss_label_graph_input_to_ignore_idx_map = {} self._stats = [] diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index dba68be944..1255909022 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -19,7 +19,11 @@ T = TypeVar("T", bound="torch.nn.Module") class TorchModuleORT(TorchModuleInterface): def __init__( - self, module: torch.nn.Module, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + module: torch.nn.Module, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): super().__init__(module) self._flattened_module = _io._FlattenedModule(module) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 4d192b041e..cb8561867a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -18,7 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo from ._io import _FlattenedModule, _InputInfo -from ._logger import TimeTrackerPhase, TrackTime +from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results from .graph_transformer_registry import GraphTransformerRegistry @@ -32,7 +32,11 @@ class TrainingManager(GraphExecutionManager): """ def __init__( - self, model: _FlattenedModule, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger + self, + model: _FlattenedModule, + debug_options: DebugOptions, + fallback_manager: _FallbackManager, + logger: Logger, ): super().__init__(model, debug_options, fallback_manager, logger) @@ -246,7 +250,7 @@ class TrainingManager(GraphExecutionManager): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): - self.time_tracker.start(TimeTrackerPhase.EndToEnd) + self.time_tracker.start(ORTModuleInitPhase.EndToEnd) build_gradient_graph = self._export_model(*inputs, **kwargs) @@ -302,7 +306,7 @@ class TrainingManager(GraphExecutionManager): self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info ) - self.time_tracker.end(TimeTrackerPhase.EndToEnd) + self.time_tracker.end(ORTModuleInitPhase.EndToEnd) self._log_feature_stats() self._gradient_accumulation_manager.maybe_update_cache_before_run() @@ -349,7 +353,8 @@ class TrainingManager(GraphExecutionManager): if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) - @TrackTime(TimeTrackerPhase.BUILD_GRAPH) + @TrackTime(ORTModuleInitPhase.BUILD_GRAPH) + @SuppressLogs(ORTModuleInitPhase.BUILD_GRAPH) def _build_graph(self, graph_transformer_config): """Build an optimized gradient graph using the module_graph_builder""" @@ -394,7 +399,8 @@ class TrainingManager(GraphExecutionManager): else: self._gradient_map.append(-1) - @TrackTime(TimeTrackerPhase.CREATE_SESSION) + @TrackTime(ORTModuleInitPhase.CREATE_SESSION) + @SuppressLogs(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" @@ -427,6 +433,7 @@ class TrainingManager(GraphExecutionManager): ] * len(bw_fetches_names) local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device) + self._execution_agent = TrainingAgent( self._onnx_models.optimized_model.SerializeToString(), fw_feed_names, diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 43bbd5dd21..8aba7c8826 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -444,3 +444,19 @@ def set_tuning_results(session, is_training, tuning_results_path): if os.path.isfile(tuning_result_file): with open(tuning_result_file, encoding="utf-8") as f: session.set_tuning_results(json.load(f)) + + +def get_rank() -> int: + """Returns the rank of the current process. If distributed training is not initialized, returns 0.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + + return 0 + + +def get_world_size() -> int: + """Returns the world size of the current process. If distributed training is not initialized, returns 1.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + + return 1 diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 6eb1e70991..7758ff518a 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -7,12 +7,14 @@ from enum import IntFlag from functools import reduce from logging import Logger +from packaging import version + from onnxruntime.capi import _pybind_state as C from onnxruntime.training import ortmodule from ._fallback import _FallbackPolicy from ._logger import LogLevel -from ._utils import parse_os_env_skip_check_flags +from ._utils import get_runtime_pytorch_version, parse_os_env_skip_check_flags class _SaveOnnxOptions: @@ -131,6 +133,39 @@ class DebugOptions: return self._logging + @property + def torch_exporter_filter(self): + """Accessor for the filter export logs configuration.""" + if self.log_level >= LogLevel.INFO and get_runtime_pytorch_version() < version.parse("2.0"): + return [ + # WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of com.microsoft::PythonOp type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + # WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. + "type is missing, so it may result in wrong shape inference", + # Warning: Checker does not support models with experimental ops: ATen + "Checker does not support models with experimental ops:", + "Dropout is a training op and should not be exported in inference mode.", + # Warning: Shape inference does not support models with experimental operators: ATen + "Shape inference does not support models with experimental operators:", + # Warning: Unsupported operator Trilu. No schema registered for this operator. + # Warning: Unsupported operator ATen. No schema registered for this operator. + # Warning: Unsupported operator SoftmaxCrossEntropyLossInternal. No schema registered for this operator. + "No schema registered for this operator.", + ] + + return None + + @property + def onnxruntime_log_filter(self): + """Accessor for the filter onnxruntime logs configuration.""" + if self.log_level >= LogLevel.INFO: + return [ + "CleanUnusedInitializersAndNodeArgs] Removing initializer", + "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED", + ] + return None + class _SkipCheck(IntFlag): """Enumeration to specify which checks should be skipped, allowing faster execution""" diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 7978695f0e..b5c52bdaef 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # isort: skip_file -# Import ordering is important in this module to aviod circular dependencies -import logging +# Import ordering is important in this module to avoid circular dependencies + from ._torch_module_factory import TorchModuleFactory from ._torch_module_ort import TorchModuleORT from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry @@ -12,7 +12,7 @@ from ._custom_gradient_registry import CustomGradientRegistry from . import _utils from .options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException -from ._logger import ortmodule_loglevel_to_python_loglevel +from ._logger import configure_ortmodule_logger from onnxruntime.training import ortmodule from onnxruntime.tools import pytorch_export_contrib_ops @@ -53,8 +53,7 @@ class ORTModule(torch.nn.Module): if not debug_options: debug_options = DebugOptions() - self._logger = logging.getLogger(__name__) - self._logger.setLevel(ortmodule_loglevel_to_python_loglevel(debug_options.logging.log_level)) + self._logger = configure_ortmodule_logger(debug_options.logging.log_level) # Fallback settings self._fallback_manager = _FallbackManager( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 625c1ce0d4..8e3f70ded7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6061,3 +6061,52 @@ def test_e2e_padding_elimination(): assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] + + +@pytest.mark.skipif( + Version(torch.__version__) >= Version("1.13.0"), + reason="PyTorch since 1.13 don't output expected warning messages any more", +) +@pytest.mark.parametrize("log_level", [LogLevel.VERBOSE, LogLevel.INFO, LogLevel.WARNING]) +def test_ortmodule_log_level_control(log_level, caplog): + class NeuralNetCrossEntropyLoss(torch.nn.Module): + def __init__(self, num_embeddings, embedding_dim): + super().__init__() + self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=1) + + def forward(self, input, positions): + output = torch.transpose(self.embedding(input), 0, 1) + ignored_index = output.size(1) + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) + return loss_fct(output, positions) + + device = "cuda" + num_embeddings, embedding_dim = 32, 128 + pt_model = NeuralNetCrossEntropyLoss(num_embeddings, embedding_dim).to(device) + + ort_model = ORTModule(pt_model, DebugOptions(log_level=log_level)) + use_fp16 = True + + def run_step(model, input, positions): + with amp.autocast(use_fp16): + loss = model(input, positions) + loss.backward() + return loss + + N = random.randint(16, 32) # noqa: N806 + input = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64, device=device) + positions = torch.randint(high=N, size=(embedding_dim,), dtype=torch.int64, device=device) + _ = run_step(ort_model, input, positions) + + found_missing_inference_log = False + for record in caplog.records: + msg = record.getMessage() + print(msg) + if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg: + found_missing_inference_log = True + break + + if log_level == LogLevel.VERBOSE: + assert found_missing_inference_log + else: + assert not found_missing_inference_log diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 020ed06513..5b2ef9ad28 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -564,7 +564,7 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog): if i == 0: # For the first time, run ORTModule, feature map is logged as warning # And the fallback warning is logged. - assert len(caplog.records) == 2 + assert len(caplog.records) >= 2 else: # For the other time, only the fallback warning is logged. assert len(caplog.records) == 1 @@ -576,8 +576,8 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback, caplog): if i == 0: # For the first time, run ORTModule, feature map is logged as warning # And the fallback warning is logged. - assert len(caplog.records) == 2 - assert "Fallback to PyTorch due to exception" in caplog.records[1].message + assert len(caplog.records) >= 2 + assert "Fallback to PyTorch due to exception" in caplog.records[-1].message caplog.clear() else: # For the other time, no fallback warning will be logged because