diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index d86eb4adeb..8b2f9a863a 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -64,3 +64,4 @@ torch.cuda.manual_seed = override_torch_cuda_manual_seed # ORTModule must be loaded only after all validation passes from .ortmodule import ORTModule +from .debug_options import DebugOptions, LogLevel diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5ce581e3ce..c3ef898849 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -3,7 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext +from .debug_options import DebugOptions, LogLevel +from . import _utils, _io, _logger, torch_cpp_extensions as _cpp_ext, _onnx_models from ._custom_autograd_function_exporter import _post_process_after_export from ._graph_execution_interface import GraphExecutionInterface from onnxruntime.training.ortmodule import ONNX_OPSET_VERSION @@ -54,7 +55,7 @@ class _SkipCheck(IntFlag): return _SkipCheck.SKIP_CHECK_DISABLED in self class GraphExecutionManager(GraphExecutionInterface): - def __init__(self, module): + def __init__(self, module, debug_options: DebugOptions): """Manages building and execution of onnx graphs This class is an abstract class and should not directly be instantiated. @@ -69,8 +70,8 @@ class GraphExecutionManager(GraphExecutionInterface): # Original and flattened (tranformed) output module self._flattened_module = module - # Exported model - self._onnx_model = None + # onnx models + self._onnx_models = _onnx_models.ONNXModels() # Model after inference optimization or gradient building. self._optimized_onnx_model = None @@ -87,8 +88,7 @@ class GraphExecutionManager(GraphExecutionInterface): self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED # Debug flags - self._save_onnx = False - self._save_onnx_prefix = '' + self._debug_options = debug_options # Graph transformer config # Specify cast propagation strategy. Currently three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL @@ -126,9 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface): self._input_info = None self._module_output_schema = None - # Log level - self._loglevel = _logger.LogLevel.WARNING - # TODO: Single device support for now self._device = _utils.get_device_from_module(module) @@ -137,7 +134,7 @@ class GraphExecutionManager(GraphExecutionInterface): # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. for input_parameter in self._module_parameters: if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: - if self._loglevel <= _logger.LogLevel.WARNING: + if self._debug_options.logging.log_level <= LogLevel.WARNING: warnings.warn("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!", UserWarning) @@ -194,7 +191,7 @@ class GraphExecutionManager(GraphExecutionInterface): else: self._graph_builder.build() - self._optimized_onnx_model = onnx.load_model_from_string(self._graph_builder.get_model()) + self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_model()) self._graph_info = self._graph_builder.get_graph_info() def _get_session_config(self): @@ -222,11 +219,7 @@ class GraphExecutionManager(GraphExecutionInterface): # default to PRIORITY_BASED execution order session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. - session_options.log_severity_level = int(self._loglevel) - - # enable dumping optimized training graph - if self._save_onnx: - session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx' + session_options.log_severity_level = int(self._debug_options.logging.log_level) return session_options, providers, provider_options @@ -243,18 +236,21 @@ class GraphExecutionManager(GraphExecutionInterface): # or the user explicitly changed model parameters after the onnx export. schema = _io._extract_schema({'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)}) - if self._onnx_model and schema == self._input_info.schema: + if self._onnx_models.exported_model and schema == self._input_info.schema: # All required models have already been exported previously return False self._set_device_from_module(inputs, kwargs) - self._onnx_model = self._get_exported_model(*inputs, **kwargs) - _cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_model) - if self._save_onnx: - onnx.save(self._onnx_model, self._save_onnx_prefix + '_torch_exporter.onnx') + self._onnx_models.exported_model = self._get_exported_model(*inputs, **kwargs) + _cpp_ext._load_aten_op_executor_cpp_extension_if_needed(self._onnx_models.exported_model) + if self._debug_options.save_onnx_models.save: + self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode) if self._run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True) + self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(self._onnx_models.exported_model, + auto_merge=True, guess_output_rank=True) return True @@ -292,7 +288,7 @@ class GraphExecutionManager(GraphExecutionInterface): try: with torch.set_grad_enabled(self._enable_custom_autograd_function), \ - _logger.suppress_os_stream_output(log_level=self._loglevel): + _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level): torch.onnx.export(self._flattened_module, sample_inputs_as_tuple, f, @@ -302,7 +298,7 @@ class GraphExecutionManager(GraphExecutionInterface): do_constant_folding=False, training=self._export_mode, dynamic_axes=self._input_info.dynamic_axes, - verbose=self._loglevel < _logger.LogLevel.WARNING, + verbose=self._debug_options.logging.log_level < LogLevel.WARNING, export_params=False, keep_initializers_as_inputs=True) except RuntimeError as e: @@ -337,7 +333,7 @@ class GraphExecutionManager(GraphExecutionInterface): # All initializer names along with user inputs are a part of the onnx graph inputs # since the onnx model was exported with the flag keep_initializers_as_inputs=True - onnx_initializer_names = {p.name for p in self._onnx_model.graph.input} + onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [name for name, _ in self._flattened_module.named_parameters() @@ -353,12 +349,12 @@ class GraphExecutionManager(GraphExecutionInterface): grad_builder_config.build_gradient_graph = training grad_builder_config.graph_transformer_config = self._get_graph_transformer_config() grad_builder_config.enable_caching = self._enable_grad_acc_optimization - grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._loglevel) + grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level) self._graph_builder = C.OrtModuleGraphBuilder() # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config) + self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config) # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train # a set (unordered_set in the backend) that does not require a copy on each reference. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py index 6d61b586a6..cd839ed9e7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py @@ -8,9 +8,9 @@ from ._inference_manager import InferenceManager class GraphExecutionManagerFactory(object): - def __init__(self, module): - self._training_manager = TrainingManager(module) - self._inference_manager = InferenceManager(module) + def __init__(self, module, debug_options): + self._training_manager = TrainingManager(module, debug_options) + self._inference_manager = InferenceManager(module, debug_options) def __call__(self, is_training): if is_training: diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index ba688e0b7d..70bd7c80d9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -18,8 +18,8 @@ class InferenceManager(GraphExecutionManager): InferenceManager is resposible for building and running the forward graph of the inference model """ - def __init__(self, model): - super().__init__(model) + def __init__(self, model, debug_options): + super().__init__(model, debug_options) self._export_mode = torch.onnx.TrainingMode.EVAL @staticmethod @@ -67,10 +67,6 @@ class InferenceManager(GraphExecutionManager): # If model was exported, then initialize the graph builder self._initialize_graph_builder(training=False) - # Save the onnx model if the model was exported - if self._save_onnx: - onnx.save(self._onnx_model, self._save_onnx_prefix + '_exported_inference_model.onnx') - # Build the inference graph if build_graph: self._build_graph() @@ -86,7 +82,7 @@ class InferenceManager(GraphExecutionManager): self._create_execution_agent() user_outputs, _ = InferenceManager.execution_session_run_forward(self._execution_agent, - self._optimized_onnx_model, + self._onnx_models.optimized_model, self._device, *_io._combine_input_buffers_initializers( self._graph_initializers, @@ -104,12 +100,14 @@ class InferenceManager(GraphExecutionManager): """Build an optimized inference graph using the module_graph_builder""" super()._build_graph() - if self._save_onnx: - onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_inference.onnx') + if self._debug_options.save_onnx_models.save: + self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode) def _create_execution_agent(self): """Creates an InferenceAgent that can run forward graph on an inference model""" session_options, providers, provider_options = self._get_session_config() - self._execution_agent = InferenceAgent(self._optimized_onnx_model.SerializeToString(), + self._execution_agent = InferenceAgent(self._onnx_models.optimized_model.SerializeToString(), session_options, providers, provider_options) diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 27b3eab27e..c503df765f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -10,6 +10,7 @@ import io import sys import warnings + class LogLevel(IntEnum): VERBOSE = 0 INFO = 1 diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py new file mode 100644 index 0000000000..9cd88f56d7 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# _onnx_models.py + +from dataclasses import dataclass +from filelock import SoftFileLock +import onnx +import os +import torch + +def _get_onnx_file_name(name_prefix, name, export_mode): + suffix = 'training' if export_mode == torch.onnx.TrainingMode.TRAINING else 'inference' + return f"{name_prefix}_{name}_{suffix}.onnx" + +def _save_model(model: onnx.ModelProto, file_path: str): + onnx.save(model, file_path) + +@dataclass +class ONNXModels: + """Encapsulates all ORTModule onnx models.""" + + exported_model: onnx.ModelProto = None + optimized_model: onnx.ModelProto = None + + def save_exported_model(self, path, name_prefix, export_mode): + # save the ortmodule exported model + _save_model(self.exported_model, os.path.join(path, + _get_onnx_file_name(name_prefix, 'torch_exported', export_mode))) + + def save_optimized_model(self, path, name_prefix, export_mode): + # save the ortmodule optimized model + _save_model(self.optimized_model, os.path.join(path, + _get_onnx_file_name(name_prefix, 'optimized', export_mode))) diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module.py b/orttraining/orttraining/python/training/ortmodule/_torch_module.py index e4f7ba5fcc..9117a91110 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module.py @@ -3,6 +3,7 @@ # _torch_module.py from . import _io +from .debug_options import DebugOptions from ._graph_execution_manager_factory import GraphExecutionManagerFactory from ._torch_module_interface import TorchModuleInterface @@ -16,7 +17,7 @@ T = TypeVar('T', bound='torch.nn.Module') class TorchModule(TorchModuleInterface): - def __init__(self, module: torch.nn.Module): + def __init__(self, module: torch.nn.Module, debug_options: DebugOptions): super(TorchModule, self).__init__(module) self._flattened_module = _io._FlattenedModule(module) @@ -36,7 +37,7 @@ class TorchModule(TorchModuleInterface): functools.update_wrapper( self.forward.__func__, self._original_module.forward.__func__) - self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) + self._execution_manager = GraphExecutionManagerFactory(self._flattened_module, debug_options) def _apply(self, fn): """Override original method to delegate execution to the flattened PyTorch user module""" diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py index 3aaf40dca6..835b67d725 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py @@ -6,7 +6,7 @@ from ._torch_module import TorchModule class TorchModuleFactory: - def __call__(self, module): + def __call__(self, module, debug_options): """Creates a TorchModule instance based on the input module.""" - return TorchModule(module) + return TorchModule(module, debug_options) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 8395e87ed4..d7606eee40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -10,7 +10,6 @@ from ._execution_agent import TrainingAgent from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type -import onnx import torch import warnings from torch.utils.dlpack import from_dlpack, to_dlpack @@ -22,8 +21,8 @@ class TrainingManager(GraphExecutionManager): TrainingManager is resposible for building and running the forward and backward graph of the training model """ - def __init__(self, model): - super().__init__(model) + def __init__(self, model, debug_options): + super().__init__(model, debug_options) self._export_mode = torch.onnx.TrainingMode.TRAINING @staticmethod @@ -67,7 +66,7 @@ class TrainingManager(GraphExecutionManager): self._initialize_graph_builder(training=True) input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, - self._onnx_model, + self._onnx_models.exported_model, inputs, kwargs) @@ -122,7 +121,7 @@ class TrainingManager(GraphExecutionManager): _utils._check_same_device(self._device, "Input argument to forward", *inputs) user_outputs, ctx.run_info = TrainingManager.execution_session_run_forward(self._execution_agent, - self._optimized_onnx_model, + self._onnx_models.optimized_model, *inputs) # Disable materializing grads then None object will not be @@ -233,30 +232,30 @@ class TrainingManager(GraphExecutionManager): super()._build_graph() - if self._save_onnx: - onnx.save(self._optimized_onnx_model, self._save_onnx_prefix + '_training.onnx') - inference_optimized_model = onnx.load_model_from_string(self._graph_builder.get_inference_optimized_model()) - onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx') + if self._debug_options.save_onnx_models.save: + self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" session_options, providers, provider_options = self._get_session_config() - fw_feed_names = [input.name for input in self._optimized_onnx_model.graph.input] + fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] fw_outputs_device_info = [ C.OrtDevice(get_ort_device_type(self._device.type), C.OrtDevice.default_memory(), _utils.get_device_index(self._device) )] * len(self._graph_info.user_output_names) - bw_fetches_names = [output.name for output in self._optimized_onnx_model.graph.output] + bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output] bw_outputs_device_info = [ C.OrtDevice(get_ort_device_type(self._device.type), C.OrtDevice.default_memory(), _utils.get_device_index(self._device) )] * len(bw_fetches_names) - self._execution_agent = TrainingAgent(self._optimized_onnx_model.SerializeToString(), + self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, diff --git a/orttraining/orttraining/python/training/ortmodule/debug_options.py b/orttraining/orttraining/python/training/ortmodule/debug_options.py new file mode 100644 index 0000000000..a373e7c91b --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/debug_options.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# debug_options.py + +import os + +from ._logger import LogLevel + +class _SaveOnnxOptions: + """Configurable option to save ORTModule intermediate onnx models.""" + + # class variable + _path_environment_key = 'ORTMODULE_SAVE_ONNX_PATH' + + def __init__(self, save, name_prefix): + self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix) + + def _extract_info(self, save, name_prefix): + # get the destination path from os env variable + destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd()) + # perform validation only when save is True + if save: + self._validate(save, name_prefix, destination_path) + return save, name_prefix, destination_path + + def _validate(self, save, name_prefix, destination_path): + # check if directory is writable + if not os.access(destination_path, os.W_OK): + raise OSError(f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path.") + + # check if input prefix is a string + if not isinstance(name_prefix, str): + raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.") + + # if save_onnx is set, save_onnx_prefix must be a non empty string + if not name_prefix: + raise ValueError("onnx_prefix must be provided when save_onnx is set.") + + @property + def save(self): + return self._save + + @property + def name_prefix(self): + return self._name_prefix + + @property + def path(self): + return self._path + + +class _LoggingOptions: + """Configurable option to set the log level in ORTModule.""" + + # class variable + _log_level_environment_key = 'ORTMODULE_LOG_LEVEL' + + def __init__(self, log_level): + self._log_level = self._extract_info(log_level) + + def _extract_info(self, log_level): + # get the log_level from os env variable + # os env variable log level supercededs the locally provided one + self._validate(log_level) + log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)] + return log_level + + def _validate(self, log_level): + # check if log_level is an instance of LogLevel + if not isinstance(log_level, LogLevel): + raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.") + + @property + def log_level(self): + return self._log_level + +class DebugOptions: + """Configurable debugging options for ORTModule. + + DebugOptions provides a way to configure ORTModule debug flags. + + Args: + log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING. + log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of + "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable + takes precedence. + save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False. + The output directory of the onnx models by default is set to the current working directory. + To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be + set to the destination directory path. + onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names. + Must be provided if save_onnx is True + + Raises: + OSError: If save_onnx is True and output directory is not writable. + TypeError: If save_onnx is True and name_prefix is not a valid string. Or if + log_level is not an instance of LogLevel. + ValueError: If save_onnx is True and name_prefix is an empty string. + + """ + + def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=''): + self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix) + self._logging = _LoggingOptions(log_level) + + @property + def save_onnx_models(self): + """Accessor for the save_onnx_models debug flag.""" + + return self._save_onnx_models + + @property + def logging(self): + """Accessor for the logging debug flag.""" + + return self._logging diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 20c8cb6c48..998cdc4010 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -6,6 +6,7 @@ from ._torch_module_factory import TorchModuleFactory from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry from ._custom_gradient_registry import CustomGradientRegistry +from .debug_options import DebugOptions from onnxruntime.training import register_custom_ops_pytorch_exporter @@ -21,10 +22,19 @@ class ORTModule(torch.nn.Module): ORTModule specializes the user's :class:`torch.nn.Module` model, providing :meth:`~torch.nn.Module.forward`, :meth:`~torch.nn.Module.backward` along with all others :class:`torch.nn.Module`'s APIs. + + Args: + module (torch.nn.Module): User's PyTorch module that ORTModule specializes + debug_options (:obj:`DebugOptions`, optional): debugging options for ORTModule. """ - def __init__(self, module): - self._torch_module = TorchModuleFactory()(module) + def __init__(self, module, debug_options=None): + # Python default arguments are evaluated on function definintion + # and not on function invocation. So, if debug_options is not provided, + # instantiate it inside the function. + if not debug_options: + debug_options = DebugOptions() + self._torch_module = TorchModuleFactory()(module, debug_options) # Create forward dynamically, so each ORTModule instance will have its own copy. # This is needed to be able to copy the forward signatures from the original PyTorch models diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index 98bfff95ae..eeff0b738d 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -124,7 +124,7 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): def is_dynamic_axes(model): # Check inputs - for inp in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.input: + for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input: shape = inp.type.tensor_type.shape if shape: for dim in shape.dim: @@ -132,7 +132,7 @@ def is_dynamic_axes(model): return False # Check outputs - for out in model._torch_module._execution_manager(model._is_training())._optimized_onnx_model.graph.output: + for out in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.output: shape = out.type.tensor_type.shape if shape: for dim in shape.dim: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 27d3beac6f..062889f1fd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -16,8 +16,9 @@ from collections import OrderedDict from collections import namedtuple from inspect import signature import tempfile +import os -from onnxruntime.training.ortmodule import ORTModule, _utils, _io +from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel import _test_helpers # Import autocasting libs @@ -2714,10 +2715,10 @@ def test_changing_bool_input_re_exports_model(bool_arguments): input1 = torch.randn(N, D_in, device=device) ort_model(input1, bool_arguments[0]) - exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model + exported_model1 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model ort_model(input1, bool_arguments[1]) - exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_model + exported_model2 = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model assert exported_model1 != exported_model2 @@ -2877,7 +2878,7 @@ def test_unused_parameters_does_not_unnecssarily_reinitilize(model): _ = ort_model(x) input_info = _io.parse_inputs_for_onnx_export(training_manager._module_parameters, - training_manager._onnx_model, + training_manager._onnx_models.exported_model, x, {}) @@ -3038,3 +3039,95 @@ def test_ortmodule_nested_list_input(): y = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + +@pytest.mark.parametrize("mode", ['training', 'inference']) +def test_debug_options_save_onnx_models_os_environment(mode): + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + # Create a temporary directory for the onnx_models + with tempfile.TemporaryDirectory() as temporary_dir: + os.environ['ORTMODULE_SAVE_ONNX_PATH'] = temporary_dir + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_model')) + if mode == 'inference': + ort_model.eval() + x = torch.randn(N, D_in, device=device) + _ = ort_model(x) + + # assert that the onnx models have been saved + assert os.path.exists(os.path.join(temporary_dir, f"my_model_torch_exported_{mode}.onnx")) + assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_{mode}.onnx")) + del os.environ['ORTMODULE_SAVE_ONNX_PATH'] + +@pytest.mark.parametrize("mode", ['training', 'inference']) +def test_debug_options_save_onnx_models_cwd(mode): + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_cwd_model')) + if mode == 'inference': + ort_model.eval() + x = torch.randn(N, D_in, device=device) + _ = ort_model(x) + + # assert that the onnx models have been saved + assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx")) + + os.remove(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx")) + os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx")) + +def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir(): + + os.environ['ORTMODULE_SAVE_ONNX_PATH'] = '/non/existent/directory' + with pytest.raises(Exception) as ex_info: + _ = DebugOptions(save_onnx=True, onnx_prefix='my_model') + assert "Directory /non/existent/directory is not writable." in str(ex_info.value) + del os.environ['ORTMODULE_SAVE_ONNX_PATH'] + +def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix(): + prefix = 23 + with pytest.raises(Exception) as ex_info: + _ = DebugOptions(save_onnx=True, onnx_prefix=prefix) + assert f"Expected name prefix of type str, got {type(prefix)}." in str(ex_info.value) + +def test_debug_options_save_onnx_models_validate_fail_on_no_prefix(): + with pytest.raises(Exception) as ex_info: + _ = DebugOptions(save_onnx=True) + assert f"onnx_prefix must be provided when save_onnx is set." in str(ex_info.value) + +def test_debug_options_log_level(): + # NOTE: This test will output verbose logging + + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(model, DebugOptions(log_level=LogLevel.VERBOSE)) + x = torch.randn(N, D_in, device=device) + _ = ort_model(x) + + # assert that the logging is done in verbose mode + assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.VERBOSE + +def test_debug_options_log_level_os_environment(): + # NOTE: This test will output info logging + + os.environ['ORTMODULE_LOG_LEVEL'] = 'INFO' + device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + ort_model = ORTModule(model) + x = torch.randn(N, D_in, device=device) + _ = ort_model(x) + + # assert that the logging is done in info mode + assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.INFO + del os.environ['ORTMODULE_LOG_LEVEL'] + +def test_debug_options_log_level_validation_fails_on_type_mismatch(): + log_level = 'some_string' + with pytest.raises(Exception) as ex_info: + _ = DebugOptions(log_level=log_level) + assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 383e26e607..5d7508d3b8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -17,7 +17,7 @@ import datetime import onnxruntime -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule import ORTModule, DebugOptions def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # ======================================== @@ -376,11 +376,10 @@ def main(): ) if not args.pytorch_only: - model = ORTModule(model) + # Just for future debugging + debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassification') - # Just for future debugging - model._torch_module._execution_manager(model._is_training())._save_onnx = False - model._torch_module._execution_manager(model._is_training())._save_onnx_prefix = 'BertForSequenceClassification' + model = ORTModule(model, debug_options) # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index 7b6e0afd4c..c73eec9509 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -17,7 +17,7 @@ import datetime import onnxruntime -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule import ORTModule, DebugOptions def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, args): # ======================================== @@ -377,11 +377,11 @@ def main(): ) if not args.pytorch_only: - model = ORTModule(model) + # Just for future debugging + debug_options = DebugOptions(save_onnx=False, onnx_prefix='BertForSequenceClassificationAutoCast') + model = ORTModule(model, debug_options) - model._torch_module._execution_manager(is_training=True)._save_onnx = True - model._torch_module._execution_manager(is_training=True)._save_onnx_prefix = 'BertForSequenceClassification' model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True # Tell pytorch to run this model on the GPU. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py index eab0339c9b..2b12b7d67c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py @@ -9,14 +9,13 @@ $ deepspeed orttraining_test_ortmodule_deepspeed_zero_stage_1.py \ ``` """ import argparse -import logging import torch import time from torchvision import datasets, transforms import torch.distributed as dist import onnxruntime -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel import deepspeed @@ -190,17 +189,20 @@ def main(): model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) if not args.pytorch_only: print('Training MNIST on ORTModule....') - model = ORTModule(model) - - # TODO: change it to False to stop saving ONNX models - model._save_onnx = True - model._save_onnx_prefix = 'MNIST' # Set log level - numeric_level = getattr(logging, args.log_level.upper(), None) - if not isinstance(numeric_level, int): + log_level_mapping = {"DEBUG": LogLevel.VERBOSE, + "INFO": LogLevel.INFO, + "WARNING": LogLevel.WARNING, + "ERROR": LogLevel.ERROR, + "CRITICAL": LogLevel.FATAL} + log_level = log_level_mapping.get(args.log_level.upper(), None) + if not isinstance(log_level, LogLevel): raise ValueError('Invalid log level: %s' % args.log_level) - logging.basicConfig(level=numeric_level) + debug_options = DebugOptions(log_level=log_level, save_onnx=False, onnx_prefix='MNIST') + + model = ORTModule(model, debug_options) + else: print('Training MNIST on vanilla PyTorch....') diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py index a806c67c32..8aeab01308 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py @@ -5,7 +5,7 @@ import time from torchvision import datasets, transforms import onnxruntime -from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule import ORTModule, DebugOptions class NeuralNet(torch.nn.Module): @@ -167,11 +167,11 @@ def main(): model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) if not args.pytorch_only: print('Training MNIST on ORTModule....') - model = ORTModule(model) - # TODO: change it to False to stop saving ONNX models - model._save_onnx = True - model._save_onnx_prefix = 'MNIST' + # Just for future debugging + debug_options = DebugOptions(save_onnx=False, onnx_prefix='MNIST') + + model = ORTModule(model, debug_options) # Set log level numeric_level = getattr(logging, args.log_level.upper(), None)