mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
Manage ORTModule configurations consistently (#16396)
### Manage ORTModule options Move all env vars that used for feature ON/OFF into runtime options for consistent managements. Be noted: the features' switch are assigned in 2 phases: default values, overwritten by env vars (if specified by users). So env vars take the highest priority when all 2 phases both given value explicitly for one feature. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
403bebfb51
commit
a49bb85cfe
19 changed files with 506 additions and 407 deletions
|
|
@ -47,7 +47,7 @@ More options for **developers**.
|
|||
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
|
||||
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
|
||||
```
|
||||
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/debug_options.py) for more details.
|
||||
Check [DebugOptions implementation](../orttraining/orttraining/python/training/ortmodule/options.py) for more details.
|
||||
|
||||
### 2.1 Environment Variables
|
||||
|
||||
|
|
@ -99,12 +99,13 @@ The output directory of the onnx models by default is set to the current working
|
|||
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
|
||||
|
||||
|
||||
#### ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT
|
||||
#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD
|
||||
|
||||
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
|
||||
- **Description**: By default, all torch.autograd.Function classes will be exported to ORT PythonOp. There are some cases where you might consider disable it. For example, if you confirmed those torch.autograd.Function classes defined computations that could be inline exported by PyTorch, and it is safe to use the inline exported ONNX graph to train, then you can disable it, as a result, ORT has more opportunities to optimize more.
|
||||
```bash
|
||||
export ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT=1
|
||||
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1 # Enable
|
||||
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=0 # Disable
|
||||
```
|
||||
|
||||
An alternative to disable without using environment variable:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def export_gradient_graph(
|
|||
"""
|
||||
|
||||
# Make sure that loss nodes that expect multiple outputs are set up.
|
||||
CustomOpSymbolicRegistry.register_all()
|
||||
CustomOpSymbolicRegistry.register_all(opset_version)
|
||||
|
||||
if not isinstance(gradient_graph_path, str):
|
||||
gradient_graph_path = str(gradient_graph_path)
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def _are_deterministic_algorithms_enabled():
|
|||
return ORTMODULE_IS_DETERMINISTIC
|
||||
|
||||
|
||||
from .debug_options import DebugOptions, LogLevel # noqa: E402, F401
|
||||
from .options import DebugOptions, LogLevel # noqa: E402, F401
|
||||
|
||||
# ORTModule must be loaded only after all validation passes
|
||||
from .ortmodule import ORTModule # noqa: E402, F401
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from onnxruntime.capi._pybind_state import is_torch_interop_default_on
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
|
||||
class Enabler:
|
||||
def __init__(self):
|
||||
|
|
@ -83,10 +86,7 @@ def enable_custom_autograd_support(to_enable=True):
|
|||
custom_autograd_function_enabler.state = False
|
||||
|
||||
|
||||
from onnxruntime.capi._pybind_state import is_torch_interop_default_on # noqa: E402
|
||||
from onnxruntime.training import ortmodule # noqa: E402
|
||||
|
||||
# Enable the custom autograd by default when PythonOp backend support is enabled during build.
|
||||
enable_custom_autograd_support(
|
||||
not ortmodule._defined_from_envvar("ORTMODULE_DISABLE_CUSTOM_AUTOGRAD_SUPPORT", 0) and is_torch_interop_default_on()
|
||||
ortmodule._defined_from_envvar("ORTMODULE_ENABLE_CUSTOM_AUTOGRAD", 1) and is_torch_interop_default_on()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ from packaging.version import Version
|
|||
from torch.onnx import register_custom_op_symbolic
|
||||
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
|
||||
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from ._utils import get_runtime_pytorch_version
|
||||
|
||||
# Mapping from pytorch scalar type to onnx scalar type.
|
||||
|
|
@ -97,13 +95,13 @@ class CustomOpSymbolicRegistry:
|
|||
cls._SYMBOLICS[domain + "::" + name] = fn
|
||||
|
||||
@classmethod
|
||||
def register_all(cls):
|
||||
def register_all(cls, onnx_opset_version):
|
||||
for name, fn in cls._SYMBOLICS.items():
|
||||
# Symbolic name is in format: domain::name
|
||||
register_custom_op_symbolic(
|
||||
name,
|
||||
fn,
|
||||
ortmodule._defined_from_envvar("ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION),
|
||||
onnx_opset_version,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ import io
|
|||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod # noqa: F401
|
||||
from enum import IntFlag
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import onnx
|
||||
|
|
@ -20,9 +18,8 @@ from torch.utils.cpp_extension import ROCM_HOME
|
|||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _runtime_inspector, _utils
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
|
||||
from ._custom_autograd_function_exporter import _post_process_after_export
|
||||
from ._fallback import (
|
||||
ORTModuleDeviceException,
|
||||
|
|
@ -35,7 +32,8 @@ from ._fallback import (
|
|||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from ._graph_execution_interface import GraphExecutionInterface
|
||||
from ._io import _FlattenedModule, _InputInfo, _ModelInputOutputSchemaType
|
||||
from .debug_options import DebugOptions, LogLevel
|
||||
from ._runtime_inspector import RuntimeInspector
|
||||
from .options import DebugOptions, LogLevel, _RuntimeOptions
|
||||
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
|
||||
|
||||
|
||||
|
|
@ -49,28 +47,6 @@ class _RunStateInfo:
|
|||
self.output_info = output_info
|
||||
|
||||
|
||||
class _SkipCheck(IntFlag):
|
||||
"""Enumeration to specify which checks should be skipped, allowing faster execution"""
|
||||
|
||||
SKIP_CHECK_DISABLED = 1
|
||||
SKIP_CHECK_DEVICE = 2
|
||||
SKIP_CHECK_BUILD_GRADIENT = 4
|
||||
SKIP_CHECK_EXECUTION_AGENT = 8
|
||||
|
||||
def is_set(self, check):
|
||||
"""Check whether `check` is set on the `_SkipCheck instance
|
||||
|
||||
SKIP_CHECK_DISABLED implies the check will return False
|
||||
"""
|
||||
|
||||
return not _SkipCheck.is_disabled(self) and check in self
|
||||
|
||||
def is_disabled(self):
|
||||
"""Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance"""
|
||||
|
||||
return _SkipCheck.SKIP_CHECK_DISABLED in self
|
||||
|
||||
|
||||
class GraphExecutionManager(GraphExecutionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -89,6 +65,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
self._logger = logger
|
||||
|
||||
self._runtime_options = _RuntimeOptions(self._logger)
|
||||
# Original and flattened (transformed) output module
|
||||
self._flattened_module = module
|
||||
|
||||
|
|
@ -102,44 +79,12 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._graph_initializer_names_to_train = set()
|
||||
self._graph_initializers: List[torch.nn.parameter.Parameter] = []
|
||||
|
||||
# Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION
|
||||
# if defined.
|
||||
ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar(
|
||||
"ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION, warn=True
|
||||
)
|
||||
|
||||
# TrainingAgent or InferenceAgent
|
||||
self._execution_agent = None
|
||||
|
||||
# Indicators of some logic have been executed previously and thus could be skipped for faster training
|
||||
# default is enabled, if not defined in os env
|
||||
self._skip_check = _SkipCheck(
|
||||
_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT
|
||||
)
|
||||
if os.getenv("ORTMODULE_SKIPCHECK_POLICY") is not None:
|
||||
self._skip_check = reduce(
|
||||
lambda x, y: x | y,
|
||||
[_SkipCheck[name] for name in _utils.parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")],
|
||||
)
|
||||
self._first_skip_check_warning = True
|
||||
|
||||
self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger)
|
||||
|
||||
# Graph transformer config
|
||||
# Specify cast propagation strategy. Currently, three strategies are available, NONE, INSERT-AND-REDUCE and FLOOD-FILL
|
||||
# The default is FLOOD_FILL, expand FP16 computation regions in the graph using allowed opcodes for the given level.
|
||||
self._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL
|
||||
# Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
|
||||
# - If the _propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes specified by _propagate_cast_ops_allow
|
||||
# as "FP16 safe", to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision.
|
||||
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow, use onnxruntime
|
||||
# predetermined list of opcodes considered safe to move before/after the cast operation.
|
||||
# - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform any computation such as Transpose, Split, Reshape, etc.,
|
||||
# or the computation is actually in Float such as GeLU, etc.
|
||||
# whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc.
|
||||
self._propagate_cast_ops_level = 1
|
||||
# List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level is zero.
|
||||
self._propagate_cast_ops_allow = []
|
||||
self._runtime_inspector = RuntimeInspector(self._logger)
|
||||
|
||||
# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
|
||||
# To be instantiated in the concrete implementation of GraphExecutionManager
|
||||
|
|
@ -149,18 +94,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# It cannot overlap with required/immutable arguments (validated in runtime)
|
||||
self._export_extra_kwargs = {}
|
||||
|
||||
# default execution order is priority-based for both dynamic/static shape input for now
|
||||
# if we observe the benefit of static shape, we can expose this flag to the user
|
||||
self._use_static_shape = False
|
||||
|
||||
# flag to enable symbolic shape inference for dynamic shape inputs to improve performance
|
||||
self._run_symbolic_shape_infer = True
|
||||
|
||||
# PyTorch custom Autograd function support
|
||||
from ._custom_autograd_function import custom_autograd_function_enabler
|
||||
|
||||
self._enable_custom_autograd_function = custom_autograd_function_enabler.state
|
||||
|
||||
# Input and output infos (including schema) for exported model.
|
||||
self._input_info: Optional[_InputInfo] = None
|
||||
self._module_output_schema: Optional[_ModelInputOutputSchemaType] = None
|
||||
|
|
@ -182,36 +115,9 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None)
|
||||
|
||||
self._use_external_gpu_allocator = True
|
||||
# assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True
|
||||
self._get_torch_gpu_allocator_function_addresses()
|
||||
|
||||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self._enable_grad_acc_optimization = False
|
||||
self._gradient_accumulation_manager = GradientAccumulationManager()
|
||||
|
||||
# Memory-aware gradient builder.
|
||||
self._use_memory_efficient_gradient = False
|
||||
|
||||
# Enable compute optimizer by default. Allowed to be disabled via an environment variable for
|
||||
# convergence parity investigation.
|
||||
self._enable_compute_optimizer = (
|
||||
ortmodule._defined_from_envvar("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER", 1, warn=True) == 1
|
||||
)
|
||||
self._enable_sparse_optimizer = (
|
||||
self._enable_compute_optimizer
|
||||
and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_SPARSE_OPTIMIZER", 1, warn=True) == 1
|
||||
)
|
||||
self._enable_embedding_sparse_optimizer = (
|
||||
self._enable_compute_optimizer
|
||||
and self._enable_sparse_optimizer
|
||||
and ortmodule._defined_from_envvar("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER", 0, warn=True) == 1
|
||||
)
|
||||
|
||||
self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1
|
||||
self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1
|
||||
self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
|
||||
|
||||
# Flag to re-export the model due to attribute change on the original module.
|
||||
# Re-export will be avoided if _skip_check is enabled.
|
||||
self._original_model_has_changed = False
|
||||
|
|
@ -219,23 +125,11 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# Load ATen operator executor extension.
|
||||
load_aten_op_executor_cpp_extension()
|
||||
|
||||
self._feature_map: List[List[str]] = [
|
||||
["ATen Executor", "ON", "Dispatch ATen operators to ORT's ATen executor"],
|
||||
[
|
||||
"Cast Propagation",
|
||||
"ON" if self._propagate_cast_ops_level > 0 else "OFF",
|
||||
f"Level {self._propagate_cast_ops_level} enabled",
|
||||
],
|
||||
["Custom Function", "ON", "Support custom torch.autograd.Function export and execution"],
|
||||
[
|
||||
"Memory Optimizer",
|
||||
"ON" if self._enable_memory_optimizer else "OFF",
|
||||
"Enable with env ORTMODULE_MEMORY_OPT_CONFIG=<config>",
|
||||
],
|
||||
]
|
||||
# Assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True
|
||||
self._get_torch_gpu_allocator_function_addresses()
|
||||
|
||||
def _get_torch_gpu_allocator_function_addresses(self):
|
||||
if self._use_external_gpu_allocator and torch.cuda.is_available():
|
||||
if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available():
|
||||
# CPP extension to get torch GPU allocator's alloc and free function addresses
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator
|
||||
|
||||
|
|
@ -272,7 +166,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
pass
|
||||
|
||||
def _build_graph(self, config):
|
||||
if self._use_static_shape:
|
||||
if self._runtime_options.use_static_shape:
|
||||
self._graph_builder.build(config, self._input_info.shape)
|
||||
else:
|
||||
self._graph_builder.build(config)
|
||||
|
|
@ -294,14 +188,10 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
provider_option_map = {"device_id": str(self._device.index)}
|
||||
if not self.is_rocm_pytorch:
|
||||
# Set Conv algo search mode to HEURISTIC by default, which is the same as PyTorch's default setting.
|
||||
conv_algo_search = ortmodule._defined_from_envvar("ORTMODULE_CONV_ALGO_SEARCH", "HEURISTIC", warn=True)
|
||||
if conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
|
||||
self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
|
||||
conv_algo_search = "HEURISTIC"
|
||||
provider_option_map["cudnn_conv_algo_search"] = conv_algo_search
|
||||
provider_option_map["cudnn_conv_algo_search"] = self._runtime_options.conv_algo_search
|
||||
provider_option_map["cudnn_conv_use_max_workspace"] = "1"
|
||||
provider_option_map["cudnn_conv1d_pad_to_nc1d"] = "1"
|
||||
if self._use_external_gpu_allocator:
|
||||
if self._runtime_options.use_external_gpu_allocator:
|
||||
provider_option_map["gpu_external_alloc"] = str(self._torch_alloc)
|
||||
provider_option_map["gpu_external_free"] = str(self._torch_free)
|
||||
provider_option_map["gpu_external_empty_cache"] = str(self._torch_empty_cache)
|
||||
|
|
@ -323,11 +213,13 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
session_options.log_severity_level = int(self._debug_options.logging.log_level)
|
||||
# Disable memory alleviation by default. Allow user to enable it via environment variable.
|
||||
alleviation_config = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
|
||||
probe_level = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", "1", warn=True)
|
||||
session_options.add_session_config_entry("optimization.enable_memory_optimizer", alleviation_config)
|
||||
session_options.add_session_config_entry("optimization.enable_memory_probe_recompute_level", probe_level)
|
||||
|
||||
session_options.add_session_config_entry(
|
||||
"optimization.enable_memory_optimizer", self._runtime_options.memory_optimizer_config
|
||||
)
|
||||
session_options.add_session_config_entry(
|
||||
"optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level
|
||||
)
|
||||
# Disable weight prepacking
|
||||
session_options.add_session_config_entry("session.disable_prepacking", "1")
|
||||
|
||||
|
|
@ -375,7 +267,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._export_mode,
|
||||
)
|
||||
|
||||
if self._run_symbolic_shape_infer:
|
||||
if self._runtime_options.run_symbolic_shape_infer:
|
||||
self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(
|
||||
self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True
|
||||
)
|
||||
|
|
@ -433,7 +325,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
required_export_kwargs = {
|
||||
"input_names": self._input_info.names,
|
||||
"output_names": output_names,
|
||||
"opset_version": ortmodule.ONNX_OPSET_VERSION,
|
||||
"opset_version": self._runtime_options.onnx_opset_version,
|
||||
"do_constant_folding": False,
|
||||
"training": self._export_mode,
|
||||
"dynamic_axes": self._input_info.dynamic_axes,
|
||||
|
|
@ -462,8 +354,12 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
exported_model = onnx.load_model_from_string(f.getvalue())
|
||||
|
||||
exported_model = _post_process_after_export(exported_model, self._enable_custom_autograd_function)
|
||||
exported_model = _post_process_after_export(
|
||||
exported_model, self._runtime_options.enable_custom_autograd_function
|
||||
)
|
||||
|
||||
# If anything was captured by suppress_output during export, set the flag to
|
||||
# raise a single user warning letting users know in the log.
|
||||
if suppress_output.tell() > 0:
|
||||
self._warning_log_detected_during_export = True
|
||||
|
||||
|
|
@ -486,10 +382,10 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
def _get_graph_transformer_config(self) -> C.TrainingGraphTransformerConfiguration:
|
||||
graph_transformer_config = C.TrainingGraphTransformerConfiguration()
|
||||
graph_transformer_config.propagate_cast_ops_config = C.PropagateCastOpsConfiguration()
|
||||
graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level
|
||||
graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow
|
||||
graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy
|
||||
graph_transformer_config.enable_compute_optimizer = self._enable_compute_optimizer
|
||||
graph_transformer_config.propagate_cast_ops_config.level = self._runtime_options.propagate_cast_ops_level
|
||||
graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow
|
||||
graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy
|
||||
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
|
||||
return graph_transformer_config
|
||||
|
||||
def _initialize_graph_builder(self):
|
||||
|
|
@ -515,11 +411,11 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
grad_builder_config.initializer_names_to_train = initializer_names_to_train
|
||||
grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
|
||||
grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING
|
||||
grad_builder_config.enable_caching = self._enable_grad_acc_optimization
|
||||
grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization
|
||||
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(
|
||||
self._debug_options.logging.log_level
|
||||
)
|
||||
grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient
|
||||
grad_builder_config.use_memory_efficient_gradient = self._runtime_options.use_memory_efficient_gradient
|
||||
self._graph_builder = C.OrtModuleGraphBuilder()
|
||||
|
||||
# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
|
||||
|
|
@ -576,27 +472,14 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
enable sparsity-based optimization.
|
||||
|
||||
"""
|
||||
self._feature_map.extend(
|
||||
[
|
||||
[
|
||||
"Compute Optimizer",
|
||||
"ON" if self._enable_compute_optimizer else "OFF",
|
||||
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
|
||||
],
|
||||
[
|
||||
" -FLOPReduction",
|
||||
"ON" if self._enable_compute_optimizer else "OFF",
|
||||
"Reduce FLOPs by upstreaming shrinking-sized ops",
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
# Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density.
|
||||
if self._enable_sparse_optimizer or self._print_input_density:
|
||||
self._rt_inspector.enable_input_inspector(
|
||||
if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density:
|
||||
self._runtime_inspector.enable_input_inspector(
|
||||
self._onnx_models.exported_model, self._graph_builder.get_graph_info().user_input_names
|
||||
)
|
||||
|
||||
if self._enable_sparse_optimizer:
|
||||
if self._runtime_options.enable_sparse_optimizer:
|
||||
detected_device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(
|
||||
inputs, kwargs
|
||||
)
|
||||
|
|
@ -609,41 +492,31 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
inputs,
|
||||
kwargs,
|
||||
detected_device,
|
||||
self._rt_inspector,
|
||||
self._runtime_inspector,
|
||||
)
|
||||
|
||||
# Enable sparsity-based optimization when applicable.
|
||||
if len(label_sparsity_results) > 0:
|
||||
graph_transformer_config.sparse_label_input_names = list(label_sparsity_results.keys())
|
||||
self._logger.info("Label sparsity-based optimization is ON for %s", label_sparsity_results)
|
||||
sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()])
|
||||
self._feature_map.append(
|
||||
[
|
||||
" -LabelSparsityOpt",
|
||||
"ON",
|
||||
f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0",
|
||||
]
|
||||
self._runtime_options.label_sparsity_ratio = ",".join(
|
||||
[f"{k}:{v:.0f}%" for k, v in label_sparsity_results.items()]
|
||||
)
|
||||
|
||||
if self._enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
|
||||
if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
|
||||
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
|
||||
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
|
||||
sparsity_stat_str = ",".join([f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()])
|
||||
self._feature_map.append(
|
||||
[
|
||||
" -EmbedSparsityOpt",
|
||||
"ON",
|
||||
f"Input density: {sparsity_stat_str}, switch: ORTMODULE_ENABLE_SPARSE_OPTIMIZER=1/0",
|
||||
]
|
||||
self._runtime_options.embed_sparsity_ratio = ",".join(
|
||||
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
|
||||
)
|
||||
|
||||
# If users don't want to print input density, disable the input density observer to avoid overhead
|
||||
# when looping through inputs during training.
|
||||
if not self._print_input_density:
|
||||
self._rt_inspector.disable_input_inspector()
|
||||
if not self._runtime_options.print_input_density:
|
||||
self._runtime_inspector.disable_input_inspector()
|
||||
|
||||
if self._print_memory_stat:
|
||||
self._rt_inspector.enable_memory_inspector(self._original_module)
|
||||
if self._runtime_options.print_memory_stat:
|
||||
self._runtime_inspector.enable_memory_inspector(self._original_module)
|
||||
|
||||
def _log_feature_stats(self):
|
||||
rank = 0
|
||||
|
|
@ -653,12 +526,57 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
if rank != 0:
|
||||
return
|
||||
|
||||
self._feature_map.append(
|
||||
[
|
||||
feature_map: List[Tuple[str, bool, str]] = [
|
||||
("ATen Executor", True, "Dispatch ATen operators to ORT's ATen executor"),
|
||||
(
|
||||
"Cast Propagation",
|
||||
self._runtime_options.propagate_cast_ops_level > 0,
|
||||
f"Level {self._runtime_options.propagate_cast_ops_level} enabled",
|
||||
),
|
||||
(
|
||||
"Custom Function",
|
||||
self._runtime_options.enable_custom_autograd_function,
|
||||
"Support custom torch.autograd.Function export and execution",
|
||||
),
|
||||
(
|
||||
"Memory Optimizer",
|
||||
len(self._runtime_options.memory_optimizer_config) > 0,
|
||||
"Enable with env ORTMODULE_MEMORY_OPT_CONFIG=<config>",
|
||||
),
|
||||
]
|
||||
|
||||
if self._runtime_options.enable_compute_optimizer:
|
||||
feature_map.extend(
|
||||
[
|
||||
(
|
||||
"Compute Optimizer",
|
||||
self._runtime_options.enable_compute_optimizer,
|
||||
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
|
||||
),
|
||||
(
|
||||
" -FLOPReduction",
|
||||
self._runtime_options.enable_compute_optimizer,
|
||||
"Reduce FLOPs by upstreaming shrinking-sized ops",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if len(self._runtime_options.label_sparsity_ratio) > 0:
|
||||
feature_map.append(
|
||||
(" -LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}")
|
||||
)
|
||||
|
||||
if len(self._runtime_options.embed_sparsity_ratio) > 0:
|
||||
feature_map.append(
|
||||
(" -EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}")
|
||||
)
|
||||
|
||||
feature_map.append(
|
||||
(
|
||||
"Auto Fallback",
|
||||
"ON" if self._fallback_manager.policy is not _FallbackPolicy.FALLBACK_DISABLE else "OFF",
|
||||
self._runtime_options.fallback_policy is not _FallbackPolicy.FALLBACK_DISABLE,
|
||||
"Fallback to PyTorch when encountering unsupported ops",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference"
|
||||
|
|
@ -666,11 +584,10 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
stat = f"\n\n{_logger.LogColor.HEADER}***** ONNX Runtime Training (ORTModule) is accelerating your model *****{_logger.LogColor.ENDC}\n\n"
|
||||
stat += f"ORTModule is enabled with following features ON/OFF for [{mode}] mode:\n\n"
|
||||
for feature_tuple in self._feature_map:
|
||||
stat += f"{feature_tuple[0]:<20}:\t{feature_tuple[1]:<10}:\t{feature_tuple[2]:<80}\n"
|
||||
for feature_tuple in feature_map:
|
||||
switch_str = "ON" if feature_tuple[1] else "OFF"
|
||||
stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n"
|
||||
|
||||
# If anything was captured in fo, raise a single user warning letting users know that there was
|
||||
# any warning or error that was raised
|
||||
stat += f"\n{_logger.LogColor.WARNING}There were one or more warnings or errors raised while exporting the PyTorch model.\n"
|
||||
stat += f"Please enable INFO level logging with DebugOptions to view all warnings and errors.{_logger.LogColor.ENDC}\n\n"
|
||||
stat += f"Export duration: {self._export_duration_in_ms:.0f} milliseconds\n"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from ._fallback import _FallbackManager
|
|||
from ._inference_manager import InferenceManager
|
||||
from ._io import _FlattenedModule
|
||||
from ._training_manager import TrainingManager
|
||||
from .debug_options import DebugOptions
|
||||
from .options import DebugOptions
|
||||
|
||||
|
||||
class GraphExecutionManagerFactory:
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from onnxruntime.capi import _pybind_state as C
|
|||
from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils
|
||||
from ._execution_agent import InferenceAgent
|
||||
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
|
||||
from .debug_options import DebugOptions
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
|
||||
from .options import DebugOptions, _SkipCheck
|
||||
|
||||
|
||||
class InferenceManager(GraphExecutionManager):
|
||||
|
|
@ -92,21 +92,21 @@ class InferenceManager(GraphExecutionManager):
|
|||
|
||||
try:
|
||||
# Issue at most one warning message about fast path
|
||||
if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False:
|
||||
if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False:
|
||||
self._first_skip_check_warning = False
|
||||
self._logger.warning(
|
||||
"Fast path enabled - skipping checks. rebuild gradient graph: %s, execution agent recreation: %s, "
|
||||
"device check: %s",
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
|
||||
)
|
||||
|
||||
# If exporting module to ONNX for the first time, this skip check will not take effect.
|
||||
# It will only take effect on subsequent forward calls.
|
||||
build_graph = False
|
||||
if (
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
or not self._onnx_models.exported_model
|
||||
):
|
||||
# Exporting module to ONNX for the first time
|
||||
|
|
@ -129,7 +129,10 @@ class InferenceManager(GraphExecutionManager):
|
|||
# If creating the execution agent for the first time, this skip check will not take effect.
|
||||
# It will only take effect on subsequent forward calls.
|
||||
create_execution_session = False
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent:
|
||||
if (
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
|
||||
or not self._execution_agent
|
||||
):
|
||||
module_device = _utils.get_device_from_module(self._original_module)
|
||||
|
||||
create_execution_session = (
|
||||
|
|
@ -146,7 +149,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
# Create execution session creates the inference_session
|
||||
self._create_execution_agent()
|
||||
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
# Assert that the input and model device match
|
||||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
|
|
@ -158,7 +161,7 @@ class InferenceManager(GraphExecutionManager):
|
|||
inputs,
|
||||
kwargs,
|
||||
self._device,
|
||||
self._rt_inspector,
|
||||
self._runtime_inspector,
|
||||
)
|
||||
|
||||
user_outputs, _ = InferenceManager.execution_session_run_forward(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from logging import Logger
|
|||
|
||||
from ._fallback import _FallbackManager
|
||||
from ._torch_module_ort import TorchModuleORT
|
||||
from .debug_options import DebugOptions
|
||||
from .options import DebugOptions
|
||||
|
||||
|
||||
class TorchModuleFactory:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from . import _io, _utils
|
|||
from ._fallback import ORTModuleTorchModelException, _FallbackManager, wrap_exception
|
||||
from ._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from ._torch_module_interface import TorchModuleInterface
|
||||
from .debug_options import DebugOptions
|
||||
from .options import DebugOptions
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_alg
|
|||
from ._execution_agent import TrainingAgent
|
||||
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
|
||||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
|
||||
from ._io import _FlattenedModule, _InputInfo
|
||||
from ._runtime_inspector import Phase
|
||||
from .debug_options import DebugOptions
|
||||
from .options import DebugOptions, _SkipCheck
|
||||
|
||||
|
||||
class TrainingManager(GraphExecutionManager):
|
||||
|
|
@ -104,9 +104,9 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
Module outputs are returned to the user
|
||||
"""
|
||||
self._rt_inspector.inspect_memory(Phase.PRE_FORWARD)
|
||||
self._runtime_inspector.inspect_memory(Phase.PRE_FORWARD)
|
||||
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
# Assert that the input and model device match
|
||||
_utils._check_same_device(self._device, "Input argument to forward", *inputs)
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
for idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
ctx.mark_non_differentiable(user_outputs[idx])
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.POST_FORWARD)
|
||||
self._runtime_inspector.inspect_memory(Phase.POST_FORWARD)
|
||||
|
||||
return user_outputs
|
||||
|
||||
|
|
@ -147,10 +147,10 @@ class TrainingManager(GraphExecutionManager):
|
|||
def backward(ctx, *grad_outputs):
|
||||
"""Performs backward pass based on grad wrt module output"""
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD)
|
||||
self._runtime_inspector.inspect_memory(Phase.PRE_BACKWARD)
|
||||
|
||||
assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
|
||||
|
||||
# Unpack saved_tensor to trigger version detection that catches inplace corruption
|
||||
|
|
@ -198,7 +198,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
# This version only works if backward_outputs is an OrtValueVector.
|
||||
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.POST_BACKWARD)
|
||||
self._runtime_inspector.inspect_memory(Phase.POST_BACKWARD)
|
||||
|
||||
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
|
||||
|
||||
|
|
@ -226,21 +226,21 @@ class TrainingManager(GraphExecutionManager):
|
|||
return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
|
||||
|
||||
try:
|
||||
if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False:
|
||||
if self._first_skip_check_warning is True and self._runtime_options.skip_check.is_disabled() is False:
|
||||
# Only change this after the firs time a warning is issued.
|
||||
self._first_skip_check_warning = False
|
||||
self._logger.info(
|
||||
"Fast path enabled - skipping checks.Rebuild graph: %s, Execution agent: %s, Device check: %s",
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE),
|
||||
)
|
||||
|
||||
# If exporting module to ONNX for the first time, this skip check will not take effect.
|
||||
# It will only take effect on subsequent forward calls.
|
||||
build_gradient_graph = False
|
||||
if (
|
||||
self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
|
||||
or not self._onnx_models.exported_model
|
||||
):
|
||||
build_gradient_graph = self._export_model(*inputs, **kwargs)
|
||||
|
|
@ -274,7 +274,10 @@ class TrainingManager(GraphExecutionManager):
|
|||
# If creating the execution agent for the first time, this skip check will not take effect.
|
||||
# It will only take effect on subsequent forward calls.
|
||||
create_execution_session = False
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent:
|
||||
if (
|
||||
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
|
||||
or not self._execution_agent
|
||||
):
|
||||
device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(
|
||||
inputs, kwargs
|
||||
)
|
||||
|
|
@ -292,7 +295,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
self._create_execution_agent()
|
||||
|
||||
self._gradient_accumulation_manager.initialize(
|
||||
self._enable_grad_acc_optimization, self._flattened_module, self._graph_info
|
||||
self._runtime_options.enable_grad_acc_optimization, self._flattened_module, self._graph_info
|
||||
)
|
||||
|
||||
self._gradient_accumulation_manager.maybe_update_cache_before_run()
|
||||
|
|
@ -305,7 +308,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
inputs,
|
||||
kwargs,
|
||||
self._device,
|
||||
self._rt_inspector,
|
||||
self._runtime_inspector,
|
||||
)
|
||||
|
||||
return _io.unflatten_user_output(
|
||||
|
|
|
|||
|
|
@ -1,118 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# debug_options.py
|
||||
|
||||
import os
|
||||
|
||||
from ._logger import LogLevel
|
||||
|
||||
|
||||
class _SaveOnnxOptions:
|
||||
"""Configurable option to save ORTModule intermediate onnx models."""
|
||||
|
||||
# class variable
|
||||
_path_environment_key = "ORTMODULE_SAVE_ONNX_PATH"
|
||||
|
||||
def __init__(self, save, name_prefix):
|
||||
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix)
|
||||
|
||||
def _extract_info(self, save, name_prefix):
|
||||
# get the destination path from os env variable
|
||||
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, os.getcwd())
|
||||
# perform validation only when save is True
|
||||
if save:
|
||||
self._validate(save, name_prefix, destination_path)
|
||||
return save, name_prefix, destination_path
|
||||
|
||||
def _validate(self, save, name_prefix, destination_path):
|
||||
# check if directory is writable
|
||||
if not os.access(destination_path, os.W_OK):
|
||||
raise OSError(
|
||||
f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path."
|
||||
)
|
||||
|
||||
# check if input prefix is a string
|
||||
if not isinstance(name_prefix, str):
|
||||
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
|
||||
|
||||
# if save_onnx is set, save_onnx_prefix must be a non empty string
|
||||
if not name_prefix:
|
||||
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
|
||||
|
||||
@property
|
||||
def save(self):
|
||||
return self._save
|
||||
|
||||
@property
|
||||
def name_prefix(self):
|
||||
return self._name_prefix
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
|
||||
class _LoggingOptions:
|
||||
"""Configurable option to set the log level in ORTModule."""
|
||||
|
||||
# class variable
|
||||
_log_level_environment_key = "ORTMODULE_LOG_LEVEL"
|
||||
|
||||
def __init__(self, log_level):
|
||||
self._log_level = self._extract_info(log_level)
|
||||
|
||||
def _extract_info(self, log_level):
|
||||
# get the log_level from os env variable
|
||||
# OS environment variable log level superseeds the locally provided one
|
||||
self._validate(log_level)
|
||||
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
|
||||
return log_level
|
||||
|
||||
def _validate(self, log_level):
|
||||
# check if log_level is an instance of LogLevel
|
||||
if not isinstance(log_level, LogLevel):
|
||||
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
|
||||
|
||||
@property
|
||||
def log_level(self) -> LogLevel:
|
||||
return self._log_level
|
||||
|
||||
|
||||
class DebugOptions:
|
||||
"""Configurable debugging options for ORTModule.
|
||||
|
||||
Args:
|
||||
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
|
||||
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
|
||||
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
|
||||
takes precedence.
|
||||
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
|
||||
The output directory of the onnx models by default is set to the current working directory.
|
||||
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
|
||||
set to the destination directory path.
|
||||
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
|
||||
Must be provided if save_onnx is True
|
||||
|
||||
Raises:
|
||||
OSError: If save_onnx is True and output directory is not writable.
|
||||
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
|
||||
log_level is not an instance of LogLevel.
|
||||
ValueError: If save_onnx is True and name_prefix is an empty string.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=""):
|
||||
self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix)
|
||||
self._logging = _LoggingOptions(log_level)
|
||||
|
||||
@property
|
||||
def save_onnx_models(self):
|
||||
"""Accessor for the ONNX saving configuration."""
|
||||
|
||||
return self._save_onnx_models
|
||||
|
||||
@property
|
||||
def logging(self):
|
||||
"""Accessor for the logging configuration."""
|
||||
|
||||
return self._logging
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# debug_options.py
|
||||
# _hierarchical_ortmodule.py
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from .... import ortmodule
|
||||
from ... import ORTModule
|
||||
from ...debug_options import DebugOptions, LogLevel
|
||||
from onnxruntime.training import ortmodule
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel
|
||||
|
||||
# nn.Module's in this set are considered exportable to ONNX.
|
||||
# For other nn.Module's, torch.onnx.export is called to check if
|
||||
|
|
|
|||
|
|
@ -9,11 +9,9 @@ from functools import reduce
|
|||
from types import SimpleNamespace
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.training import ortmodule
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
from onnxruntime.training.ortmodule.options import DebugOptions, LogLevel, _SaveOnnxOptions, _SkipCheck
|
||||
|
||||
from ..._fallback import _FallbackPolicy
|
||||
from ..._graph_execution_manager import _SkipCheck
|
||||
from ...debug_options import DebugOptions, LogLevel, _SaveOnnxOptions
|
||||
from . import JSON_PATH_ENVIRONMENT_KEY
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -39,15 +37,15 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data):
|
|||
log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
def _update_strategy():
|
||||
ortmodule_config_accessor._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[
|
||||
ortmodule_config_accessor._runtime_options.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[
|
||||
data.PropagateCastOps.Strategy
|
||||
]
|
||||
|
||||
def _update_level():
|
||||
ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level
|
||||
ortmodule_config_accessor._runtime_options.propagate_cast_ops_level = data.PropagateCastOps.Level
|
||||
|
||||
def _update_allow():
|
||||
ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow
|
||||
ortmodule_config_accessor._runtime_options.propagate_cast_ops_allow = data.PropagateCastOps.Allow
|
||||
|
||||
key_to_function_mapping = {"Strategy": _update_strategy, "Level": _update_level, "Allow": _update_allow}
|
||||
|
||||
|
|
@ -64,8 +62,7 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data):
|
|||
assert isinstance(
|
||||
data.UseExternalGPUAllocator, bool
|
||||
), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator
|
||||
ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses()
|
||||
ortmodule_config_accessor._runtime_options.use_external_gpu_allocator = data.UseExternalGPUAllocator
|
||||
|
||||
|
||||
def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
||||
|
|
@ -79,7 +76,11 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
|||
assert isinstance(
|
||||
data.EnableCustomAutogradFunction, bool
|
||||
), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
|
||||
|
||||
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
|
||||
|
||||
enable_custom_autograd_support(data.EnableCustomAutogradFunction)
|
||||
ortmodule_config_accessor._runtime_options.enable_custom_autograd_function = data.EnableCustomAutogradFunction
|
||||
|
||||
|
||||
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
||||
|
|
@ -91,7 +92,7 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
|||
assert isinstance(
|
||||
data.EnableGradAccOptimization, bool
|
||||
), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization
|
||||
ortmodule_config_accessor._runtime_options.enable_grad_acc_optimization = data.EnableGradAccOptimization
|
||||
|
||||
|
||||
def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
|
||||
|
|
@ -103,7 +104,7 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data):
|
|||
assert isinstance(
|
||||
data.RunSymbolicShapeInference, bool
|
||||
), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference
|
||||
ortmodule_config_accessor._runtime_options.run_symbolic_shape_infer = data.RunSymbolicShapeInference
|
||||
|
||||
|
||||
def _load_use_static_shape(ortmodule_config_accessor, data):
|
||||
|
|
@ -113,7 +114,7 @@ def _load_use_static_shape(ortmodule_config_accessor, data):
|
|||
log.info(f"Found keyword {_load_use_static_shape.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.UseStaticShape, bool), f"{_load_use_static_shape.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_static_shape = data.UseStaticShape
|
||||
ortmodule_config_accessor._runtime_options.use_static_shape = data.UseStaticShape
|
||||
|
||||
|
||||
def _load_skip_check(ortmodule_config_accessor, data):
|
||||
|
|
@ -124,7 +125,7 @@ def _load_skip_check(ortmodule_config_accessor, data):
|
|||
|
||||
skip_check = reduce(lambda x, y: x | y, [_SkipCheck[name] for name in data.SkipCheck])
|
||||
if skip_check.value > 0:
|
||||
ortmodule_config_accessor._skip_check = skip_check
|
||||
ortmodule_config_accessor._runtime_options.skip_check = skip_check
|
||||
|
||||
|
||||
def _load_debug_options(ortmodule_config_accessor, data):
|
||||
|
|
@ -177,7 +178,7 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data):
|
|||
assert isinstance(
|
||||
data.UseMemoryEfficientGradient, bool
|
||||
), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient
|
||||
ortmodule_config_accessor._runtime_options.use_memory_efficient_gradient = data.UseMemoryEfficientGradient
|
||||
|
||||
|
||||
def _load_fallback_policy(ortmodule_config_accessor, data):
|
||||
|
|
@ -198,7 +199,7 @@ def _load_onnx_opset_version(ortmodule_config_accessor, data):
|
|||
log.info(f"Found keyword {_load_onnx_opset_version.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(data.OnnxOpsetVersion, int), f"{_load_onnx_opset_version.loading_key} must be an int"
|
||||
ortmodule.ONNX_OPSET_VERSION = data.OnnxOpsetVersion
|
||||
ortmodule_config_accessor._runtime_options.onnx_opset_version = data.OnnxOpsetVersion
|
||||
|
||||
|
||||
def _define_load_function_keys():
|
||||
|
|
|
|||
287
orttraining/orttraining/python/training/ortmodule/options.py
Normal file
287
orttraining/orttraining/python/training/ortmodule/options.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# options.py
|
||||
|
||||
import os
|
||||
from enum import IntFlag
|
||||
from functools import reduce
|
||||
from logging import Logger
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from ._fallback import _FallbackPolicy
|
||||
from ._logger import LogLevel
|
||||
from ._utils import parse_os_env_skip_check_flags
|
||||
|
||||
|
||||
class _SaveOnnxOptions:
|
||||
"""Configurable option to save ORTModule intermediate onnx models."""
|
||||
|
||||
# class variable
|
||||
_path_environment_key = "ORTMODULE_SAVE_ONNX_PATH"
|
||||
|
||||
def __init__(self, save, name_prefix, path: str):
|
||||
self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix, path)
|
||||
|
||||
def _extract_info(self, save, name_prefix, path: str):
|
||||
# get the destination path from os env variable
|
||||
default_path = path if len(path) > 0 else os.getcwd()
|
||||
destination_path = os.getenv(_SaveOnnxOptions._path_environment_key, default_path)
|
||||
# perform validation only when save is True
|
||||
if save:
|
||||
self._validate(save, name_prefix, destination_path)
|
||||
return save, name_prefix, destination_path
|
||||
|
||||
def _validate(self, save, name_prefix, destination_path):
|
||||
# check if directory is writable
|
||||
if not os.access(destination_path, os.W_OK):
|
||||
raise OSError(
|
||||
f"Directory {destination_path} is not writable. Please set the "
|
||||
f"{_SaveOnnxOptions._path_environment_key} environment variable to a writable path."
|
||||
)
|
||||
|
||||
# check if input prefix is a string
|
||||
if not isinstance(name_prefix, str):
|
||||
raise TypeError(f"Expected name prefix of type str, got {type(name_prefix)}.")
|
||||
|
||||
# if save_onnx is set, save_onnx_prefix must be a non empty string
|
||||
if not name_prefix:
|
||||
raise ValueError("onnx_prefix must be provided when save_onnx is set.")
|
||||
|
||||
@property
|
||||
def save(self):
|
||||
return self._save
|
||||
|
||||
@property
|
||||
def name_prefix(self):
|
||||
return self._name_prefix
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
|
||||
class _LoggingOptions:
|
||||
"""Configurable option to set the log level in ORTModule."""
|
||||
|
||||
# class variable
|
||||
_log_level_environment_key = "ORTMODULE_LOG_LEVEL"
|
||||
|
||||
def __init__(self, log_level):
|
||||
self._log_level = self._extract_info(log_level)
|
||||
|
||||
def _extract_info(self, log_level):
|
||||
# get the log_level from os env variable
|
||||
# OS environment variable log level superseeds the locally provided one
|
||||
self._validate(log_level)
|
||||
log_level = LogLevel[os.getenv(_LoggingOptions._log_level_environment_key, log_level.name)]
|
||||
return log_level
|
||||
|
||||
def _validate(self, log_level):
|
||||
# check if log_level is an instance of LogLevel
|
||||
if not isinstance(log_level, LogLevel):
|
||||
raise TypeError(f"Expected log_level of type LogLevel, got {type(log_level)}.")
|
||||
|
||||
@property
|
||||
def log_level(self) -> LogLevel:
|
||||
return self._log_level
|
||||
|
||||
|
||||
class DebugOptions:
|
||||
"""Configurable debugging options for ORTModule.
|
||||
|
||||
Args:
|
||||
log_level (:obj:`LogLevel`, optional): Configure ORTModule log level. Defaults to LogLevel.WARNING.
|
||||
log_level can also be set by setting the environment variable "ORTMODULE_LOG_LEVEL" to one of
|
||||
"VERBOSE", "INFO", "WARNING", "ERROR", "FATAL". In case both are set, the environment variable
|
||||
takes precedence.
|
||||
save_onnx (:obj:`bool`, optional): Configure ORTModule to save onnx models. Defaults to False.
|
||||
The output directory of the onnx models by default is set to the current working directory.
|
||||
To change the output directory, the environment variable "ORTMODULE_SAVE_ONNX_PATH" can be
|
||||
set to the destination directory path.
|
||||
onnx_prefix (:obj:`str`, optional): Name prefix to the ORTModule ONNX models saved file names.
|
||||
Must be provided if save_onnx is True
|
||||
|
||||
Raises:
|
||||
OSError: If save_onnx is True and output directory is not writable.
|
||||
TypeError: If save_onnx is True and name_prefix is not a valid string. Or if
|
||||
log_level is not an instance of LogLevel.
|
||||
ValueError: If save_onnx is True and name_prefix is an empty string.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix="", save_path="", config=None):
|
||||
self.log_level = log_level
|
||||
self.save_onnx = save_onnx
|
||||
self.onnx_prefix = onnx_prefix
|
||||
|
||||
self._save_onnx_models = _SaveOnnxOptions(self.save_onnx, self.onnx_prefix, save_path)
|
||||
self._logging = _LoggingOptions(self.log_level)
|
||||
|
||||
@property
|
||||
def save_onnx_models(self):
|
||||
"""Accessor for the ONNX saving configuration."""
|
||||
|
||||
return self._save_onnx_models
|
||||
|
||||
@property
|
||||
def logging(self):
|
||||
"""Accessor for the logging configuration."""
|
||||
|
||||
return self._logging
|
||||
|
||||
|
||||
class _SkipCheck(IntFlag):
|
||||
"""Enumeration to specify which checks should be skipped, allowing faster execution"""
|
||||
|
||||
SKIP_CHECK_DISABLED = 1
|
||||
SKIP_CHECK_DEVICE = 2
|
||||
SKIP_CHECK_BUILD_GRADIENT = 4
|
||||
SKIP_CHECK_EXECUTION_AGENT = 8
|
||||
|
||||
def is_set(self, check):
|
||||
"""Check whether `check` is set on the `_SkipCheck instance
|
||||
|
||||
SKIP_CHECK_DISABLED implies the check will return False
|
||||
"""
|
||||
|
||||
return not _SkipCheck.is_disabled(self) and check in self
|
||||
|
||||
def is_disabled(self):
|
||||
"""Check whether `_SkipCheck.SKIP_CHECK_DISABLED is set on the `_SkipCheck instance"""
|
||||
|
||||
return _SkipCheck.SKIP_CHECK_DISABLED in self
|
||||
|
||||
|
||||
class _RuntimeOptions:
|
||||
"""Configurable runtime options for ORTModule."""
|
||||
|
||||
def __init__(self, logger: Logger):
|
||||
"""Constructor for Options.
|
||||
|
||||
Initially set all the options to their default values, then override them with the values
|
||||
from the environment variables.
|
||||
"""
|
||||
self._logger = logger
|
||||
|
||||
self.onnx_opset_version = ortmodule.ONNX_OPSET_VERSION
|
||||
self.conv_algo_search = "HEURISTIC"
|
||||
|
||||
# Configuration for cast optimization.
|
||||
# Specify cast propagation strategy. Currently, three strategies are available:
|
||||
# NONE, INSERT-AND-REDUCE and FLOOD-FILL
|
||||
# The default is FLOOD_FILL, expand FP16 computation regions in the graph using
|
||||
# allowed opcodes for the given level.
|
||||
self.propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.FLOOD_FILL
|
||||
# Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
|
||||
# - If the propagate_cast_ops_level is set to zero, then the transformation considers only the opcodes
|
||||
# specified by propagate_cast_ops_allow as "FP16 safe", to insert/(re)move cast operations before/after
|
||||
# to perform such operations in reduced (16-bit) precision.
|
||||
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by
|
||||
# propagate_cast_ops_allow, use onnxruntime predetermined list of opcodes considered safe to move
|
||||
# before/after the cast operation.
|
||||
# - Onnxruntime Level 1 predetermined "FP16 safe" opcodes include only opcodes that do not perform
|
||||
# any computation such as Transpose, Split, Reshape, etc., or the computation is actually in Float
|
||||
# such as GeLU, etc.
|
||||
# - Whereas Level 2 predetermined "FP16 safe" opcodes include opcodes that perform computation using
|
||||
# contrib ops, Dropout, LayerNormalization, etc.
|
||||
self.propagate_cast_ops_level = 1
|
||||
# List of opcodes to be considered safe to move before/after the cast operation if propagate_cast_ops_level
|
||||
# is zero.
|
||||
self.propagate_cast_ops_allow = []
|
||||
|
||||
# default execution order is priority-based for both dynamic/static shape input for now
|
||||
# if we observe the benefit of static shape, we can expose this flag to the user
|
||||
self.use_static_shape = False
|
||||
|
||||
# flag to enable symbolic shape inference for dynamic shape inputs to improve performance
|
||||
self.run_symbolic_shape_infer = True
|
||||
|
||||
# PyTorch custom Autograd function support
|
||||
from ._custom_autograd_function import custom_autograd_function_enabler
|
||||
|
||||
self.enable_custom_autograd_function = custom_autograd_function_enabler.state
|
||||
|
||||
self.use_external_gpu_allocator = True
|
||||
|
||||
# WIP feature to enable caching in Gradient accumulation scenario.
|
||||
self.enable_grad_acc_optimization = False
|
||||
|
||||
# Memory-aware gradient builder.
|
||||
self.use_memory_efficient_gradient = False
|
||||
|
||||
# Configuration for compute optimization.
|
||||
self.enable_compute_optimizer = True
|
||||
self.enable_sparse_optimizer = True
|
||||
self.label_sparsity_ratio = ""
|
||||
self.embed_sparsity_ratio = ""
|
||||
self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done.
|
||||
|
||||
# Configuration for memory optimization.
|
||||
self.memory_optimizer_config = ""
|
||||
self.probe_level = "1"
|
||||
|
||||
# Configuration for dev tools.
|
||||
self.print_input_density = False
|
||||
self.print_memory_stat = False
|
||||
|
||||
# Configuration for fallback.
|
||||
self.fallback_policy = ortmodule.ORTMODULE_FALLBACK_POLICY
|
||||
|
||||
# Configuration for skip check.
|
||||
# Indicators of some logic have been executed previously and thus could be skipped for faster training
|
||||
# default is enabled, if not defined in os env
|
||||
self.skip_check = _SkipCheck(
|
||||
_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT
|
||||
)
|
||||
|
||||
# Override the feature config if it exists in os env.
|
||||
self._override_from_env_vars()
|
||||
|
||||
def _override_from_env_vars(self):
|
||||
self.onnx_opset_version = int(os.getenv("ORTMODULE_ONNX_OPSET_VERSION", self.onnx_opset_version))
|
||||
self.conv_algo_search = os.getenv("ORTMODULE_CONV_ALGO_SEARCH", self.conv_algo_search)
|
||||
if self.conv_algo_search not in ["HEURISTIC", "EXHAUSTIVE"]:
|
||||
self._logger.warning("Invalid value of env CONV_ALGO_SEARCH. Must be HEURISTIC or EXHAUSTIVE.")
|
||||
self.conv_algo_search = "HEURISTIC"
|
||||
|
||||
# Configuration for compute optimization.
|
||||
compute_optimizer_reset = False
|
||||
if "ORTMODULE_ENABLE_COMPUTE_OPTIMIZER" in os.environ:
|
||||
self.enable_compute_optimizer = int(os.getenv("ORTMODULE_ENABLE_COMPUTE_OPTIMIZER")) == 1
|
||||
compute_optimizer_reset = True
|
||||
|
||||
if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ or compute_optimizer_reset:
|
||||
if "ORTMODULE_ENABLE_SPARSE_OPTIMIZER" in os.environ:
|
||||
self.enable_sparse_optimizer = int(os.getenv("ORTMODULE_ENABLE_SPARSE_OPTIMIZER")) == 1
|
||||
self.enable_sparse_optimizer = self.enable_compute_optimizer and self.enable_sparse_optimizer
|
||||
|
||||
# TODO(pengwa): remove once validation on more models are done.
|
||||
if "ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER" in os.environ:
|
||||
self.enable_embedding_sparse_optimizer = (
|
||||
self.enable_sparse_optimizer and int(os.getenv("ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER")) == 1
|
||||
)
|
||||
|
||||
# Configuration for memory optimization.
|
||||
self.memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config)
|
||||
self.probe_level = os.getenv("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", self.probe_level)
|
||||
|
||||
# Configuration for dev tools.
|
||||
if "ORTMODULE_PRINT_INPUT_DENSITY" in os.environ:
|
||||
self.print_input_density = int(os.getenv("ORTMODULE_PRINT_INPUT_DENSITY")) == 1
|
||||
if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ:
|
||||
self.print_memory_stat = int(os.getenv("ORTMODULE_PRINT_MEMORY_STATS")) == 1
|
||||
|
||||
# Configuration for fallback.
|
||||
if "ORTMODULE_FALLBACK_POLICY" in os.environ:
|
||||
self.fallback_policy = os.getenv("ORTMODULE_FALLBACK_POLICY")
|
||||
if isinstance(self.fallback_policy, str):
|
||||
self.fallback_policy = _FallbackPolicy[self.fallback_policy]
|
||||
|
||||
# Configuration for skip check.
|
||||
if "ORTMODULE_SKIPCHECK_POLICY" in os.environ:
|
||||
self.skip_check = reduce(
|
||||
lambda x, y: x | y,
|
||||
[_SkipCheck[name] for name in parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")],
|
||||
)
|
||||
|
|
@ -10,7 +10,7 @@ from ._torch_module_ort import TorchModuleORT
|
|||
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
|
||||
from ._custom_gradient_registry import CustomGradientRegistry
|
||||
from . import _utils
|
||||
from .debug_options import DebugOptions
|
||||
from .options import DebugOptions
|
||||
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException
|
||||
from ._logger import ortmodule_loglevel_to_python_loglevel
|
||||
from onnxruntime.training import ortmodule
|
||||
|
|
@ -77,7 +77,9 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Support contrib OPs
|
||||
pytorch_export_contrib_ops.register()
|
||||
CustomOpSymbolicRegistry.register_all()
|
||||
CustomOpSymbolicRegistry.register_all(
|
||||
self._torch_module._execution_manager(module.training)._runtime_options.onnx_opset_version
|
||||
)
|
||||
CustomGradientRegistry.register_all()
|
||||
|
||||
# Warn user if there are name collisions between user model's and ORTModule attributes
|
||||
|
|
@ -320,7 +322,9 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
# Re-register contrib OPs
|
||||
pytorch_export_contrib_ops.register()
|
||||
CustomOpSymbolicRegistry.register_all()
|
||||
CustomOpSymbolicRegistry.register_all(
|
||||
self._torch_module._execution_manager(self.module.training)._runtime_options.onnx_opset_version
|
||||
)
|
||||
CustomGradientRegistry.register_all()
|
||||
|
||||
# Re-initialize the ORTModule forward method
|
||||
|
|
|
|||
|
|
@ -30,16 +30,9 @@ from transformers.modeling_outputs import SequenceClassifierOutput
|
|||
|
||||
import onnxruntime.training.ortmodule as ortmodule_module
|
||||
from onnxruntime.training.optim import AdamWMode, FusedAdam
|
||||
from onnxruntime.training.ortmodule import (
|
||||
DebugOptions,
|
||||
LogLevel,
|
||||
ORTModule,
|
||||
_fallback,
|
||||
_graph_execution_manager,
|
||||
_io,
|
||||
_utils,
|
||||
)
|
||||
from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils
|
||||
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
|
||||
from onnxruntime.training.ortmodule.options import _SkipCheck
|
||||
|
||||
DEFAULT_OPSET = 15
|
||||
|
||||
|
|
@ -4454,7 +4447,7 @@ def test_ortmodule_gradient_accumulation_optimization_correctness():
|
|||
|
||||
# model with optimization enabled
|
||||
opt_model = ORTModule(copy.deepcopy(pt_model))
|
||||
opt_model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
opt_model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True
|
||||
opt_optimizer = torch.optim.Adam(opt_model.parameters())
|
||||
|
||||
def run_step(model, x):
|
||||
|
|
@ -4879,10 +4872,10 @@ def test_ortmodule_ortmodule_method_attribute_copy():
|
|||
@pytest.mark.parametrize(
|
||||
"policy_str, policy",
|
||||
[
|
||||
("SKIP_CHECK_DISABLED", _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),
|
||||
("SKIP_CHECK_DEVICE", _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE),
|
||||
("SKIP_CHECK_BUILD_GRADIENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
("SKIP_CHECK_EXECUTION_AGENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
("SKIP_CHECK_DISABLED", _SkipCheck.SKIP_CHECK_DISABLED),
|
||||
("SKIP_CHECK_DEVICE", _SkipCheck.SKIP_CHECK_DEVICE),
|
||||
("SKIP_CHECK_BUILD_GRADIENT", _SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
("SKIP_CHECK_EXECUTION_AGENT", _SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
],
|
||||
)
|
||||
def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
|
||||
|
|
@ -4893,7 +4886,7 @@ def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
|
|||
ort_model = ORTModule(model)
|
||||
|
||||
for training_mode in [False, True]:
|
||||
assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy
|
||||
assert ort_model._torch_module._execution_manager(training_mode)._runtime_options.skip_check == policy
|
||||
|
||||
del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
|
||||
|
||||
|
|
@ -5227,10 +5220,8 @@ def test_sigmoid_grad_opset13():
|
|||
N, D_in, H, D_out = 120, 15360, 500, 15360 # noqa: N806
|
||||
pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device)
|
||||
|
||||
old_opst_cst = ortmodule_module.ONNX_OPSET_VERSION
|
||||
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
|
||||
assert ortmodule_module.ONNX_OPSET_VERSION == 15
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
|
|
@ -5256,21 +5247,25 @@ def test_sigmoid_grad_opset13():
|
|||
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
|
||||
else:
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
|
||||
assert ortmodule_module.ONNX_OPSET_VERSION == 13
|
||||
ortmodule_module.ONNX_OPSET_VERSION = old_opst_cst
|
||||
|
||||
assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
|
||||
def test_opset_version_change(opset_version):
|
||||
original_env = None
|
||||
if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ:
|
||||
original_env = os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
|
||||
del os.environ["ORTMODULE_ONNX_OPSET_VERSION"]
|
||||
|
||||
device = "cuda"
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
|
||||
ort_model = ORTModule(model)
|
||||
|
||||
ortmodule_module.ONNX_OPSET_VERSION = opset_version
|
||||
ort_model = ORTModule(model)
|
||||
|
||||
# Make sure model runs without any exception
|
||||
prediction = ort_model(x)
|
||||
|
|
@ -5282,6 +5277,9 @@ def test_opset_version_change(opset_version):
|
|||
exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model
|
||||
assert exported_model.opset_import[0].version == opset_version
|
||||
|
||||
if original_env is not None:
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = original_env
|
||||
|
||||
|
||||
def test_serialize_ortmodule():
|
||||
device = "cuda"
|
||||
|
|
|
|||
|
|
@ -420,7 +420,7 @@ def main():
|
|||
)
|
||||
|
||||
model = ORTModule(model, debug_options)
|
||||
model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True
|
||||
model._torch_module._execution_manager(is_training=True)._runtime_options.enable_grad_acc_optimization = True
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
|
|
|
|||
|
|
@ -38,27 +38,29 @@ def test_load_config_from_json_1():
|
|||
ort_model_attributes = model._torch_module._execution_manager(training_mode)
|
||||
|
||||
# test propagate cast ops
|
||||
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
|
||||
assert ort_model_attributes._propagate_cast_ops_level == 3
|
||||
assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"]
|
||||
assert (
|
||||
ort_model_attributes._runtime_options.propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
|
||||
)
|
||||
assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 3
|
||||
assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["ABC", "DEF"]
|
||||
|
||||
# test use external gpu allocator
|
||||
assert ort_model_attributes._use_external_gpu_allocator is False
|
||||
assert ort_model_attributes._runtime_options.use_external_gpu_allocator is False
|
||||
|
||||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function is True
|
||||
assert ort_model_attributes._runtime_options.enable_custom_autograd_function is True
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape is True
|
||||
assert ort_model_attributes._runtime_options.use_static_shape is True
|
||||
|
||||
# test run symbolic shape inference
|
||||
assert ort_model_attributes._run_symbolic_shape_infer is False
|
||||
assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is False
|
||||
|
||||
# test enable grad acc optimization
|
||||
assert ort_model_attributes._enable_grad_acc_optimization is True
|
||||
assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is True
|
||||
|
||||
# test skip check
|
||||
assert ort_model_attributes._skip_check.value == 14
|
||||
assert ort_model_attributes._runtime_options.skip_check.value == 14
|
||||
|
||||
# test debug options
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.save is True
|
||||
|
|
@ -66,13 +68,13 @@ def test_load_config_from_json_1():
|
|||
assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE"
|
||||
|
||||
# test use memory aware gradient builder.
|
||||
assert ort_model_attributes._use_memory_efficient_gradient is False
|
||||
assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is False
|
||||
|
||||
# test fallback policy
|
||||
assert ort_model_attributes._fallback_manager.policy.value == 1
|
||||
|
||||
# assert onnx opset version
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 13
|
||||
assert ort_model_attributes._runtime_options.onnx_opset_version == 13
|
||||
|
||||
|
||||
def test_load_config_from_json_2():
|
||||
|
|
@ -91,27 +93,30 @@ def test_load_config_from_json_2():
|
|||
ort_model_attributes = model._torch_module._execution_manager(training_mode)
|
||||
|
||||
# test propagate cast ops
|
||||
assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE
|
||||
assert ort_model_attributes._propagate_cast_ops_level == 5
|
||||
assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"]
|
||||
assert (
|
||||
ort_model_attributes._runtime_options.propagate_cast_ops_strategy
|
||||
== C.PropagateCastOpsStrategy.INSERT_AND_REDUCE
|
||||
)
|
||||
assert ort_model_attributes._runtime_options.propagate_cast_ops_level == 5
|
||||
assert ort_model_attributes._runtime_options.propagate_cast_ops_allow == ["XYZ", "PQR"]
|
||||
|
||||
# test use external gpu allocator
|
||||
assert ort_model_attributes._use_external_gpu_allocator is True
|
||||
assert ort_model_attributes._runtime_options.use_external_gpu_allocator is True
|
||||
|
||||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function is False
|
||||
assert ort_model_attributes._runtime_options.enable_custom_autograd_function is False
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape is False
|
||||
assert ort_model_attributes._runtime_options.use_static_shape is False
|
||||
|
||||
# test run symbolic shape inference
|
||||
assert ort_model_attributes._run_symbolic_shape_infer is True
|
||||
assert ort_model_attributes._runtime_options.run_symbolic_shape_infer is True
|
||||
|
||||
# test enable grad acc optimization
|
||||
assert ort_model_attributes._enable_grad_acc_optimization is False
|
||||
assert ort_model_attributes._runtime_options.enable_grad_acc_optimization is False
|
||||
|
||||
# test skip check
|
||||
assert ort_model_attributes._skip_check.value == 10
|
||||
assert ort_model_attributes._runtime_options.skip_check.value == 10
|
||||
|
||||
# test debug options
|
||||
assert ort_model_attributes._debug_options.save_onnx_models.save is True
|
||||
|
|
@ -119,10 +124,10 @@ def test_load_config_from_json_2():
|
|||
assert ort_model_attributes._debug_options.logging.log_level.name == "INFO"
|
||||
|
||||
# test use memory aware gradient builder.
|
||||
assert ort_model_attributes._use_memory_efficient_gradient is True
|
||||
assert ort_model_attributes._runtime_options.use_memory_efficient_gradient is True
|
||||
|
||||
# test fallback policy
|
||||
assert ort_model_attributes._fallback_manager.policy.value == 250
|
||||
|
||||
# assert onnx opset version
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 12
|
||||
assert ort_model_attributes._runtime_options.onnx_opset_version == 12
|
||||
|
|
|
|||
Loading…
Reference in a new issue