mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Auto forward non method attribute lookups to the user's model and bind custom methods to ORTModule (#8798)
This commit is contained in:
parent
c343f7cb43
commit
0cc2909573
5 changed files with 375 additions and 8 deletions
|
|
@ -22,6 +22,7 @@ from onnxruntime.capi import _pybind_state as C
|
|||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
from functools import reduce
|
||||
import io
|
||||
import inspect
|
||||
import os
|
||||
|
|
@ -94,7 +95,10 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._execution_agent = None
|
||||
|
||||
# indicators of some logic have been executed previously thus could be skipped for faster training
|
||||
self._skip_check = _SkipCheck.SKIP_CHECK_DISABLED
|
||||
self._skip_check = reduce(lambda x, y: x | y,
|
||||
[_SkipCheck[name] for name in
|
||||
_utils.parse_os_env_skip_check_flags('ORTMODULE_SKIPCHECK_POLICY',
|
||||
_SkipCheck.SKIP_CHECK_DISABLED.name)])
|
||||
self._first_skip_check_warning = True
|
||||
|
||||
# Graph transformer config
|
||||
|
|
@ -158,6 +162,10 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# Memory aware gradient builder.
|
||||
self._use_memory_efficient_gradient = False
|
||||
|
||||
# Flag to re-export the model due to attribute change on original module.
|
||||
# Re-export will be avoided if _skip_check is enabled.
|
||||
self._original_model_has_changed = False
|
||||
|
||||
def _get_torch_gpu_allocator_function_addresses(self):
|
||||
if self._use_external_gpu_allocator and torch.cuda.is_available():
|
||||
# CPP extension to get torch GPU allocator's alloc and free function addresses
|
||||
|
|
@ -278,7 +286,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
schema = _io._extract_schema(
|
||||
{'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)})
|
||||
if self._onnx_models.exported_model and schema == self._input_info.schema:
|
||||
if self._onnx_models.exported_model and schema == self._input_info.schema and not self._original_model_has_changed:
|
||||
# All required models have already been exported previously
|
||||
return False
|
||||
|
||||
|
|
@ -421,3 +429,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# between forward calls.
|
||||
self._graph_initializers = [param for name, param in self._flattened_module.named_parameters()
|
||||
if name in self._graph_initializer_names]
|
||||
|
||||
def signal_model_changed(self):
|
||||
"""Signals the execution manager to re-export the model on the next forward call"""
|
||||
self._original_model_has_changed = True
|
||||
|
|
|
|||
|
|
@ -92,4 +92,4 @@ class TorchModulePytorch(TorchModuleInterface):
|
|||
|
||||
@TorchModuleInterface.module.getter
|
||||
def module(self):
|
||||
return self._original_module.module
|
||||
return self._original_module
|
||||
|
|
|
|||
|
|
@ -7,8 +7,14 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
|||
from onnxruntime.capi import _pybind_state as C
|
||||
from ._fallback import _FallbackManager, ORTModuleFallbackException, ORTModuleDeviceException, wrap_exception
|
||||
|
||||
import os
|
||||
import copy
|
||||
import inspect
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
from typing import List
|
||||
import types
|
||||
import warnings
|
||||
|
||||
|
||||
def _ortvalue_to_torch_tensor(ortvalue):
|
||||
|
|
@ -112,3 +118,63 @@ def _create_iobinding(io_binding, inputs, model, device):
|
|||
|
||||
for value_info in model.graph.output:
|
||||
io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
|
||||
|
||||
def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module,
|
||||
user_module: torch.nn.Module):
|
||||
"""Warns if there are any common attributes between the user's model and ORTModule and binds user methods to ORTModule
|
||||
|
||||
If there are methods defined on the user's model that ORTModule does not recognize (custom methods),
|
||||
then this function binds these methods to ORTModule.
|
||||
|
||||
Args:
|
||||
ortmodule: the ORTModule instance
|
||||
user_module: the user's torch.nn.Module
|
||||
|
||||
Raises:
|
||||
UserWarning: If there are any overlapping attributes between the ortmodule and user_module (except forward)
|
||||
"""
|
||||
|
||||
ortmodule_attributes = dict(inspect.getmembers(ortmodule))
|
||||
torch_module_attributes = dict(inspect.getmembers(torch.nn.Module()))
|
||||
user_module_attributes = inspect.getmembers(user_module)
|
||||
|
||||
# Check if any user defined attribute collides with ORTModule's attributes
|
||||
for attribute_name, attribute in user_module_attributes:
|
||||
if inspect.ismethod(attribute):
|
||||
# Skip the dunder methods
|
||||
if attribute_name.startswith('__'):
|
||||
continue
|
||||
|
||||
# if the attribute is not a torch attribute, or if the torch attribute
|
||||
# corresponding to attribute_name is not a method or the user attribute
|
||||
# does not equal the torch attribute, then this is a user defined method.
|
||||
if attribute_name not in torch_module_attributes or \
|
||||
not inspect.ismethod(torch_module_attributes[attribute_name]) or \
|
||||
attribute.__func__ != torch_module_attributes[attribute_name].__func__:
|
||||
|
||||
# forward is expected to be defined by the user.
|
||||
if attribute_name == 'forward':
|
||||
continue
|
||||
|
||||
# This is a user defined/overriden method. Check for collisions.
|
||||
if attribute_name in ortmodule_attributes:
|
||||
# This is a user defined method, issue a warning.
|
||||
warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. "
|
||||
"User Module's method may not be called upon invocation through ORTModule.")
|
||||
else:
|
||||
# This is a custom method, copy it and bind the copy to ORTModule.
|
||||
# This is needed for cases where the user's custom method invokes
|
||||
# the forward method. It should go through ORTModule's forward implementation
|
||||
# and not go through the user defined forward method.
|
||||
ortmodule.__dict__[attribute_name] = types.MethodType(copy.deepcopy(attribute.__func__), ortmodule)
|
||||
else:
|
||||
if attribute_name not in torch_module_attributes and attribute_name in ortmodule_attributes:
|
||||
# This is a user defined attribute that collides with ORTModule
|
||||
if attribute_name in ortmodule_attributes:
|
||||
warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. "
|
||||
"User Module's attribute may not be returned when trying to retrieve the attribute through ORTModule.")
|
||||
|
||||
def parse_os_env_skip_check_flags(env_name, default_skip_check_str):
|
||||
"""Returns a list of SkipChecks as defined by os env variable env_name or default provided"""
|
||||
|
||||
return os.getenv(env_name, default_skip_check_str).split('|')
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@
|
|||
|
||||
from ._torch_module_factory import TorchModuleFactory
|
||||
from ._torch_module_pytorch import TorchModulePytorch
|
||||
from ._torch_module_ort import TorchModuleORT
|
||||
from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry
|
||||
from ._custom_gradient_registry import CustomGradientRegistry
|
||||
from . import _utils
|
||||
from .debug_options import DebugOptions
|
||||
from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException, ORTModuleTorchModelException, wrap_exception
|
||||
from . import _FALLBACK_INIT_EXCEPTION, MINIMUM_RUNTIME_PYTORCH_VERSION_STR, ORTMODULE_FALLBACK_POLICY, ORTMODULE_FALLBACK_RETRY
|
||||
|
|
@ -15,6 +17,7 @@ from onnxruntime.tools import pytorch_export_contrib_ops
|
|||
import functools
|
||||
import torch
|
||||
from typing import Iterator, Optional, Tuple, TypeVar, Set, Callable
|
||||
import warnings
|
||||
|
||||
# Needed to override PyTorch methods
|
||||
T = TypeVar('T', bound='Module')
|
||||
|
|
@ -32,6 +35,18 @@ class ORTModule(torch.nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, module, debug_options=None):
|
||||
|
||||
# NOTE: torch.nn.Modules that call setattr on their internal attributes regularly
|
||||
# (for example PyTorch Lightning), will trigger regular re-exports. This is
|
||||
# because ORTModule auto detects such setattrs on the original module and
|
||||
# marks the model as stale. This is a known limitation. To disable repeated
|
||||
# re-export checks when not required, please set the environment variable
|
||||
# ORTMODULE_SKIPCHECK_POLICY to SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT
|
||||
|
||||
# Set _is_initialized attribute first which starts off as False.
|
||||
# This variable will be used for comparing strings in __setattr__ and __getattr__
|
||||
# NOTE: Do not rename/move.
|
||||
self._is_initialized = False
|
||||
# Python default arguments are evaluated on function definition
|
||||
# and not on function invocation. So, if debug_options is not provided,
|
||||
# instantiate it inside the function.
|
||||
|
|
@ -41,12 +56,15 @@ class ORTModule(torch.nn.Module):
|
|||
# Fallback settings
|
||||
self._fallback_manager = _FallbackManager(policy=ORTMODULE_FALLBACK_POLICY,
|
||||
retry=ORTMODULE_FALLBACK_RETRY)
|
||||
|
||||
try:
|
||||
# Read ORTModule module initialization status
|
||||
global _FALLBACK_INIT_EXCEPTION
|
||||
if _FALLBACK_INIT_EXCEPTION:
|
||||
raise _FALLBACK_INIT_EXCEPTION
|
||||
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
self._torch_module = TorchModuleFactory()(module, debug_options, self._fallback_manager)
|
||||
|
||||
# Create forward dynamically, so each ORTModule instance will have its own copy.
|
||||
|
|
@ -68,16 +86,19 @@ class ORTModule(torch.nn.Module):
|
|||
functools.update_wrapper(
|
||||
self.forward.__func__, self._torch_module.forward.__func__)
|
||||
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
# Support contrib OPs
|
||||
pytorch_export_contrib_ops.register()
|
||||
CustomOpSymbolicRegistry.register_all()
|
||||
CustomGradientRegistry.register_all()
|
||||
|
||||
# Warn user if there are name collisions between user model's and ORTModule attributes
|
||||
# And if there are custom methods defined on the user's model, copy and bind them to
|
||||
# ORTModule.
|
||||
_utils.check_for_name_collisions_and_bind_methods_to_ortmodule(self, module)
|
||||
|
||||
except ORTModuleFallbackException as e:
|
||||
self._torch_module = TorchModulePytorch(module)
|
||||
# TODO: Rework after "custom methods" task is designed
|
||||
# TODO: Rework by implementing the "__getattribute__" method.
|
||||
# Assigning all default attributes from user's original torch.nn.Module into ORTModule
|
||||
self._backward_hooks = module._backward_hooks
|
||||
self._forward_hooks = module._forward_hooks
|
||||
|
|
@ -92,7 +113,6 @@ class ORTModule(torch.nn.Module):
|
|||
self.forward = module.forward
|
||||
|
||||
# Exceptions subject to fallback are handled here
|
||||
# import pdb; pdb.set_trace()
|
||||
self._fallback_manager.handle_exception(exception=e,
|
||||
log_level=debug_options.logging.log_level)
|
||||
except Exception as e:
|
||||
|
|
@ -102,6 +122,11 @@ class ORTModule(torch.nn.Module):
|
|||
log_level=debug_options.logging.log_level,
|
||||
override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)
|
||||
|
||||
# Finally, ORTModule initialization is complete.
|
||||
# Assign self._is_initialized to True after all the ORTModule class attributes have been assigned
|
||||
# else, they will be assigned to self._torch_module.original_module instead.
|
||||
self._is_initialized = True
|
||||
|
||||
# IMPORTANT: DO NOT add code here
|
||||
# This declaration is for automatic document generation purposes only
|
||||
# The actual forward implementation is bound during ORTModule initialization
|
||||
|
|
@ -269,3 +294,36 @@ class ORTModule(torch.nn.Module):
|
|||
"""Override :meth:`~torch.nn.Module.named_modules`"""
|
||||
|
||||
yield from self._torch_module.named_modules(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True:
|
||||
# If ORTModule is intitialized and attribute is not found in ORTModule,
|
||||
# it must be present in the user's torch.nn.Module. Forward the call to
|
||||
# the user's model.
|
||||
assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model"
|
||||
return getattr(self.module, name)
|
||||
else:
|
||||
return super(ORTModule, self).__getattr__(name)
|
||||
|
||||
def __setattr__(self, name: str, value) -> None:
|
||||
|
||||
if name in self.__dict__:
|
||||
# If the name is an attribute of ORTModule, update only ORTModule
|
||||
self.__dict__[name] = value
|
||||
|
||||
elif '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] == True:
|
||||
|
||||
assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model"
|
||||
|
||||
# If the name is an attribute of user model, or is a new attribute, update there.
|
||||
# Set the attribute on the user's original module
|
||||
setattr(self.module, name, value)
|
||||
# Signal to execution manager to re-export the model.
|
||||
# Re-export will be avoided if _skip_check is enabled.
|
||||
if isinstance(self._torch_module, TorchModuleORT):
|
||||
for training_mode in [False, True]:
|
||||
self._torch_module._execution_manager(training_mode).signal_model_changed()
|
||||
|
||||
else:
|
||||
# Setting any new attributes should be done on ORTModule only when 'torch_module' is not defined
|
||||
self.__dict__[name] = value
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# orttraining_test_ortmodule_api.py
|
||||
|
||||
from functools import reduce
|
||||
import itertools
|
||||
import math
|
||||
import random
|
||||
import copy
|
||||
|
|
@ -18,7 +20,7 @@ from inspect import signature
|
|||
import tempfile
|
||||
import os
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager
|
||||
import _test_helpers
|
||||
|
||||
# Import autocasting libs
|
||||
|
|
@ -3469,3 +3471,232 @@ def test_ortmodule_list_dict_input_with_kwargs_and_registered_buffer():
|
|||
kwargs_input_copy = copy.deepcopy(kwargs_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_model(x, **kwargs_input), ort_model(x_copy, **kwargs_input_copy))
|
||||
|
||||
def test_ortmodule_user_defined_method():
|
||||
class UserDefinedMethodsNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserDefinedMethodsNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
def custom_method_returns_input(self, user_input):
|
||||
return user_input
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = UserDefinedMethodsNet().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = copy.deepcopy(x)
|
||||
|
||||
out = ort_model.custom_method_returns_input(x)
|
||||
assert x is out
|
||||
|
||||
pt_out = pt_model(x)
|
||||
ort_out = ort_model(y)
|
||||
_test_helpers.assert_values_are_close(pt_out, ort_out)
|
||||
|
||||
def test_ortmodule_user_getattr_gets_successfully():
|
||||
class UserDefinedMethodsNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserDefinedMethodsNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
def custom_method_returns_input(self, user_input):
|
||||
return user_input
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = UserDefinedMethodsNet().to(device)
|
||||
ort_model = ORTModule(pt_model)
|
||||
|
||||
assert ort_model.custom_method_returns_input != pt_model.custom_method_returns_input
|
||||
assert ort_model.custom_method_returns_input.__func__ == pt_model.custom_method_returns_input.__func__
|
||||
assert ort_model.dummy is pt_model.dummy
|
||||
|
||||
@pytest.mark.parametrize("attribute", ['True', 'lambda x : x'])
|
||||
def test_ortmodule_setattr_new_attribute(attribute):
|
||||
class UserNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
device = 'cuda'
|
||||
pt_model = UserNet().to(device)
|
||||
ort_model = ORTModule(pt_model)
|
||||
ort_model.a_new_attribute = attribute
|
||||
|
||||
assert hasattr(pt_model, 'a_new_attribute')
|
||||
assert pt_model.a_new_attribute == attribute
|
||||
assert 'a_new_attribute' not in ort_model.__dict__
|
||||
|
||||
def test_ortmodule_setattr_on_ortmodule_copied_user_model_attribute():
|
||||
class UserNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
def custom_method(self, a):
|
||||
return a
|
||||
|
||||
def my_new_custom_method(self, a, b, c):
|
||||
return a + b + c
|
||||
|
||||
device = 'cuda'
|
||||
pt_model = UserNet().to(device)
|
||||
ort_model = ORTModule(pt_model)
|
||||
# custom_method is copied by ORTModule from the users model
|
||||
# and bound to itself
|
||||
ort_model.custom_method = my_new_custom_method
|
||||
# dummy is defined on pt model
|
||||
ort_model.dummy = torch.nn.Parameter(torch.FloatTensor([12]))
|
||||
|
||||
assert hasattr(pt_model, 'dummy')
|
||||
assert torch.eq(pt_model.dummy, torch.nn.Parameter(torch.FloatTensor([12])))
|
||||
assert 'dummy' not in ort_model.__dict__
|
||||
|
||||
assert hasattr(pt_model, 'custom_method')
|
||||
assert pt_model.custom_method is not my_new_custom_method
|
||||
assert ort_model.custom_method is my_new_custom_method
|
||||
|
||||
def test_ortmodule_setattr_ortmodule_attribute():
|
||||
class UserNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
device = 'cuda'
|
||||
pt_model = UserNet().to(device)
|
||||
ort_model = ORTModule(pt_model)
|
||||
ort_model._torch_module = True
|
||||
|
||||
assert not hasattr(pt_model, '_torch_module')
|
||||
assert '_torch_module' in ort_model.__dict__
|
||||
assert ort_model._torch_module == True
|
||||
|
||||
def test_ortmodule_setattr_signals_model_changed():
|
||||
class UserNet(torch.nn.Module):
|
||||
def __init__(self, input_flag):
|
||||
super(UserNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([10]))
|
||||
self.input_flag = input_flag
|
||||
|
||||
def forward(self, a):
|
||||
if self.input_flag:
|
||||
return self.dummy + a
|
||||
else:
|
||||
return a
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = UserNet(True).to(device)
|
||||
ort_model = ORTModule(pt_model)
|
||||
|
||||
_ = ort_model(torch.randn(N, D_in, device=device))
|
||||
exported_model1 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model
|
||||
|
||||
for training_mode in [False, True]:
|
||||
assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == False
|
||||
ort_model.input_flag = False
|
||||
|
||||
for training_mode in [False, True]:
|
||||
assert ort_model._torch_module._execution_manager(training_mode)._original_model_has_changed == True
|
||||
|
||||
_ = ort_model(torch.randn(N, D_in, device=device))
|
||||
exported_model2 = ort_model._torch_module._execution_manager(True)._onnx_models.exported_model
|
||||
|
||||
assert exported_model1 != exported_model2
|
||||
|
||||
def test_ortmodule_attribute_name_collision_warning():
|
||||
class UserNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(UserNet, self).__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.FloatTensor([0]))
|
||||
self._torch_module = True
|
||||
|
||||
def forward(self, a):
|
||||
return self.dummy + a
|
||||
|
||||
def load_state_dict(self):
|
||||
pass
|
||||
|
||||
device = 'cuda'
|
||||
pt_model = UserNet().to(device)
|
||||
with pytest.warns(UserWarning) as warning_record:
|
||||
ort_model = ORTModule(pt_model)
|
||||
|
||||
assert len(warning_record) == 2
|
||||
assert "_torch_module collides with ORTModule's attribute name." in warning_record[0].message.args[0]
|
||||
assert "load_state_dict collides with ORTModule's attribute name." in warning_record[1].message.args[0]
|
||||
|
||||
def test_ortmodule_ortmodule_method_attribute_copy():
|
||||
class UserNetWithSelfCallingForward(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super(UserNetWithSelfCallingForward, self).__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1):
|
||||
out = self.fc1(input1)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
def run_forward(self, *args, **kwargs):
|
||||
return self(*args, **kwargs)
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
pt_model = UserNetWithSelfCallingForward(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
x_1 = torch.randn(N, D_in, device=device)
|
||||
x_2 = copy.deepcopy(x_1)
|
||||
x_3 = copy.deepcopy(x_1)
|
||||
# Executed on ORTModule
|
||||
out1 = ort_model(x_1)
|
||||
# Executed on ORTModule even though run_forward is not defined on ORTModule
|
||||
out2 = ort_model.run_forward(x_2)
|
||||
# Executed on pytorch module since it is directly invoked from there
|
||||
out3 = pt_model.run_forward(x_3)
|
||||
|
||||
assert torch.equal(out1, out2)
|
||||
_test_helpers.assert_values_are_close(out2, out3)
|
||||
|
||||
assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward'
|
||||
assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward'
|
||||
assert type(out3.grad_fn).__name__ == 'AddmmBackward'
|
||||
|
||||
@pytest.mark.parametrize("policy_str, policy",[
|
||||
('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),
|
||||
('SKIP_CHECK_DEVICE', _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE),
|
||||
('SKIP_CHECK_BUILD_GRADIENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT),
|
||||
('SKIP_CHECK_EXECUTION_AGENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT),
|
||||
])
|
||||
def test_ortmodule_skip_check_load_from_os_env(policy_str, policy):
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
os.environ['ORTMODULE_SKIPCHECK_POLICY'] = policy_str
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(model)
|
||||
|
||||
for training_mode in [False, True]:
|
||||
assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy
|
||||
|
||||
del os.environ['ORTMODULE_SKIPCHECK_POLICY']
|
||||
|
|
|
|||
Loading…
Reference in a new issue