Manage ORTModule configurations consistently (#16396)

### Manage ORTModule options

Move all env vars that used for feature ON/OFF into runtime options for
consistent managements.


Be noted: the features' switch are assigned in 2 phases: default values,
overwritten by env vars (if specified by users). So env vars take the
highest priority when all 2 phases both given value explicitly for one
feature.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-06-27 19:19:36 +08:00 committed by GitHub
parent 403bebfb51
commit a49bb85cfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 506 additions and 407 deletions

View file

@ -47,7 +47,7 @@ More options for **developers**.
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
```
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details.
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/options.py) for more details.
### 2.1 Environment Variables
@ -99,12 +99,13 @@ The output directory of the onnx models by default is set to the current working
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
#### ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT
#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
```bash
export ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1 # Enable
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # Disable
```
An alternative to disable without using environment variable:

View file

@ -38,7 +38,7 @@ def export_gradient_graph(
"""
# Make sure that loss nodes that expect multiple outputs are set up.
CustomOpSymbolicRegistry.register_all()
CustomOpSymbolicRegistry.register_all(opset_version)
if not isinstance(gradient_graph_path, str):
gradient_graph_path = str(gradient_graph_path)

View file

@ -120,7 +120,7 @@ def _are_deterministic_algorithms_enabled():
return ORTMODULE_IS_DETERMINISTIC
from .debug_options import DebugOptions, LogLevel # noqa: E402, F401
from .options import DebugOptions, LogLevel # noqa: E402, F401
# ORTModule must be loaded only after all validation passes
from .ortmodule import ORTModule # noqa: E402, F401

View file

@ -3,6 +3,9 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from onnxruntime.capi._pybind_state import is_torch_interop_default_on
from onnxruntime.training import ortmodule
class Enabler:
def __init__(self):
@ -83,10 +86,7 @@ def enable_custom_autograd_support(to_enable=True):
custom_autograd_function_enabler.state = False
from onnxruntime.capi._pybind_state import is_torch_interop_default_on # noqa: E402
from onnxruntime.training import ortmodule # noqa: E402
# Enable the custom autograd by default when PythonOp backend support is enabled during build.
enable_custom_autograd_support(
not ortmodule._defined_from_envvar("ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT", 0) and is_torch_interop_default_on()
ortmodule._defined_from_envvar("ORTMODULE_ENABLE_CUSTOM_AUTOGRAD", 1) and is_torch_interop_default_on()
)

View file

@ -12,8 +12,6 @@ from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
from onnxruntime.training import ortmodule
from ._utils import get_runtime_pytorch_version
# Mapping from pytorch scalar type to onnx scalar type.
@ -97,13 +95,13 @@ class CustomOpSymbolicRegistry:
cls._SYMBOLICS[domain + "::" + name] = fn
@classmethod
def register_all(cls):
def register_all(cls, onnx_opset_version):
for name, fn in cls._SYMBOLICS.items():
# Symbolic name is in format: domain::name
register_custom_op_symbolic(
name,
fn,
ortmodule._defined_from_envvar("ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION),
onnx_opset_version,
)

View file

@ -9,8 +9,6 @@ import io
import logging
import os
from abc import ABC, abstractmethod # noqa: F401
from enum import IntFlag
from functools import reduce
from typing import Dict, List, Optional, Tuple
import onnx
@ -20,9 +18,8 @@ from torch.utils.cpp_extension import ROCM_HOME
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training import ortmodule
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _runtime_inspector, _utils
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
from ._custom_autograd_function_exporter import _post_process_after_export
from ._fallback import (
ORTModuleDeviceException,
@ -35,7 +32,8 @@ from ._fallback import (
from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_interface import GraphExecutionInterface
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
from .debug_options import DebugOptions, LogLevel
from ._runtime_inspector import RuntimeInspector
from .options import DebugOptions, LogLevel, _RuntimeOptions
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
@ -49,28 +47,6 @@ class _RunStateInfo:
self.output_info = output_info
class _SkipCheck(IntFlag):
"""Enumeration to specify which checks should be skipped, allowing faster execution"""
SKIP_CHECK_DISABLED = 1
SKIP_CHECK_DEVICE = 2
SKIP_CHECK_BUILD_GRADIENT = 4
SKIP_CHECK_EXECUTION_AGENT = 8
def is_set(self, check):
"""Check whether `check` is set on the `_SkipCheck instance
SKIP_CHECK_DISABLED implies the check will return False
"""
return not _SkipCheck.is_disabled(self) and check in self
def is_disabled(self):
"""Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance"""
return _SkipCheck.SKIP_CHECK_DISABLED in self
class GraphExecutionManager(GraphExecutionInterface):
def __init__(
self,
@ -89,6 +65,7 @@ class GraphExecutionManager(GraphExecutionInterface):
self._logger = logger
self._runtime_options = _RuntimeOptions(self._logger)
# Original and flattened (transformed) output module
self._flattened_module = module
@ -102,44 +79,12 @@ class GraphExecutionManager(GraphExecutionInterface):
self._graph_initializer_names_to_train = set()
self._graph_initializers: List[torch.nn.parameter.Parameter] = []
# Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION
# if defined.
ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar(
"ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION, warn=True
)
# TrainingAgent or InferenceAgent
self._execution_agent = None
# Indicators of some logic have been executed previously and thus could be skipped for faster training
# default is enabled, if not defined in os env
self._skip_check = _SkipCheck(
_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT
)
if os.getenv("ORTMODULE_SKIPCHECK_POLICY") is not None:
self._skip_check = reduce(
lambda x, y: x | y,
[_SkipCheck[name] for name in _utils.parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")],
)
self._first_skip_check_warning = True
self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger)
# Graph transformer config
# Specify cast propagation strategy. Currently, three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL
# The default is FLOOD_FILL, expand FP16 computation regions in the graph using allowed opcodes for the given level.
self._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL
# Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
# - If the _propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes specified by _propagate_cast_ops_allow
# as "FP16 safe", to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision.
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow, use onnxruntime
# predetermined list of opcodes considered safe to move before/after the cast operation.
# - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform any computation such as Transpose, Split, Reshape, etc.,
# or the computation is actually in Float such as GeLU, etc.
# whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc.
self._propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level is zero.
self._propagate_cast_ops_allow = []
self._runtime_inspector = RuntimeInspector(self._logger)
# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
# To be instantiated in the concrete implementation of GraphExecutionManager
@ -149,18 +94,6 @@ class GraphExecutionManager(GraphExecutionInterface):
# It cannot overlap with required/immutable arguments (validated in runtime)
self._export_extra_kwargs = {}
# default execution order is priority-based for both dynamic/static shape input for now
# if we observe the benefit of static shape, we can expose this flag to the user
self._use_static_shape = False
# flag to enable symbolic shape inference for dynamic shape inputs to improve performance
self._run_symbolic_shape_infer = True
# PyTorch custom Autograd function support
from ._custom_autograd_function import custom_autograd_function_enabler
self._enable_custom_autograd_function = custom_autograd_function_enabler.state
# Input and output infos (including schema) for exported model.
self._input_info: Optional[_InputInfo] = None
self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None
@ -182,36 +115,9 @@ class GraphExecutionManager(GraphExecutionInterface):
self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None)
self._use_external_gpu_allocator = True
# assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True
self._get_torch_gpu_allocator_function_addresses()
# WIP feature to enable caching in Gradient accumulation scenario.
self._enable_grad_acc_optimization = False
self._gradient_accumulation_manager = GradientAccumulationManager()
# Memory-aware gradient builder.
self._use_memory_efficient_gradient = False
# Enable compute optimizer by default. Allowed to be disabled via an environment variable for
# convergence parity investigation.
self._enable_compute_optimizer = (
ortmodule._defined_from_envvar("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER", 1, warn=True) == 1
)
self._enable_sparse_optimizer = (
self._enable_compute_optimizer
and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_SPARSE_OPTIMIZER", 1, warn=True) == 1
)
self._enable_embedding_sparse_optimizer = (
self._enable_compute_optimizer
and self._enable_sparse_optimizer
and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER", 0, warn=True) == 1
)
self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1
self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1
self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
# Flag to re-export the model due to attribute change on the original module.
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False
@ -219,23 +125,11 @@ class GraphExecutionManager(GraphExecutionInterface):
# Load ATen operator executor extension.
load_aten_op_executor_cpp_extension()
self._feature_map: List[List[str]] = [
["ATen Executor", "ON", "Dispatch ATen operators to ORT's ATen executor"],
[
"Cast Propagation",
"ON" if self._propagate_cast_ops_level > 0 else "OFF",
f"Level {self._propagate_cast_ops_level} enabled",
],
["Custom Function", "ON", "Support custom torch.autograd.Function export and execution"],
[
"Memory Optimizer",
"ON" if self._enable_memory_optimizer else "OFF",
"Enable with env ORTMODULE_MEMORY_OPT_CONFIG=<config>",
],
]
# Assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True
self._get_torch_gpu_allocator_function_addresses()
def _get_torch_gpu_allocator_function_addresses(self):
if self._use_external_gpu_allocator and torch.cuda.is_available():
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator
@ -272,7 +166,7 @@ class GraphExecutionManager(GraphExecutionInterface):
pass
def _build_graph(self, config):
if self._use_static_shape:
if self._runtime_options.use_static_shape:
self._graph_builder.build(config, self._input_info.shape)
else:
self._graph_builder.build(config)
@ -294,14 +188,10 @@ class GraphExecutionManager(GraphExecutionInterface):
provider_option_map = {"device_id": str(self._device.index)}
if not self.is_rocm_pytorch:
# Set Conv algo search mode to HEURISTIC by default, which is the same as PyTorch's default setting.
conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
conv_algo_search = "HEURISTIC"
provider_option_map["cudnn_conv_algo_search"] = conv_algo_search
provider_option_map["cudnn_conv_algo_search"] = self._runtime_options.conv_algo_search
provider_option_map["cudnn_conv_use_max_workspace"] = "1"
provider_option_map["cudnn_conv1d_pad_to_nc1d"] = "1"
if self._use_external_gpu_allocator:
if self._runtime_options.use_external_gpu_allocator:
provider_option_map["gpu_external_alloc"] = str(self._torch_alloc)
provider_option_map["gpu_external_free"] = str(self._torch_free)
provider_option_map["gpu_external_empty_cache"] = str(self._torch_empty_cache)
@ -323,11 +213,13 @@ class GraphExecutionManager(GraphExecutionInterface):
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = int(self._debug_options.logging.log_level)
# Disable memory alleviation by default. Allow user to enable it via environment variable.
alleviation_config = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
probe_level = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", "1", warn=True)
session_options.add_session_config_entry("optimization.enable_memory_optimizer", alleviation_config)
session_options.add_session_config_entry("optimization.enable_memory_probe_recompute_level", probe_level)
session_options.add_session_config_entry(
"optimization.enable_memory_optimizer", self._runtime_options.memory_optimizer_config
)
session_options.add_session_config_entry(
"optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level
)
# Disable weight prepacking
session_options.add_session_config_entry("session.disable_prepacking", "1")
@ -375,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface):
self._export_mode,
)
if self._run_symbolic_shape_infer:
if self._runtime_options.run_symbolic_shape_infer:
self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(
self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True
)
@ -433,7 +325,7 @@ class GraphExecutionManager(GraphExecutionInterface):
required_export_kwargs = {
"input_names": self._input_info.names,
"output_names": output_names,
"opset_version": ortmodule.ONNX_OPSET_VERSION,
"opset_version": self._runtime_options.onnx_opset_version,
"do_constant_folding": False,
"training": self._export_mode,
"dynamic_axes": self._input_info.dynamic_axes,
@ -462,8 +354,12 @@ class GraphExecutionManager(GraphExecutionInterface):
)
exported_model = onnx.load_model_from_string(f.getvalue())
exported_model = _post_process_after_export(exported_model, self._enable_custom_autograd_function)
exported_model = _post_process_after_export(
exported_model, self._runtime_options.enable_custom_autograd_function
)
# If anything was captured by suppress_output during export, set the flag to
# raise a single user warning letting users know in the log.
if suppress_output.tell() > 0:
self._warning_log_detected_during_export = True
@ -486,10 +382,10 @@ class GraphExecutionManager(GraphExecutionInterface):
def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration:
graph_transformer_config = C.TrainingGraphTransformerConfiguration()
graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration()
graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level
graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow
graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy
graph_transformer_config.enable_compute_optimizer = self._enable_compute_optimizer
graph_transformer_config.propagate_cast_ops_config.level = self._runtime_options.propagate_cast_ops_level
graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow
graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
return graph_transformer_config
def _initialize_graph_builder(self):
@ -515,11 +411,11 @@ class GraphExecutionManager(GraphExecutionInterface):
grad_builder_config.initializer_names_to_train = initializer_names_to_train
grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
grad_builder_config.enable_caching = self._enable_grad_acc_optimization
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(
self._debug_options.logging.log_level
)
grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient
grad_builder_config.use_memory_efficient_gradient = self._runtime_options.use_memory_efficient_gradient
self._graph_builder = C.OrtModuleGraphBuilder()
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
@ -576,27 +472,14 @@ class GraphExecutionManager(GraphExecutionInterface):
enable sparsity-based optimization.
"""
self._feature_map.extend(
[
[
"Compute Optimizer",
"ON" if self._enable_compute_optimizer else "OFF",
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
],
[
" -FLOPReduction",
"ON" if self._enable_compute_optimizer else "OFF",
"Reduce FLOPs by upstreaming shrinking-sized ops",
],
]
)
# Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density.
if self._enable_sparse_optimizer or self._print_input_density:
self._rt_inspector.enable_input_inspector(
if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density:
self._runtime_inspector.enable_input_inspector(
self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names
)
if self._enable_sparse_optimizer:
if self._runtime_options.enable_sparse_optimizer:
detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(
inputs, kwargs
)
@ -609,41 +492,31 @@ class GraphExecutionManager(GraphExecutionInterface):
inputs,
kwargs,
detected_device,
self._rt_inspector,
self._runtime_inspector,
)
# Enable sparsity-based optimization when applicable.
if len(label_sparsity_results) > 0:
graph_transformer_config.sparse_label_input_names = list(label_sparsity_results.keys())
self._logger.info("Label sparsity-based optimization is ON for %s", label_sparsity_results)
sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()])
self._feature_map.append(
[
" -LabelSparsityOpt",
"ON",
f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0",
]
self._runtime_options.label_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
)
if self._enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()])
self._feature_map.append(
[
" -EmbedSparsityOpt",
"ON",
f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0",
]
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
if not self._print_input_density:
self._rt_inspector.disable_input_inspector()
if not self._runtime_options.print_input_density:
self._runtime_inspector.disable_input_inspector()
if self._print_memory_stat:
self._rt_inspector.enable_memory_inspector(self._original_module)
if self._runtime_options.print_memory_stat:
self._runtime_inspector.enable_memory_inspector(self._original_module)
def _log_feature_stats(self):
rank = 0
@ -653,12 +526,57 @@ class GraphExecutionManager(GraphExecutionInterface):
if rank != 0:
return
self._feature_map.append(
[
feature_map: List[Tuple[str, bool, str]] = [
("ATen Executor", True, "Dispatch ATen operators to ORT's ATen executor"),
(
"Cast Propagation",
self._runtime_options.propagate_cast_ops_level > 0,
f"Level {self._runtime_options.propagate_cast_ops_level} enabled",
),
(
"Custom Function",
self._runtime_options.enable_custom_autograd_function,
"Support custom torch.autograd.Function export and execution",
),
(
"Memory Optimizer",
len(self._runtime_options.memory_optimizer_config) > 0,
"Enable with env ORTMODULE_MEMORY_OPT_CONFIG=<config>",
),
]
if self._runtime_options.enable_compute_optimizer:
feature_map.extend(
[
(
"Compute Optimizer",
self._runtime_options.enable_compute_optimizer,
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
),
(
" -FLOPReduction",
self._runtime_options.enable_compute_optimizer,
"Reduce FLOPs by upstreaming shrinking-sized ops",
),
]
)
if len(self._runtime_options.label_sparsity_ratio) > 0:
feature_map.append(
(" -LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}")
)
if len(self._runtime_options.embed_sparsity_ratio) > 0:
feature_map.append(
(" -EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}")
)
feature_map.append(
(
"Auto Fallback",
"ON" if self._fallback_manager.policy is not _FallbackPolicy.FALLBACK_DISABLE else "OFF",
self._runtime_options.fallback_policy is not _FallbackPolicy.FALLBACK_DISABLE,
"Fallback to PyTorch when encountering unsupported ops",
]
)
)
mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference"
@ -666,11 +584,10 @@ class GraphExecutionManager(GraphExecutionInterface):
stat = f"\n\n{_logger.LogColor.HEADER}***** ONNX Runtime Training (ORTModule) is accelerating your model *****{_logger.LogColor.ENDC}\n\n"
stat += f"ORTModule is enabled with following features ON/OFF for [{mode}] mode:\n\n"
for feature_tuple in self._feature_map:
stat += f"{feature_tuple[0]:<20}:\t{feature_tuple[1]:<10}:\t{feature_tuple[2]:<80}\n"
for feature_tuple in feature_map:
switch_str = "ON" if feature_tuple[1] else "OFF"
stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n"
# If anything was captured in fo, raise a single user warning letting users know that there was
# any warning or error that was raised
stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n"
stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n"
stat += f"Export duration: {self._export_duration_in_ms:.0f} milliseconds\n"

View file

@ -10,7 +10,7 @@ from ._fallback import _FallbackManager
from ._inference_manager import InferenceManager
from ._io import _FlattenedModule
from ._training_manager import TrainingManager
from .debug_options import DebugOptions
from .options import DebugOptions
class GraphExecutionManagerFactory:

View file

@ -14,8 +14,8 @@ from onnxruntime.capi import _pybind_state as C
from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils
from ._execution_agent import InferenceAgent
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
from .debug_options import DebugOptions
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from .options import DebugOptions, _SkipCheck
class InferenceManager(GraphExecutionManager):
@ -92,21 +92,21 @@ class InferenceManager(GraphExecutionManager):
try:
# Issue at most one warning message about fast path
if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False:
if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False:
self._first_skip_check_warning = False
self._logger.warning(
"Fast path enabled - skipping checks. rebuild gradient graph: %s, execution agent recreation: %s, "
"device check: %s",
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
)
# If exporting module to ONNX for the first time, this skip check will not take effect.
# It will only take effect on subsequent forward calls.
build_graph = False
if (
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
):
# Exporting module to ONNX for the first time
@ -129,7 +129,10 @@ class InferenceManager(GraphExecutionManager):
# If creating the execution agent for the first time, this skip check will not take effect.
# It will only take effect on subsequent forward calls.
create_execution_session = False
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent:
if (
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
or not self._execution_agent
):
module_device = _utils.get_device_from_module(self._original_module)
create_execution_session = (
@ -146,7 +149,7 @@ class InferenceManager(GraphExecutionManager):
# Create execution session creates the inference_session
self._create_execution_agent()
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
@ -158,7 +161,7 @@ class InferenceManager(GraphExecutionManager):
inputs,
kwargs,
self._device,
self._rt_inspector,
self._runtime_inspector,
)
user_outputs, _ = InferenceManager.execution_session_run_forward(

View file

@ -6,7 +6,7 @@ from logging import Logger
from ._fallback import _FallbackManager
from ._torch_module_ort import TorchModuleORT
from .debug_options import DebugOptions
from .options import DebugOptions
class TorchModuleFactory:

View file

@ -12,7 +12,7 @@ from . import _io, _utils
from ._fallback import ORTModuleTorchModelException, _FallbackManager, wrap_exception
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
from ._torch_module_interface import TorchModuleInterface
from .debug_options import DebugOptions
from .options import DebugOptions
T = TypeVar("T", bound="torch.nn.Module")

View file

@ -16,10 +16,10 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg
from ._execution_agent import TrainingAgent
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from ._io import _FlattenedModule, _InputInfo
from ._runtime_inspector import Phase
from .debug_options import DebugOptions
from .options import DebugOptions, _SkipCheck
class TrainingManager(GraphExecutionManager):
@ -104,9 +104,9 @@ class TrainingManager(GraphExecutionManager):
Module outputs are returned to the user
"""
self._rt_inspector.inspect_memory(Phase.PRE_FORWARD)
self._runtime_inspector.inspect_memory(Phase.PRE_FORWARD)
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
# Assert that the input and model device match
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
@ -139,7 +139,7 @@ class TrainingManager(GraphExecutionManager):
for idx in self._graph_info.output_grad_indices_non_differentiable:
ctx.mark_non_differentiable(user_outputs[idx])
self._rt_inspector.inspect_memory(Phase.POST_FORWARD)
self._runtime_inspector.inspect_memory(Phase.POST_FORWARD)
return user_outputs
@ -147,10 +147,10 @@ class TrainingManager(GraphExecutionManager):
def backward(ctx, *grad_outputs):
"""Performs backward pass based on grad wrt module output"""
self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD)
self._runtime_inspector.inspect_memory(Phase.PRE_BACKWARD)
assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
# Unpack saved_tensor to trigger version detection that catches inplace corruption
@ -198,7 +198,7 @@ class TrainingManager(GraphExecutionManager):
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
self._rt_inspector.inspect_memory(Phase.POST_BACKWARD)
self._runtime_inspector.inspect_memory(Phase.POST_BACKWARD)
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
@ -226,21 +226,21 @@ class TrainingManager(GraphExecutionManager):
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
try:
if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False:
if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False:
# Only change this after the firs time a warning is issued.
self._first_skip_check_warning = False
self._logger.info(
"Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s",
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
)
# If exporting module to ONNX for the first time, this skip check will not take effect.
# It will only take effect on subsequent forward calls.
build_gradient_graph = False
if (
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
):
build_gradient_graph = self._export_model(*inputs, **kwargs)
@ -274,7 +274,10 @@ class TrainingManager(GraphExecutionManager):
# If creating the execution agent for the first time, this skip check will not take effect.
# It will only take effect on subsequent forward calls.
create_execution_session = False
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent:
if (
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
or not self._execution_agent
):
device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(
inputs, kwargs
)
@ -292,7 +295,7 @@ class TrainingManager(GraphExecutionManager):
self._create_execution_agent()
self._gradient_accumulation_manager.initialize(
self._enable_grad_acc_optimization, self._flattened_module, self._graph_info
self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info
)
self._gradient_accumulation_manager.maybe_update_cache_before_run()
@ -305,7 +308,7 @@ class TrainingManager(GraphExecutionManager):
inputs,
kwargs,
self._device,
self._rt_inspector,
self._runtime_inspector,
)
return _io.unflatten_user_output(

View file

@ -1,118 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# debug_options.py
import os
from ._logger import LogLevel
class _SaveOnnxOptions:
"""Configurable option to save ORTModule intermediate onnx models."""
# class variable
_path_environment_key = "ORTMODULE_SAVE_ONNX_PATH"
def __init__(self, save, name_prefix):
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix)
def _extract_info(self, save, name_prefix):
# get the destination path from os env variable
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd())
# perform validation only when save is True
if save:
self._validate(save, name_prefix, destination_path)
return save, name_prefix, destination_path
def _validate(self, save, name_prefix, destination_path):
# check if directory is writable
if not os.access(destination_path, os.W_OK):
raise OSError(
f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path."
)
# check if input prefix is a string
if not isinstance(name_prefix, str):
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
# if save_onnx is set, save_onnx_prefix must be a non empty string
if not name_prefix:
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
@property
def save(self):
return self._save
@property
def name_prefix(self):
return self._name_prefix
@property
def path(self):
return self._path
class _LoggingOptions:
"""Configurable option to set the log level in ORTModule."""
# class variable
_log_level_environment_key = "ORTMODULE_LOG_LEVEL"
def __init__(self, log_level):
self._log_level = self._extract_info(log_level)
def _extract_info(self, log_level):
# get the log_level from os env variable
# OS environment variable log level superseeds the locally provided one
self._validate(log_level)
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
return log_level
def _validate(self, log_level):
# check if log_level is an instance of LogLevel
if not isinstance(log_level, LogLevel):
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
@property
def log_level(self) -> LogLevel:
return self._log_level
class DebugOptions:
"""Configurable debugging options for ORTModule.
Args:
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
takes precedence.
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory.
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
set to the destination directory path.
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
Must be provided if save_onnx is True
Raises:
OSError: If save_onnx is True and output directory is not writable.
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
log_level is not an instance of LogLevel.
ValueError: If save_onnx is True and name_prefix is an empty string.
"""
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=""):
self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix)
self._logging = _LoggingOptions(log_level)
@property
def save_onnx_models(self):
"""Accessor for the ONNX saving configuration."""
return self._save_onnx_models
@property
def logging(self):
"""Accessor for the logging configuration."""
return self._logging

View file

@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# debug_options.py
# _hierarchical_ortmodule.py
import tempfile
import warnings
import torch
from .... import ortmodule
from ... import ORTModule
from ...debug_options import DebugOptions, LogLevel
from onnxruntime.training import ortmodule
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel
# nn.Module's in this set are considered exportable to ONNX.
# For other nn.Module's, torch.onnx.export is called to check if

View file

@ -9,11 +9,9 @@ from functools import reduce
from types import SimpleNamespace
from onnxruntime.capi import _pybind_state as C
from onnxruntime.training import ortmodule
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel, _SaveOnnxOptions, _SkipCheck
from ..._fallback import _FallbackPolicy
from ..._graph_execution_manager import _SkipCheck
from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions
from . import JSON_PATH_ENVIRONMENT_KEY
log = logging.getLogger(__name__)
@ -39,15 +37,15 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data):
log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.")
def _update_strategy():
ortmodule_config_accessor._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[
ortmodule_config_accessor._runtime_options.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[
data.PropagateCastOps.Strategy
]
def _update_level():
ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level
ortmodule_config_accessor._runtime_options.propagate_cast_ops_level = data.PropagateCastOps.Level
def _update_allow():
ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow
ortmodule_config_accessor._runtime_options.propagate_cast_ops_allow = data.PropagateCastOps.Allow
key_to_function_mapping = {"Strategy": _update_strategy, "Level": _update_level, "Allow": _update_allow}
@ -64,8 +62,7 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
assert isinstance(
data.UseExternalGPUAllocator, bool
), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean"
ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator
ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses()
ortmodule_config_accessor._runtime_options.use_external_gpu_allocator = data.UseExternalGPUAllocator
def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
@ -79,7 +76,11 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
assert isinstance(
data.EnableCustomAutogradFunction, bool
), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support(data.EnableCustomAutogradFunction)
ortmodule_config_accessor._runtime_options.enable_custom_autograd_function = data.EnableCustomAutogradFunction
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
@ -91,7 +92,7 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
assert isinstance(
data.EnableGradAccOptimization, bool
), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization
ortmodule_config_accessor._runtime_options.enable_grad_acc_optimization = data.EnableGradAccOptimization
def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
@ -103,7 +104,7 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
assert isinstance(
data.RunSymbolicShapeInference, bool
), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference
ortmodule_config_accessor._runtime_options.run_symbolic_shape_infer = data.RunSymbolicShapeInference
def _load_use_static_shape(ortmodule_config_accessor, data):
@ -113,7 +114,7 @@ def _load_use_static_shape(ortmodule_config_accessor, data):
log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.")
assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean"
ortmodule_config_accessor._use_static_shape = data.UseStaticShape
ortmodule_config_accessor._runtime_options.use_static_shape = data.UseStaticShape
def _load_skip_check(ortmodule_config_accessor, data):
@ -124,7 +125,7 @@ def _load_skip_check(ortmodule_config_accessor, data):
skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck])
if skip_check.value > 0:
ortmodule_config_accessor._skip_check = skip_check
ortmodule_config_accessor._runtime_options.skip_check = skip_check
def _load_debug_options(ortmodule_config_accessor, data):
@ -177,7 +178,7 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
assert isinstance(
data.UseMemoryEfficientGradient, bool
), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient
ortmodule_config_accessor._runtime_options.use_memory_efficient_gradient = data.UseMemoryEfficientGradient
def _load_fallback_policy(ortmodule_config_accessor, data):
@ -198,7 +199,7 @@ def _load_onnx_opset_version(ortmodule_config_accessor, data):
log.info(f"Found keyword {_load_onnx_opset_version.loading_key} in json. Loading attributes from file.")
assert isinstance(data.OnnxOpsetVersion, int), f"{_load_onnx_opset_version.loading_key} must be an int"
ortmodule.ONNX_OPSET_VERSION = data.OnnxOpsetVersion
ortmodule_config_accessor._runtime_options.onnx_opset_version = data.OnnxOpsetVersion
def _define_load_function_keys():

View file

@ -0,0 +1,287 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# options.py
import os
from enum import IntFlag
from functools import reduce
from logging import Logger
from onnxruntime.capi import _pybind_state as C
from onnxruntime.training import ortmodule
from ._fallback import _FallbackPolicy
from ._logger import LogLevel
from ._utils import parse_os_env_skip_check_flags
class _SaveOnnxOptions:
"""Configurable option to save ORTModule intermediate onnx models."""
# class variable
_path_environment_key = "ORTMODULE_SAVE_ONNX_PATH"
def __init__(self, save, name_prefix, path: str):
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix, path)
def _extract_info(self, save, name_prefix, path: str):
# get the destination path from os env variable
default_path = path if len(path) > 0 else os.getcwd()
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, default_path)
# perform validation only when save is True
if save:
self._validate(save, name_prefix, destination_path)
return save, name_prefix, destination_path
def _validate(self, save, name_prefix, destination_path):
# check if directory is writable
if not os.access(destination_path, os.W_OK):
raise OSError(
f"Directory {destination_path} is not writable. Please set the "
f"{_SaveOnnxOptions._path_environment_key} environment variable to a writable path."
)
# check if input prefix is a string
if not isinstance(name_prefix, str):
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
# if save_onnx is set, save_onnx_prefix must be a non empty string
if not name_prefix:
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
@property
def save(self):
return self._save
@property
def name_prefix(self):
return self._name_prefix
@property
def path(self):
return self._path
class _LoggingOptions:
"""Configurable option to set the log level in ORTModule."""
# class variable
_log_level_environment_key = "ORTMODULE_LOG_LEVEL"
def __init__(self, log_level):
self._log_level = self._extract_info(log_level)
def _extract_info(self, log_level):
# get the log_level from os env variable
# OS environment variable log level superseeds the locally provided one
self._validate(log_level)
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
return log_level
def _validate(self, log_level):
# check if log_level is an instance of LogLevel
if not isinstance(log_level, LogLevel):
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
@property
def log_level(self) -> LogLevel:
return self._log_level
class DebugOptions:
"""Configurable debugging options for ORTModule.
Args:
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
takes precedence.
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
The output directory of the onnx models by default is set to the current working directory.
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
set to the destination directory path.
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
Must be provided if save_onnx is True
Raises:
OSError: If save_onnx is True and output directory is not writable.
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
log_level is not an instance of LogLevel.
ValueError: If save_onnx is True and name_prefix is an empty string.
"""
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix="", save_path="", config=None):
self.log_level = log_level
self.save_onnx = save_onnx
self.onnx_prefix = onnx_prefix
self._save_onnx_models = _SaveOnnxOptions(self.save_onnx, self.onnx_prefix, save_path)
self._logging = _LoggingOptions(self.log_level)
@property
def save_onnx_models(self):
"""Accessor for the ONNX saving configuration."""
return self._save_onnx_models
@property
def logging(self):
"""Accessor for the logging configuration."""
return self._logging
class _SkipCheck(IntFlag):
"""Enumeration to specify which checks should be skipped, allowing faster execution"""
SKIP_CHECK_DISABLED = 1
SKIP_CHECK_DEVICE = 2
SKIP_CHECK_BUILD_GRADIENT = 4
SKIP_CHECK_EXECUTION_AGENT = 8
def is_set(self, check):
"""Check whether `check` is set on the `_SkipCheck instance
SKIP_CHECK_DISABLED implies the check will return False
"""
return not _SkipCheck.is_disabled(self) and check in self
def is_disabled(self):
"""Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance"""
return _SkipCheck.SKIP_CHECK_DISABLED in self
class _RuntimeOptions:
"""Configurable runtime options for ORTModule."""
def __init__(self, logger: Logger):
"""Constructor for Options.
Initially set all the options to their default values, then override them with the values
from the environment variables.
"""
self._logger = logger
self.onnx_opset_version = ortmodule.ONNX_OPSET_VERSION
self.conv_algo_search = "HEURISTIC"
# Configuration for cast optimization.
# Specify cast propagation strategy. Currently, three strategies are available:
# NONE, INSERT-AND-REDUCE and FLOOD-FILL
# The default is FLOOD_FILL, expand FP16 computation regions in the graph using
# allowed opcodes for the given level.
self.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL
# Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
# - If the propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes
# specified by propagate_cast_ops_allow as "FP16 safe", to insert/(re)move cast operations before/after
# to perform such operations in reduced (16-bit) precision.
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by
# propagate_cast_ops_allow, use onnxruntime predetermined list of opcodes considered safe to move
# before/after the cast operation.
# - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform
# any computation such as Transpose, Split, Reshape, etc., or the computation is actually in Float
# such as GeLU, etc.
# - Whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using
# contrib ops, Dropout, LayerNormalization, etc.
self.propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level
# is zero.
self.propagate_cast_ops_allow = []
# default execution order is priority-based for both dynamic/static shape input for now
# if we observe the benefit of static shape, we can expose this flag to the user
self.use_static_shape = False
# flag to enable symbolic shape inference for dynamic shape inputs to improve performance
self.run_symbolic_shape_infer = True
# PyTorch custom Autograd function support
from ._custom_autograd_function import custom_autograd_function_enabler
self.enable_custom_autograd_function = custom_autograd_function_enabler.state
self.use_external_gpu_allocator = True
# WIP feature to enable caching in Gradient accumulation scenario.
self.enable_grad_acc_optimization = False
# Memory-aware gradient builder.
self.use_memory_efficient_gradient = False
# Configuration for compute optimization.
self.enable_compute_optimizer = True
self.enable_sparse_optimizer = True
self.label_sparsity_ratio = ""
self.embed_sparsity_ratio = ""
self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done.
# Configuration for memory optimization.
self.memory_optimizer_config = ""
self.probe_level = "1"
# Configuration for dev tools.
self.print_input_density = False
self.print_memory_stat = False
# Configuration for fallback.
self.fallback_policy = ortmodule.ORTMODULE_FALLBACK_POLICY
# Configuration for skip check.
# Indicators of some logic have been executed previously and thus could be skipped for faster training
# default is enabled, if not defined in os env
self.skip_check = _SkipCheck(
_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT
)
# Override the feature config if it exists in os env.
self._override_from_env_vars()
def _override_from_env_vars(self):
self.onnx_opset_version = int(os.getenv("ORTMODULE_ONNX_OPSET_VERSION", self.onnx_opset_version))
self.conv_algo_search = os.getenv("ORTMODULE_CONV_ALGO_SEARCH", self.conv_algo_search)
if self.conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
self.conv_algo_search = "HEURISTIC"
# Configuration for compute optimization.
compute_optimizer_reset = False
if "ORTMODULE_ENABLE_COMPUTE_OPTIMIZER" in os.environ:
self.enable_compute_optimizer = int(os.getenv("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER")) == 1
compute_optimizer_reset = True
if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ or compute_optimizer_reset:
if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ:
self.enable_sparse_optimizer = int(os.getenv("ORTMODULE_ENABLE_SPARSE_OPTIMIZER")) == 1
self.enable_sparse_optimizer = self.enable_compute_optimizer and self.enable_sparse_optimizer
# TODO(pengwa): remove once validation on more models are done.
if "ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER" in os.environ:
self.enable_embedding_sparse_optimizer = (
self.enable_sparse_optimizer and int(os.getenv("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER")) == 1
)
# Configuration for memory optimization.
self.memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config)
self.probe_level = os.getenv("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", self.probe_level)
# Configuration for dev tools.
if "ORTMODULE_PRINT_INPUT_DENSITY" in os.environ:
self.print_input_density = int(os.getenv("ORTMODULE_PRINT_INPUT_DENSITY")) == 1
if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ:
self.print_memory_stat = int(os.getenv("ORTMODULE_PRINT_MEMORY_STATS")) == 1
# Configuration for fallback.
if "ORTMODULE_FALLBACK_POLICY" in os.environ:
self.fallback_policy = os.getenv("ORTMODULE_FALLBACK_POLICY")
if isinstance(self.fallback_policy, str):
self.fallback_policy = _FallbackPolicy[self.fallback_policy]
# Configuration for skip check.
if "ORTMODULE_SKIPCHECK_POLICY" in os.environ:
self.skip_check = reduce(
lambda x, y: x | y,
[_SkipCheck[name] for name in parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")],
)

View file

@ -10,7 +10,7 @@ from ._torch_module_ort import TorchModuleORT
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
from ._custom_gradient_registry import CustomGradientRegistry
from . import _utils
from .debug_options import DebugOptions
from .options import DebugOptions
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException
from ._logger import ortmodule_loglevel_to_python_loglevel
from onnxruntime.training import ortmodule
@ -77,7 +77,9 @@ class ORTModule(torch.nn.Module):
# Support contrib OPs
pytorch_export_contrib_ops.register()
CustomOpSymbolicRegistry.register_all()
CustomOpSymbolicRegistry.register_all(
self._torch_module._execution_manager(module.training)._runtime_options.onnx_opset_version
)
CustomGradientRegistry.register_all()
# Warn user if there are name collisions between user model's and ORTModule attributes
@ -320,7 +322,9 @@ class ORTModule(torch.nn.Module):
# Re-register contrib OPs
pytorch_export_contrib_ops.register()
CustomOpSymbolicRegistry.register_all()
CustomOpSymbolicRegistry.register_all(
self._torch_module._execution_manager(self.module.training)._runtime_options.onnx_opset_version
)
CustomGradientRegistry.register_all()
# Re-initialize the ORTModule forward method

View file

@ -30,16 +30,9 @@ from transformers.modeling_outputs import SequenceClassifierOutput
import onnxruntime.training.ortmodule as ortmodule_module
from onnxruntime.training.optim import AdamWMode, FusedAdam
from onnxruntime.training.ortmodule import (
DebugOptions,
LogLevel,
ORTModule,
_fallback,
_graph_execution_manager,
_io,
_utils,
)
from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
from onnxruntime.training.ortmodule.options import _SkipCheck
DEFAULT_OPSET = 15
@ -4454,7 +4447,7 @@ def test_ortmodule_gradient_accumulation_optimization_correctness():
# model with optimization enabled
opt_model = ORTModule(copy.deepcopy(pt_model))
opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
opt_model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True
opt_optimizer = torch.optim.Adam(opt_model.parameters())
def run_step(model, x):
@ -4879,10 +4872,10 @@ def test_ortmodule_ortmodule_method_attribute_copy():
@pytest.mark.parametrize(
"policy_str, policy",
[
("SKIP_CHECK_DISABLED", _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),
("SKIP_CHECK_DEVICE", _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE),
("SKIP_CHECK_BUILD_GRADIENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
("SKIP_CHECK_EXECUTION_AGENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
("SKIP_CHECK_DISABLED", _SkipCheck.SKIP_CHECK_DISABLED),
("SKIP_CHECK_DEVICE", _SkipCheck.SKIP_CHECK_DEVICE),
("SKIP_CHECK_BUILD_GRADIENT", _SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
("SKIP_CHECK_EXECUTION_AGENT", _SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
],
)
def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
@ -4893,7 +4886,7 @@ def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
ort_model = ORTModule(model)
for training_mode in [False, True]:
assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy
assert ort_model._torch_module._execution_manager(training_mode)._runtime_options.skip_check == policy
del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
@ -5227,10 +5220,8 @@ def test_sigmoid_grad_opset13():
N, D_in, H, D_out = 120, 15360, 500, 15360 # noqa: N806
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
old_opst_cst = ortmodule_module.ONNX_OPSET_VERSION
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
assert ortmodule_module.ONNX_OPSET_VERSION == 15
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -5256,21 +5247,25 @@ def test_sigmoid_grad_opset13():
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
else:
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
assert ortmodule_module.ONNX_OPSET_VERSION == 13
ortmodule_module.ONNX_OPSET_VERSION = old_opst_cst
assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
def test_opset_version_change(opset_version):
original_env = None
if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ:
original_env = os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
device = "cuda"
N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806
x = torch.randn(N, D_in, device=device)
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
ort_model = ORTModule(model)
ortmodule_module.ONNX_OPSET_VERSION = opset_version
ort_model = ORTModule(model)
# Make sure model runs without any exception
prediction = ort_model(x)
@ -5282,6 +5277,9 @@ def test_opset_version_change(opset_version):
exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
assert exported_model.opset_import[0].version == opset_version
if original_env is not None:
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = original_env
def test_serialize_ortmodule():
device = "cuda"

View file

@ -420,7 +420,7 @@ def main():
)
model = ORTModule(model, debug_options)
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True
# Tell pytorch to run this model on the GPU.
if torch.cuda.is_available() and not args.no_cuda:

View file

@ -38,27 +38,29 @@ def test_load_config_from_json_1():
ort_model_attributes = model._torch_module._execution_manager(training_mode)
# test propagate cast ops
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
assert ort_model_attributes._propagate_cast_ops_level == 3
assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"]
assert (
ort_model_attributes._runtime_options.propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
)
assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 3
assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["ABC", "DEF"]
# test use external gpu allocator
assert ort_model_attributes._use_external_gpu_allocator is False
assert ort_model_attributes._runtime_options.use_external_gpu_allocator is False
# test enable custom autograd function
assert ort_model_attributes._enable_custom_autograd_function is True
assert ort_model_attributes._runtime_options.enable_custom_autograd_function is True
# test use static shape
assert ort_model_attributes._use_static_shape is True
assert ort_model_attributes._runtime_options.use_static_shape is True
# test run symbolic shape inference
assert ort_model_attributes._run_symbolic_shape_infer is False
assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is False
# test enable grad acc optimization
assert ort_model_attributes._enable_grad_acc_optimization is True
assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is True
# test skip check
assert ort_model_attributes._skip_check.value == 14
assert ort_model_attributes._runtime_options.skip_check.value == 14
# test debug options
assert ort_model_attributes._debug_options.save_onnx_models.save is True
@ -66,13 +68,13 @@ def test_load_config_from_json_1():
assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE"
# test use memory aware gradient builder.
assert ort_model_attributes._use_memory_efficient_gradient is False
assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is False
# test fallback policy
assert ort_model_attributes._fallback_manager.policy.value == 1
# assert onnx opset version
assert ortmodule.ONNX_OPSET_VERSION == 13
assert ort_model_attributes._runtime_options.onnx_opset_version == 13
def test_load_config_from_json_2():
@ -91,27 +93,30 @@ def test_load_config_from_json_2():
ort_model_attributes = model._torch_module._execution_manager(training_mode)
# test propagate cast ops
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE
assert ort_model_attributes._propagate_cast_ops_level == 5
assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"]
assert (
ort_model_attributes._runtime_options.propagate_cast_ops_strategy
== C.PropagateCastOpsStrategy.INSERT_AND_REDUCE
)
assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 5
assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["XYZ", "PQR"]
# test use external gpu allocator
assert ort_model_attributes._use_external_gpu_allocator is True
assert ort_model_attributes._runtime_options.use_external_gpu_allocator is True
# test enable custom autograd function
assert ort_model_attributes._enable_custom_autograd_function is False
assert ort_model_attributes._runtime_options.enable_custom_autograd_function is False
# test use static shape
assert ort_model_attributes._use_static_shape is False
assert ort_model_attributes._runtime_options.use_static_shape is False
# test run symbolic shape inference
assert ort_model_attributes._run_symbolic_shape_infer is True
assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is True
# test enable grad acc optimization
assert ort_model_attributes._enable_grad_acc_optimization is False
assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is False
# test skip check
assert ort_model_attributes._skip_check.value == 10
assert ort_model_attributes._runtime_options.skip_check.value == 10
# test debug options
assert ort_model_attributes._debug_options.save_onnx_models.save is True
@ -119,10 +124,10 @@ def test_load_config_from_json_2():
assert ort_model_attributes._debug_options.logging.log_level.name == "INFO"
# test use memory aware gradient builder.
assert ort_model_attributes._use_memory_efficient_gradient is True
assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is True
# test fallback policy
assert ort_model_attributes._fallback_manager.policy.value == 250
# assert onnx opset version
assert ortmodule.ONNX_OPSET_VERSION == 12
assert ort_model_attributes._runtime_options.onnx_opset_version == 12