From a49bb85cfec6f0c4da0a483aaa235a4841b9aedf Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 27 Jun 2023 19:19:36 +0800 Subject: [PATCH] Manage ORTModule configurations consistently (#16396) ### Manage ORTModule options Move all env vars that used for feature ON/OFF into runtime options for consistent managements. Be noted: the features' switch are assigned in 2 phases: default values, overwritten by env vars (if specified by users). So env vars take the highest priority when all 2 phases both given value explicitly for one feature. ### Motivation and Context --- docs/ORTModule_Training_Guidelines.md | 7 +- .../gradient_graph/_gradient_graph_tools.py | 2 +- .../python/training/ortmodule/__init__.py | 2 +- .../ortmodule/_custom_autograd_function.py | 8 +- .../ortmodule/_custom_op_symbolic_registry.py | 6 +- .../ortmodule/_graph_execution_manager.py | 277 ++++++----------- .../_graph_execution_manager_factory.py | 2 +- .../training/ortmodule/_inference_manager.py | 23 +- .../ortmodule/_torch_module_factory.py | 2 +- .../training/ortmodule/_torch_module_ort.py | 2 +- .../training/ortmodule/_training_manager.py | 35 ++- .../training/ortmodule/debug_options.py | 118 ------- .../_hierarchical_ortmodule.py | 8 +- .../json_config/_load_config_from_json.py | 33 +- .../python/training/ortmodule/options.py | 287 ++++++++++++++++++ .../python/training/ortmodule/ortmodule.py | 10 +- .../python/orttraining_test_ortmodule_api.py | 40 ++- ...test_ortmodule_bert_classifier_autocast.py | 2 +- ...test_ortmodule_experimental_json_config.py | 49 +-- 19 files changed, 506 insertions(+), 407 deletions(-) delete mode 100644 orttraining/orttraining/python/training/ortmodule/debug_options.py create mode 100644 orttraining/orttraining/python/training/ortmodule/options.py diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 7f3d20b8c1..bd98605a82 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -47,7 +47,7 @@ More options for **developers**. + from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel + model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name")) ``` -Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details. +Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/options.py) for more details. ### 2.1 Environment Variables @@ -99,12 +99,13 @@ The output directory of the onnx models by default is set to the current working > Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it. -#### ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT +#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD - **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)* - **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more. ```bash - export ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1 + export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1 # Enable + export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # Disable ``` An alternative to disable without using environment variable: diff --git a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py index a5242ab047..5ab79b3712 100644 --- a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py +++ b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py @@ -38,7 +38,7 @@ def export_gradient_graph( """ # Make sure that loss nodes that expect multiple outputs are set up. - CustomOpSymbolicRegistry.register_all() + CustomOpSymbolicRegistry.register_all(opset_version) if not isinstance(gradient_graph_path, str): gradient_graph_path = str(gradient_graph_path) diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 99ca9b7a61..a6ead5a2c9 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -120,7 +120,7 @@ def _are_deterministic_algorithms_enabled(): return ORTMODULE_IS_DETERMINISTIC -from .debug_options import DebugOptions, LogLevel # noqa: E402, F401 +from .options import DebugOptions, LogLevel # noqa: E402, F401 # ORTModule must be loaded only after all validation passes from .ortmodule import ORTModule # noqa: E402, F401 diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 31a7f07f3c..fece1be20c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -3,6 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from onnxruntime.capi._pybind_state import is_torch_interop_default_on +from onnxruntime.training import ortmodule + class Enabler: def __init__(self): @@ -83,10 +86,7 @@ def enable_custom_autograd_support(to_enable=True): custom_autograd_function_enabler.state = False -from onnxruntime.capi._pybind_state import is_torch_interop_default_on # noqa: E402 -from onnxruntime.training import ortmodule # noqa: E402 - # Enable the custom autograd by default when PythonOp backend support is enabled during build. enable_custom_autograd_support( - not ortmodule._defined_from_envvar("ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT", 0) and is_torch_interop_default_on() + ortmodule._defined_from_envvar("ORTMODULE_ENABLE_CUSTOM_AUTOGRAD", 1) and is_torch_interop_default_on() ) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index fbab6d2b8c..b4bee8a17f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,8 +12,6 @@ from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args -from onnxruntime.training import ortmodule - from ._utils import get_runtime_pytorch_version # Mapping from pytorch scalar type to onnx scalar type. @@ -97,13 +95,13 @@ class CustomOpSymbolicRegistry: cls._SYMBOLICS[domain + "::" + name] = fn @classmethod - def register_all(cls): + def register_all(cls, onnx_opset_version): for name, fn in cls._SYMBOLICS.items(): # Symbolic name is in format: domain::name register_custom_op_symbolic( name, fn, - ortmodule._defined_from_envvar("ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION), + onnx_opset_version, ) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index cb6d9879ec..3732e6cc4c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -9,8 +9,6 @@ import io import logging import os from abc import ABC, abstractmethod # noqa: F401 -from enum import IntFlag -from functools import reduce from typing import Dict, List, Optional, Tuple import onnx @@ -20,9 +18,8 @@ from torch.utils.cpp_extension import ROCM_HOME import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training import ortmodule -from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _runtime_inspector, _utils +from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, @@ -35,7 +32,8 @@ from ._fallback import ( from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_interface import GraphExecutionInterface from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType -from .debug_options import DebugOptions, LogLevel +from ._runtime_inspector import RuntimeInspector +from .options import DebugOptions, LogLevel, _RuntimeOptions from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension @@ -49,28 +47,6 @@ class _RunStateInfo: self.output_info = output_info -class _SkipCheck(IntFlag): - """Enumeration to specify which checks should be skipped, allowing faster execution""" - - SKIP_CHECK_DISABLED = 1 - SKIP_CHECK_DEVICE = 2 - SKIP_CHECK_BUILD_GRADIENT = 4 - SKIP_CHECK_EXECUTION_AGENT = 8 - - def is_set(self, check): - """Check whether `check` is set on the `_SkipCheck instance - - SKIP_CHECK_DISABLED implies the check will return False - """ - - return not _SkipCheck.is_disabled(self) and check in self - - def is_disabled(self): - """Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance""" - - return _SkipCheck.SKIP_CHECK_DISABLED in self - - class GraphExecutionManager(GraphExecutionInterface): def __init__( self, @@ -89,6 +65,7 @@ class GraphExecutionManager(GraphExecutionInterface): self._logger = logger + self._runtime_options = _RuntimeOptions(self._logger) # Original and flattened (transformed) output module self._flattened_module = module @@ -102,44 +79,12 @@ class GraphExecutionManager(GraphExecutionInterface): self._graph_initializer_names_to_train = set() self._graph_initializers: List[torch.nn.parameter.Parameter] = [] - # Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION - # if defined. - ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar( - "ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION, warn=True - ) - # TrainingAgent or InferenceAgent self._execution_agent = None - # Indicators of some logic have been executed previously and thus could be skipped for faster training - # default is enabled, if not defined in os env - self._skip_check = _SkipCheck( - _SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT - ) - if os.getenv("ORTMODULE_SKIPCHECK_POLICY") is not None: - self._skip_check = reduce( - lambda x, y: x | y, - [_SkipCheck[name] for name in _utils.parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")], - ) self._first_skip_check_warning = True - self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger) - - # Graph transformer config - # Specify cast propagation strategy. Currently, three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL - # The default is FLOOD_FILL, expand FP16 computation regions in the graph using allowed opcodes for the given level. - self._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL - # Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. - # - If the _propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes specified by _propagate_cast_ops_allow - # as "FP16 safe", to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision. - # - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow, use onnxruntime - # predetermined list of opcodes considered safe to move before/after the cast operation. - # - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform any computation such as Transpose, Split, Reshape, etc., - # or the computation is actually in Float such as GeLU, etc. - # whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc. - self._propagate_cast_ops_level = 1 - # List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level is zero. - self._propagate_cast_ops_allow = [] + self._runtime_inspector = RuntimeInspector(self._logger) # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager @@ -149,18 +94,6 @@ class GraphExecutionManager(GraphExecutionInterface): # It cannot overlap with required/immutable arguments (validated in runtime) self._export_extra_kwargs = {} - # default execution order is priority-based for both dynamic/static shape input for now - # if we observe the benefit of static shape, we can expose this flag to the user - self._use_static_shape = False - - # flag to enable symbolic shape inference for dynamic shape inputs to improve performance - self._run_symbolic_shape_infer = True - - # PyTorch custom Autograd function support - from ._custom_autograd_function import custom_autograd_function_enabler - - self._enable_custom_autograd_function = custom_autograd_function_enabler.state - # Input and output infos (including schema) for exported model. self._input_info: Optional[_InputInfo] = None self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None @@ -182,36 +115,9 @@ class GraphExecutionManager(GraphExecutionInterface): self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - self._use_external_gpu_allocator = True - # assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True - self._get_torch_gpu_allocator_function_addresses() - # WIP feature to enable caching in Gradient accumulation scenario. - self._enable_grad_acc_optimization = False self._gradient_accumulation_manager = GradientAccumulationManager() - # Memory-aware gradient builder. - self._use_memory_efficient_gradient = False - - # Enable compute optimizer by default. Allowed to be disabled via an environment variable for - # convergence parity investigation. - self._enable_compute_optimizer = ( - ortmodule._defined_from_envvar("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER", 1, warn=True) == 1 - ) - self._enable_sparse_optimizer = ( - self._enable_compute_optimizer - and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_SPARSE_OPTIMIZER", 1, warn=True) == 1 - ) - self._enable_embedding_sparse_optimizer = ( - self._enable_compute_optimizer - and self._enable_sparse_optimizer - and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER", 0, warn=True) == 1 - ) - - self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1 - self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1 - self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True) - # Flag to re-export the model due to attribute change on the original module. # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False @@ -219,23 +125,11 @@ class GraphExecutionManager(GraphExecutionInterface): # Load ATen operator executor extension. load_aten_op_executor_cpp_extension() - self._feature_map: List[List[str]] = [ - ["ATen Executor", "ON", "Dispatch ATen operators to ORT's ATen executor"], - [ - "Cast Propagation", - "ON" if self._propagate_cast_ops_level > 0 else "OFF", - f"Level {self._propagate_cast_ops_level} enabled", - ], - ["Custom Function", "ON", "Support custom torch.autograd.Function export and execution"], - [ - "Memory Optimizer", - "ON" if self._enable_memory_optimizer else "OFF", - "Enable with env ORTMODULE_MEMORY_OPT_CONFIG=", - ], - ] + # Assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True + self._get_torch_gpu_allocator_function_addresses() def _get_torch_gpu_allocator_function_addresses(self): - if self._use_external_gpu_allocator and torch.cuda.is_available(): + if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator @@ -272,7 +166,7 @@ class GraphExecutionManager(GraphExecutionInterface): pass def _build_graph(self, config): - if self._use_static_shape: + if self._runtime_options.use_static_shape: self._graph_builder.build(config, self._input_info.shape) else: self._graph_builder.build(config) @@ -294,14 +188,10 @@ class GraphExecutionManager(GraphExecutionInterface): provider_option_map = {"device_id": str(self._device.index)} if not self.is_rocm_pytorch: # Set Conv algo search mode to HEURISTIC by default, which is the same as PyTorch's default setting. - conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True) - if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]: - self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.") - conv_algo_search = "HEURISTIC" - provider_option_map["cudnn_conv_algo_search"] = conv_algo_search + provider_option_map["cudnn_conv_algo_search"] = self._runtime_options.conv_algo_search provider_option_map["cudnn_conv_use_max_workspace"] = "1" provider_option_map["cudnn_conv1d_pad_to_nc1d"] = "1" - if self._use_external_gpu_allocator: + if self._runtime_options.use_external_gpu_allocator: provider_option_map["gpu_external_alloc"] = str(self._torch_alloc) provider_option_map["gpu_external_free"] = str(self._torch_free) provider_option_map["gpu_external_empty_cache"] = str(self._torch_empty_cache) @@ -323,11 +213,13 @@ class GraphExecutionManager(GraphExecutionInterface): session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = int(self._debug_options.logging.log_level) - # Disable memory alleviation by default. Allow user to enable it via environment variable. - alleviation_config = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True) - probe_level = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", "1", warn=True) - session_options.add_session_config_entry("optimization.enable_memory_optimizer", alleviation_config) - session_options.add_session_config_entry("optimization.enable_memory_probe_recompute_level", probe_level) + + session_options.add_session_config_entry( + "optimization.enable_memory_optimizer", self._runtime_options.memory_optimizer_config + ) + session_options.add_session_config_entry( + "optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level + ) # Disable weight prepacking session_options.add_session_config_entry("session.disable_prepacking", "1") @@ -375,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface): self._export_mode, ) - if self._run_symbolic_shape_infer: + if self._runtime_options.run_symbolic_shape_infer: self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes( self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True ) @@ -433,7 +325,7 @@ class GraphExecutionManager(GraphExecutionInterface): required_export_kwargs = { "input_names": self._input_info.names, "output_names": output_names, - "opset_version": ortmodule.ONNX_OPSET_VERSION, + "opset_version": self._runtime_options.onnx_opset_version, "do_constant_folding": False, "training": self._export_mode, "dynamic_axes": self._input_info.dynamic_axes, @@ -462,8 +354,12 @@ class GraphExecutionManager(GraphExecutionInterface): ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export(exported_model, self._enable_custom_autograd_function) + exported_model = _post_process_after_export( + exported_model, self._runtime_options.enable_custom_autograd_function + ) + # If anything was captured by suppress_output during export, set the flag to + # raise a single user warning letting users know in the log. if suppress_output.tell() > 0: self._warning_log_detected_during_export = True @@ -486,10 +382,10 @@ class GraphExecutionManager(GraphExecutionInterface): def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration: graph_transformer_config = C.TrainingGraphTransformerConfiguration() graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration() - graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level - graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow - graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy - graph_transformer_config.enable_compute_optimizer = self._enable_compute_optimizer + graph_transformer_config.propagate_cast_ops_config.level = self._runtime_options.propagate_cast_ops_level + graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow + graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy + graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer return graph_transformer_config def _initialize_graph_builder(self): @@ -515,11 +411,11 @@ class GraphExecutionManager(GraphExecutionInterface): grad_builder_config.initializer_names_to_train = initializer_names_to_train grad_builder_config.input_names_require_grad = self._input_info.require_grad_names grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING - grad_builder_config.enable_caching = self._enable_grad_acc_optimization + grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( self._debug_options.logging.log_level ) - grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient + grad_builder_config.use_memory_efficient_gradient = self._runtime_options.use_memory_efficient_gradient self._graph_builder = C.OrtModuleGraphBuilder() # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way @@ -576,27 +472,14 @@ class GraphExecutionManager(GraphExecutionInterface): enable sparsity-based optimization. """ - self._feature_map.extend( - [ - [ - "Compute Optimizer", - "ON" if self._enable_compute_optimizer else "OFF", - "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", - ], - [ - " -FLOPReduction", - "ON" if self._enable_compute_optimizer else "OFF", - "Reduce FLOPs by upstreaming shrinking-sized ops", - ], - ] - ) + # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. - if self._enable_sparse_optimizer or self._print_input_density: - self._rt_inspector.enable_input_inspector( + if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: + self._runtime_inspector.enable_input_inspector( self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names ) - if self._enable_sparse_optimizer: + if self._runtime_options.enable_sparse_optimizer: detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( inputs, kwargs ) @@ -609,41 +492,31 @@ class GraphExecutionManager(GraphExecutionInterface): inputs, kwargs, detected_device, - self._rt_inspector, + self._runtime_inspector, ) # Enable sparsity-based optimization when applicable. if len(label_sparsity_results) > 0: graph_transformer_config.sparse_label_input_names = list(label_sparsity_results.keys()) self._logger.info("Label sparsity-based optimization is ON for %s", label_sparsity_results) - sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]) - self._feature_map.append( - [ - " -LabelSparsityOpt", - "ON", - f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0", - ] + self._runtime_options.label_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()] ) - if self._enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: + if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]) - self._feature_map.append( - [ - " -EmbedSparsityOpt", - "ON", - f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0", - ] + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] ) # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. - if not self._print_input_density: - self._rt_inspector.disable_input_inspector() + if not self._runtime_options.print_input_density: + self._runtime_inspector.disable_input_inspector() - if self._print_memory_stat: - self._rt_inspector.enable_memory_inspector(self._original_module) + if self._runtime_options.print_memory_stat: + self._runtime_inspector.enable_memory_inspector(self._original_module) def _log_feature_stats(self): rank = 0 @@ -653,12 +526,57 @@ class GraphExecutionManager(GraphExecutionInterface): if rank != 0: return - self._feature_map.append( - [ + feature_map: List[Tuple[str, bool, str]] = [ + ("ATen Executor", True, "Dispatch ATen operators to ORT's ATen executor"), + ( + "Cast Propagation", + self._runtime_options.propagate_cast_ops_level > 0, + f"Level {self._runtime_options.propagate_cast_ops_level} enabled", + ), + ( + "Custom Function", + self._runtime_options.enable_custom_autograd_function, + "Support custom torch.autograd.Function export and execution", + ), + ( + "Memory Optimizer", + len(self._runtime_options.memory_optimizer_config) > 0, + "Enable with env ORTMODULE_MEMORY_OPT_CONFIG=", + ), + ] + + if self._runtime_options.enable_compute_optimizer: + feature_map.extend( + [ + ( + "Compute Optimizer", + self._runtime_options.enable_compute_optimizer, + "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", + ), + ( + " -FLOPReduction", + self._runtime_options.enable_compute_optimizer, + "Reduce FLOPs by upstreaming shrinking-sized ops", + ), + ] + ) + + if len(self._runtime_options.label_sparsity_ratio) > 0: + feature_map.append( + (" -LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}") + ) + + if len(self._runtime_options.embed_sparsity_ratio) > 0: + feature_map.append( + (" -EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}") + ) + + feature_map.append( + ( "Auto Fallback", - "ON" if self._fallback_manager.policy is not _FallbackPolicy.FALLBACK_DISABLE else "OFF", + self._runtime_options.fallback_policy is not _FallbackPolicy.FALLBACK_DISABLE, "Fallback to PyTorch when encountering unsupported ops", - ] + ) ) mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference" @@ -666,11 +584,10 @@ class GraphExecutionManager(GraphExecutionInterface): stat = f"\n\n{_logger.LogColor.HEADER}***** ONNX Runtime Training (ORTModule) is accelerating your model *****{_logger.LogColor.ENDC}\n\n" stat += f"ORTModule is enabled with following features ON/OFF for [{mode}] mode:\n\n" - for feature_tuple in self._feature_map: - stat += f"{feature_tuple[0]:<20}:\t{feature_tuple[1]:<10}:\t{feature_tuple[2]:<80}\n" + for feature_tuple in feature_map: + switch_str = "ON" if feature_tuple[1] else "OFF" + stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n" - # If anything was captured in fo, raise a single user warning letting users know that there was - # any warning or error that was raised stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n" stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n" stat += f"Export duration: {self._export_duration_in_ms:.0f} milliseconds\n" diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py index 7e86b487b5..1b09eafee9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py @@ -10,7 +10,7 @@ from ._fallback import _FallbackManager from ._inference_manager import InferenceManager from ._io import _FlattenedModule from ._training_manager import TrainingManager -from .debug_options import DebugOptions +from .options import DebugOptions class GraphExecutionManagerFactory: diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 1937ce5184..7fe11c1510 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -14,8 +14,8 @@ from onnxruntime.capi import _pybind_state as C from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils from ._execution_agent import InferenceAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy -from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck -from .debug_options import DebugOptions +from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo +from .options import DebugOptions, _SkipCheck class InferenceManager(GraphExecutionManager): @@ -92,21 +92,21 @@ class InferenceManager(GraphExecutionManager): try: # Issue at most one warning message about fast path - if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False: + if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False: self._first_skip_check_warning = False self._logger.warning( "Fast path enabled - skipping checks. rebuild gradient graph: %s, execution agent recreation: %s, " "device check: %s", - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_graph = False if ( - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): # Exporting module to ONNX for the first time @@ -129,7 +129,10 @@ class InferenceManager(GraphExecutionManager): # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent: + if ( + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False + or not self._execution_agent + ): module_device = _utils.get_device_from_module(self._original_module) create_execution_session = ( @@ -146,7 +149,7 @@ class InferenceManager(GraphExecutionManager): # Create execution session creates the inference_session self._create_execution_agent() - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: + if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) @@ -158,7 +161,7 @@ class InferenceManager(GraphExecutionManager): inputs, kwargs, self._device, - self._rt_inspector, + self._runtime_inspector, ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py index 6891361d74..5a1b09803c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_factory.py @@ -6,7 +6,7 @@ from logging import Logger from ._fallback import _FallbackManager from ._torch_module_ort import TorchModuleORT -from .debug_options import DebugOptions +from .options import DebugOptions class TorchModuleFactory: diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index 449c61c694..dba68be944 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -12,7 +12,7 @@ from . import _io, _utils from ._fallback import ORTModuleTorchModelException, _FallbackManager, wrap_exception from ._graph_execution_manager_factory import GraphExecutionManagerFactory from ._torch_module_interface import TorchModuleInterface -from .debug_options import DebugOptions +from .options import DebugOptions T = TypeVar("T", bound="torch.nn.Module") diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 10edadf8f9..42901d34a7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -16,10 +16,10 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg from ._execution_agent import TrainingAgent from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy from ._gradient_accumulation_manager import GradientAccumulationManager -from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck +from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo from ._io import _FlattenedModule, _InputInfo from ._runtime_inspector import Phase -from .debug_options import DebugOptions +from .options import DebugOptions, _SkipCheck class TrainingManager(GraphExecutionManager): @@ -104,9 +104,9 @@ class TrainingManager(GraphExecutionManager): Module outputs are returned to the user """ - self._rt_inspector.inspect_memory(Phase.PRE_FORWARD) + self._runtime_inspector.inspect_memory(Phase.PRE_FORWARD) - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: + if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) @@ -139,7 +139,7 @@ class TrainingManager(GraphExecutionManager): for idx in self._graph_info.output_grad_indices_non_differentiable: ctx.mark_non_differentiable(user_outputs[idx]) - self._rt_inspector.inspect_memory(Phase.POST_FORWARD) + self._runtime_inspector.inspect_memory(Phase.POST_FORWARD) return user_outputs @@ -147,10 +147,10 @@ class TrainingManager(GraphExecutionManager): def backward(ctx, *grad_outputs): """Performs backward pass based on grad wrt module output""" - self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD) + self._runtime_inspector.inspect_memory(Phase.PRE_BACKWARD) assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: + if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) # Unpack saved_tensor to trigger version detection that catches inplace corruption @@ -198,7 +198,7 @@ class TrainingManager(GraphExecutionManager): # This version only works if backward_outputs is an OrtValueVector. transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - self._rt_inspector.inspect_memory(Phase.POST_BACKWARD) + self._runtime_inspector.inspect_memory(Phase.POST_BACKWARD) return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) @@ -226,21 +226,21 @@ class TrainingManager(GraphExecutionManager): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) try: - if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False: + if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False: # Only change this after the firs time a warning is issued. self._first_skip_check_warning = False self._logger.info( "Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s", - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_gradient_graph = False if ( - self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): build_gradient_graph = self._export_model(*inputs, **kwargs) @@ -274,7 +274,10 @@ class TrainingManager(GraphExecutionManager): # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent: + if ( + self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False + or not self._execution_agent + ): device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( inputs, kwargs ) @@ -292,7 +295,7 @@ class TrainingManager(GraphExecutionManager): self._create_execution_agent() self._gradient_accumulation_manager.initialize( - self._enable_grad_acc_optimization, self._flattened_module, self._graph_info + self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info ) self._gradient_accumulation_manager.maybe_update_cache_before_run() @@ -305,7 +308,7 @@ class TrainingManager(GraphExecutionManager): inputs, kwargs, self._device, - self._rt_inspector, + self._runtime_inspector, ) return _io.unflatten_user_output( diff --git a/orttraining/orttraining/python/training/ortmodule/debug_options.py b/orttraining/orttraining/python/training/ortmodule/debug_options.py deleted file mode 100644 index 9360b53b11..0000000000 --- a/orttraining/orttraining/python/training/ortmodule/debug_options.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# debug_options.py - -import os - -from ._logger import LogLevel - - -class _SaveOnnxOptions: - """Configurable option to save ORTModule intermediate onnx models.""" - - # class variable - _path_environment_key = "ORTMODULE_SAVE_ONNX_PATH" - - def __init__(self, save, name_prefix): - self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix) - - def _extract_info(self, save, name_prefix): - # get the destination path from os env variable - destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd()) - # perform validation only when save is True - if save: - self._validate(save, name_prefix, destination_path) - return save, name_prefix, destination_path - - def _validate(self, save, name_prefix, destination_path): - # check if directory is writable - if not os.access(destination_path, os.W_OK): - raise OSError( - f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path." - ) - - # check if input prefix is a string - if not isinstance(name_prefix, str): - raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.") - - # if save_onnx is set, save_onnx_prefix must be a non empty string - if not name_prefix: - raise ValueError("onnx_prefix must be provided when save_onnx is set.") - - @property - def save(self): - return self._save - - @property - def name_prefix(self): - return self._name_prefix - - @property - def path(self): - return self._path - - -class _LoggingOptions: - """Configurable option to set the log level in ORTModule.""" - - # class variable - _log_level_environment_key = "ORTMODULE_LOG_LEVEL" - - def __init__(self, log_level): - self._log_level = self._extract_info(log_level) - - def _extract_info(self, log_level): - # get the log_level from os env variable - # OS environment variable log level superseeds the locally provided one - self._validate(log_level) - log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)] - return log_level - - def _validate(self, log_level): - # check if log_level is an instance of LogLevel - if not isinstance(log_level, LogLevel): - raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.") - - @property - def log_level(self) -> LogLevel: - return self._log_level - - -class DebugOptions: - """Configurable debugging options for ORTModule. - - Args: - log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING. - log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of - "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable - takes precedence. - save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False. - The output directory of the onnx models by default is set to the current working directory. - To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be - set to the destination directory path. - onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names. - Must be provided if save_onnx is True - - Raises: - OSError: If save_onnx is True and output directory is not writable. - TypeError: If save_onnx is True and name_prefix is not a valid string. Or if - log_level is not an instance of LogLevel. - ValueError: If save_onnx is True and name_prefix is an empty string. - - """ - - def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=""): - self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix) - self._logging = _LoggingOptions(log_level) - - @property - def save_onnx_models(self): - """Accessor for the ONNX saving configuration.""" - - return self._save_onnx_models - - @property - def logging(self): - """Accessor for the logging configuration.""" - - return self._logging diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index cb1715ccd8..993ba915ed 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# debug_options.py +# _hierarchical_ortmodule.py import tempfile import warnings import torch -from .... import ortmodule -from ... import ORTModule -from ...debug_options import DebugOptions, LogLevel +from onnxruntime.training import ortmodule +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel # nn.Module's in this set are considered exportable to ONNX. # For other nn.Module's, torch.onnx.export is called to check if diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py index 2f1451497f..6ddb159d18 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py @@ -9,11 +9,9 @@ from functools import reduce from types import SimpleNamespace from onnxruntime.capi import _pybind_state as C -from onnxruntime.training import ortmodule +from onnxruntime.training.ortmodule._fallback import _FallbackPolicy +from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel, _SaveOnnxOptions, _SkipCheck -from ..._fallback import _FallbackPolicy -from ..._graph_execution_manager import _SkipCheck -from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions from . import JSON_PATH_ENVIRONMENT_KEY log = logging.getLogger(__name__) @@ -39,15 +37,15 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data): log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.") def _update_strategy(): - ortmodule_config_accessor._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[ + ortmodule_config_accessor._runtime_options.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[ data.PropagateCastOps.Strategy ] def _update_level(): - ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level + ortmodule_config_accessor._runtime_options.propagate_cast_ops_level = data.PropagateCastOps.Level def _update_allow(): - ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow + ortmodule_config_accessor._runtime_options.propagate_cast_ops_allow = data.PropagateCastOps.Allow key_to_function_mapping = {"Strategy": _update_strategy, "Level": _update_level, "Allow": _update_allow} @@ -64,8 +62,7 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data): assert isinstance( data.UseExternalGPUAllocator, bool ), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean" - ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator - ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses() + ortmodule_config_accessor._runtime_options.use_external_gpu_allocator = data.UseExternalGPUAllocator def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): @@ -79,7 +76,11 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): assert isinstance( data.EnableCustomAutogradFunction, bool ), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean" - ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction + + from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support + + enable_custom_autograd_support(data.EnableCustomAutogradFunction) + ortmodule_config_accessor._runtime_options.enable_custom_autograd_function = data.EnableCustomAutogradFunction def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): @@ -91,7 +92,7 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): assert isinstance( data.EnableGradAccOptimization, bool ), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean" - ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization + ortmodule_config_accessor._runtime_options.enable_grad_acc_optimization = data.EnableGradAccOptimization def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): @@ -103,7 +104,7 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): assert isinstance( data.RunSymbolicShapeInference, bool ), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean" - ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference + ortmodule_config_accessor._runtime_options.run_symbolic_shape_infer = data.RunSymbolicShapeInference def _load_use_static_shape(ortmodule_config_accessor, data): @@ -113,7 +114,7 @@ def _load_use_static_shape(ortmodule_config_accessor, data): log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.") assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean" - ortmodule_config_accessor._use_static_shape = data.UseStaticShape + ortmodule_config_accessor._runtime_options.use_static_shape = data.UseStaticShape def _load_skip_check(ortmodule_config_accessor, data): @@ -124,7 +125,7 @@ def _load_skip_check(ortmodule_config_accessor, data): skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck]) if skip_check.value > 0: - ortmodule_config_accessor._skip_check = skip_check + ortmodule_config_accessor._runtime_options.skip_check = skip_check def _load_debug_options(ortmodule_config_accessor, data): @@ -177,7 +178,7 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data): assert isinstance( data.UseMemoryEfficientGradient, bool ), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean" - ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient + ortmodule_config_accessor._runtime_options.use_memory_efficient_gradient = data.UseMemoryEfficientGradient def _load_fallback_policy(ortmodule_config_accessor, data): @@ -198,7 +199,7 @@ def _load_onnx_opset_version(ortmodule_config_accessor, data): log.info(f"Found keyword {_load_onnx_opset_version.loading_key} in json. Loading attributes from file.") assert isinstance(data.OnnxOpsetVersion, int), f"{_load_onnx_opset_version.loading_key} must be an int" - ortmodule.ONNX_OPSET_VERSION = data.OnnxOpsetVersion + ortmodule_config_accessor._runtime_options.onnx_opset_version = data.OnnxOpsetVersion def _define_load_function_keys(): diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py new file mode 100644 index 0000000000..b61c59c4ba --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -0,0 +1,287 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# options.py + +import os +from enum import IntFlag +from functools import reduce +from logging import Logger + +from onnxruntime.capi import _pybind_state as C +from onnxruntime.training import ortmodule + +from ._fallback import _FallbackPolicy +from ._logger import LogLevel +from ._utils import parse_os_env_skip_check_flags + + +class _SaveOnnxOptions: + """Configurable option to save ORTModule intermediate onnx models.""" + + # class variable + _path_environment_key = "ORTMODULE_SAVE_ONNX_PATH" + + def __init__(self, save, name_prefix, path: str): + self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix, path) + + def _extract_info(self, save, name_prefix, path: str): + # get the destination path from os env variable + default_path = path if len(path) > 0 else os.getcwd() + destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, default_path) + # perform validation only when save is True + if save: + self._validate(save, name_prefix, destination_path) + return save, name_prefix, destination_path + + def _validate(self, save, name_prefix, destination_path): + # check if directory is writable + if not os.access(destination_path, os.W_OK): + raise OSError( + f"Directory {destination_path} is not writable. Please set the " + f"{_SaveOnnxOptions._path_environment_key} environment variable to a writable path." + ) + + # check if input prefix is a string + if not isinstance(name_prefix, str): + raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.") + + # if save_onnx is set, save_onnx_prefix must be a non empty string + if not name_prefix: + raise ValueError("onnx_prefix must be provided when save_onnx is set.") + + @property + def save(self): + return self._save + + @property + def name_prefix(self): + return self._name_prefix + + @property + def path(self): + return self._path + + +class _LoggingOptions: + """Configurable option to set the log level in ORTModule.""" + + # class variable + _log_level_environment_key = "ORTMODULE_LOG_LEVEL" + + def __init__(self, log_level): + self._log_level = self._extract_info(log_level) + + def _extract_info(self, log_level): + # get the log_level from os env variable + # OS environment variable log level superseeds the locally provided one + self._validate(log_level) + log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)] + return log_level + + def _validate(self, log_level): + # check if log_level is an instance of LogLevel + if not isinstance(log_level, LogLevel): + raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.") + + @property + def log_level(self) -> LogLevel: + return self._log_level + + +class DebugOptions: + """Configurable debugging options for ORTModule. + + Args: + log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING. + log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of + "VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable + takes precedence. + save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False. + The output directory of the onnx models by default is set to the current working directory. + To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be + set to the destination directory path. + onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names. + Must be provided if save_onnx is True + + Raises: + OSError: If save_onnx is True and output directory is not writable. + TypeError: If save_onnx is True and name_prefix is not a valid string. Or if + log_level is not an instance of LogLevel. + ValueError: If save_onnx is True and name_prefix is an empty string. + + """ + + def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix="", save_path="", config=None): + self.log_level = log_level + self.save_onnx = save_onnx + self.onnx_prefix = onnx_prefix + + self._save_onnx_models = _SaveOnnxOptions(self.save_onnx, self.onnx_prefix, save_path) + self._logging = _LoggingOptions(self.log_level) + + @property + def save_onnx_models(self): + """Accessor for the ONNX saving configuration.""" + + return self._save_onnx_models + + @property + def logging(self): + """Accessor for the logging configuration.""" + + return self._logging + + +class _SkipCheck(IntFlag): + """Enumeration to specify which checks should be skipped, allowing faster execution""" + + SKIP_CHECK_DISABLED = 1 + SKIP_CHECK_DEVICE = 2 + SKIP_CHECK_BUILD_GRADIENT = 4 + SKIP_CHECK_EXECUTION_AGENT = 8 + + def is_set(self, check): + """Check whether `check` is set on the `_SkipCheck instance + + SKIP_CHECK_DISABLED implies the check will return False + """ + + return not _SkipCheck.is_disabled(self) and check in self + + def is_disabled(self): + """Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance""" + + return _SkipCheck.SKIP_CHECK_DISABLED in self + + +class _RuntimeOptions: + """Configurable runtime options for ORTModule.""" + + def __init__(self, logger: Logger): + """Constructor for Options. + + Initially set all the options to their default values, then override them with the values + from the environment variables. + """ + self._logger = logger + + self.onnx_opset_version = ortmodule.ONNX_OPSET_VERSION + self.conv_algo_search = "HEURISTIC" + + # Configuration for cast optimization. + # Specify cast propagation strategy. Currently, three strategies are available: + # NONE, INSERT-AND-REDUCE and FLOOD-FILL + # The default is FLOOD_FILL, expand FP16 computation regions in the graph using + # allowed opcodes for the given level. + self.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL + # Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. + # - If the propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes + # specified by propagate_cast_ops_allow as "FP16 safe", to insert/(re)move cast operations before/after + # to perform such operations in reduced (16-bit) precision. + # - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by + # propagate_cast_ops_allow, use onnxruntime predetermined list of opcodes considered safe to move + # before/after the cast operation. + # - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform + # any computation such as Transpose, Split, Reshape, etc., or the computation is actually in Float + # such as GeLU, etc. + # - Whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using + # contrib ops, Dropout, LayerNormalization, etc. + self.propagate_cast_ops_level = 1 + # List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level + # is zero. + self.propagate_cast_ops_allow = [] + + # default execution order is priority-based for both dynamic/static shape input for now + # if we observe the benefit of static shape, we can expose this flag to the user + self.use_static_shape = False + + # flag to enable symbolic shape inference for dynamic shape inputs to improve performance + self.run_symbolic_shape_infer = True + + # PyTorch custom Autograd function support + from ._custom_autograd_function import custom_autograd_function_enabler + + self.enable_custom_autograd_function = custom_autograd_function_enabler.state + + self.use_external_gpu_allocator = True + + # WIP feature to enable caching in Gradient accumulation scenario. + self.enable_grad_acc_optimization = False + + # Memory-aware gradient builder. + self.use_memory_efficient_gradient = False + + # Configuration for compute optimization. + self.enable_compute_optimizer = True + self.enable_sparse_optimizer = True + self.label_sparsity_ratio = "" + self.embed_sparsity_ratio = "" + self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. + + # Configuration for memory optimization. + self.memory_optimizer_config = "" + self.probe_level = "1" + + # Configuration for dev tools. + self.print_input_density = False + self.print_memory_stat = False + + # Configuration for fallback. + self.fallback_policy = ortmodule.ORTMODULE_FALLBACK_POLICY + + # Configuration for skip check. + # Indicators of some logic have been executed previously and thus could be skipped for faster training + # default is enabled, if not defined in os env + self.skip_check = _SkipCheck( + _SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT + ) + + # Override the feature config if it exists in os env. + self._override_from_env_vars() + + def _override_from_env_vars(self): + self.onnx_opset_version = int(os.getenv("ORTMODULE_ONNX_OPSET_VERSION", self.onnx_opset_version)) + self.conv_algo_search = os.getenv("ORTMODULE_CONV_ALGO_SEARCH", self.conv_algo_search) + if self.conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]: + self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.") + self.conv_algo_search = "HEURISTIC" + + # Configuration for compute optimization. + compute_optimizer_reset = False + if "ORTMODULE_ENABLE_COMPUTE_OPTIMIZER" in os.environ: + self.enable_compute_optimizer = int(os.getenv("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER")) == 1 + compute_optimizer_reset = True + + if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ or compute_optimizer_reset: + if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ: + self.enable_sparse_optimizer = int(os.getenv("ORTMODULE_ENABLE_SPARSE_OPTIMIZER")) == 1 + self.enable_sparse_optimizer = self.enable_compute_optimizer and self.enable_sparse_optimizer + + # TODO(pengwa): remove once validation on more models are done. + if "ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER" in os.environ: + self.enable_embedding_sparse_optimizer = ( + self.enable_sparse_optimizer and int(os.getenv("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER")) == 1 + ) + + # Configuration for memory optimization. + self.memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) + self.probe_level = os.getenv("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", self.probe_level) + + # Configuration for dev tools. + if "ORTMODULE_PRINT_INPUT_DENSITY" in os.environ: + self.print_input_density = int(os.getenv("ORTMODULE_PRINT_INPUT_DENSITY")) == 1 + if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ: + self.print_memory_stat = int(os.getenv("ORTMODULE_PRINT_MEMORY_STATS")) == 1 + + # Configuration for fallback. + if "ORTMODULE_FALLBACK_POLICY" in os.environ: + self.fallback_policy = os.getenv("ORTMODULE_FALLBACK_POLICY") + if isinstance(self.fallback_policy, str): + self.fallback_policy = _FallbackPolicy[self.fallback_policy] + + # Configuration for skip check. + if "ORTMODULE_SKIPCHECK_POLICY" in os.environ: + self.skip_check = reduce( + lambda x, y: x | y, + [_SkipCheck[name] for name in parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")], + ) diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 1af50d6811..7978695f0e 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -10,7 +10,7 @@ from ._torch_module_ort import TorchModuleORT from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry from ._custom_gradient_registry import CustomGradientRegistry from . import _utils -from .debug_options import DebugOptions +from .options import DebugOptions from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException from ._logger import ortmodule_loglevel_to_python_loglevel from onnxruntime.training import ortmodule @@ -77,7 +77,9 @@ class ORTModule(torch.nn.Module): # Support contrib OPs pytorch_export_contrib_ops.register() - CustomOpSymbolicRegistry.register_all() + CustomOpSymbolicRegistry.register_all( + self._torch_module._execution_manager(module.training)._runtime_options.onnx_opset_version + ) CustomGradientRegistry.register_all() # Warn user if there are name collisions between user model's and ORTModule attributes @@ -320,7 +322,9 @@ class ORTModule(torch.nn.Module): # Re-register contrib OPs pytorch_export_contrib_ops.register() - CustomOpSymbolicRegistry.register_all() + CustomOpSymbolicRegistry.register_all( + self._torch_module._execution_manager(self.module.training)._runtime_options.onnx_opset_version + ) CustomGradientRegistry.register_all() # Re-initialize the ORTModule forward method diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 3c1887f834..d7eadc56db 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -30,16 +30,9 @@ from transformers.modeling_outputs import SequenceClassifierOutput import onnxruntime.training.ortmodule as ortmodule_module from onnxruntime.training.optim import AdamWMode, FusedAdam -from onnxruntime.training.ortmodule import ( - DebugOptions, - LogLevel, - ORTModule, - _fallback, - _graph_execution_manager, - _io, - _utils, -) +from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient +from onnxruntime.training.ortmodule.options import _SkipCheck DEFAULT_OPSET = 15 @@ -4454,7 +4447,7 @@ def test_ortmodule_gradient_accumulation_optimization_correctness(): # model with optimization enabled opt_model = ORTModule(copy.deepcopy(pt_model)) - opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True + opt_model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True opt_optimizer = torch.optim.Adam(opt_model.parameters()) def run_step(model, x): @@ -4879,10 +4872,10 @@ def test_ortmodule_ortmodule_method_attribute_copy(): @pytest.mark.parametrize( "policy_str, policy", [ - ("SKIP_CHECK_DISABLED", _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED), - ("SKIP_CHECK_DEVICE", _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE), - ("SKIP_CHECK_BUILD_GRADIENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT), - ("SKIP_CHECK_EXECUTION_AGENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT), + ("SKIP_CHECK_DISABLED", _SkipCheck.SKIP_CHECK_DISABLED), + ("SKIP_CHECK_DEVICE", _SkipCheck.SKIP_CHECK_DEVICE), + ("SKIP_CHECK_BUILD_GRADIENT", _SkipCheck.SKIP_CHECK_BUILD_GRADIENT), + ("SKIP_CHECK_EXECUTION_AGENT", _SkipCheck.SKIP_CHECK_EXECUTION_AGENT), ], ) def test_ortmodule_skip_check_load_from_os_env(policy_str, policy): @@ -4893,7 +4886,7 @@ def test_ortmodule_skip_check_load_from_os_env(policy_str, policy): ort_model = ORTModule(model) for training_mode in [False, True]: - assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy + assert ort_model._torch_module._execution_manager(training_mode)._runtime_options.skip_check == policy del os.environ["ORTMODULE_SKIPCHECK_POLICY"] @@ -5227,10 +5220,8 @@ def test_sigmoid_grad_opset13(): N, D_in, H, D_out = 120, 15360, 500, 15360 # noqa: N806 pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device) - old_opst_cst = ortmodule_module.ONNX_OPSET_VERSION old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None) os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13" - assert ortmodule_module.ONNX_OPSET_VERSION == 15 ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -5256,21 +5247,25 @@ def test_sigmoid_grad_opset13(): del os.environ["ORTMODULE_ONNX_OPSET_VERSION"] else: os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset - assert ortmodule_module.ONNX_OPSET_VERSION == 13 - ortmodule_module.ONNX_OPSET_VERSION = old_opst_cst + + assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13 @pytest.mark.parametrize("opset_version", [12, 13, 14, 15]) def test_opset_version_change(opset_version): + original_env = None + if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ: + original_env = os.environ["ORTMODULE_ONNX_OPSET_VERSION"] + del os.environ["ORTMODULE_ONNX_OPSET_VERSION"] + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806 x = torch.randn(N, D_in, device=device) model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - ort_model = ORTModule(model) - ortmodule_module.ONNX_OPSET_VERSION = opset_version + ort_model = ORTModule(model) # Make sure model runs without any exception prediction = ort_model(x) @@ -5282,6 +5277,9 @@ def test_opset_version_change(opset_version): exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model assert exported_model.opset_import[0].version == opset_version + if original_env is not None: + os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = original_env + def test_serialize_ortmodule(): device = "cuda" diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index 2eeb4f68bb..87c8e66231 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -420,7 +420,7 @@ def main(): ) model = ORTModule(model, debug_options) - model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True + model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py index cc46a2db95..4350740690 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py @@ -38,27 +38,29 @@ def test_load_config_from_json_1(): ort_model_attributes = model._torch_module._execution_manager(training_mode) # test propagate cast ops - assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL - assert ort_model_attributes._propagate_cast_ops_level == 3 - assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"] + assert ( + ort_model_attributes._runtime_options.propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL + ) + assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 3 + assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["ABC", "DEF"] # test use external gpu allocator - assert ort_model_attributes._use_external_gpu_allocator is False + assert ort_model_attributes._runtime_options.use_external_gpu_allocator is False # test enable custom autograd function - assert ort_model_attributes._enable_custom_autograd_function is True + assert ort_model_attributes._runtime_options.enable_custom_autograd_function is True # test use static shape - assert ort_model_attributes._use_static_shape is True + assert ort_model_attributes._runtime_options.use_static_shape is True # test run symbolic shape inference - assert ort_model_attributes._run_symbolic_shape_infer is False + assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is False # test enable grad acc optimization - assert ort_model_attributes._enable_grad_acc_optimization is True + assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is True # test skip check - assert ort_model_attributes._skip_check.value == 14 + assert ort_model_attributes._runtime_options.skip_check.value == 14 # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save is True @@ -66,13 +68,13 @@ def test_load_config_from_json_1(): assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE" # test use memory aware gradient builder. - assert ort_model_attributes._use_memory_efficient_gradient is False + assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is False # test fallback policy assert ort_model_attributes._fallback_manager.policy.value == 1 # assert onnx opset version - assert ortmodule.ONNX_OPSET_VERSION == 13 + assert ort_model_attributes._runtime_options.onnx_opset_version == 13 def test_load_config_from_json_2(): @@ -91,27 +93,30 @@ def test_load_config_from_json_2(): ort_model_attributes = model._torch_module._execution_manager(training_mode) # test propagate cast ops - assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE - assert ort_model_attributes._propagate_cast_ops_level == 5 - assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"] + assert ( + ort_model_attributes._runtime_options.propagate_cast_ops_strategy + == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE + ) + assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 5 + assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["XYZ", "PQR"] # test use external gpu allocator - assert ort_model_attributes._use_external_gpu_allocator is True + assert ort_model_attributes._runtime_options.use_external_gpu_allocator is True # test enable custom autograd function - assert ort_model_attributes._enable_custom_autograd_function is False + assert ort_model_attributes._runtime_options.enable_custom_autograd_function is False # test use static shape - assert ort_model_attributes._use_static_shape is False + assert ort_model_attributes._runtime_options.use_static_shape is False # test run symbolic shape inference - assert ort_model_attributes._run_symbolic_shape_infer is True + assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is True # test enable grad acc optimization - assert ort_model_attributes._enable_grad_acc_optimization is False + assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is False # test skip check - assert ort_model_attributes._skip_check.value == 10 + assert ort_model_attributes._runtime_options.skip_check.value == 10 # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save is True @@ -119,10 +124,10 @@ def test_load_config_from_json_2(): assert ort_model_attributes._debug_options.logging.log_level.name == "INFO" # test use memory aware gradient builder. - assert ort_model_attributes._use_memory_efficient_gradient is True + assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is True # test fallback policy assert ort_model_attributes._fallback_manager.policy.value == 250 # assert onnx opset version - assert ortmodule.ONNX_OPSET_VERSION == 12 + assert ort_model_attributes._runtime_options.onnx_opset_version == 12