mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Ignore all string inputs to ORTModule AB#1310803 (#9344)
This commit is contained in:
parent
4698b73725
commit
20eaed43e5
5 changed files with 46 additions and 31 deletions
|
|
@ -11,7 +11,7 @@ import warnings
|
|||
import gc
|
||||
|
||||
from ._fallback import _FallbackManager, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception
|
||||
|
||||
from ._utils import warn_of_constant_inputs
|
||||
|
||||
class _OutputIdentityOp(torch.autograd.Function):
|
||||
'''Internal class used to prepend Identity ops in model's outputs
|
||||
|
|
@ -141,6 +141,9 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
|
|||
# The exporter handles input lists by expanding them so that each
|
||||
# element of the list is its own input.
|
||||
# ORTModule must match this behavior by also expanding the inputs.
|
||||
if current_input is None or isinstance(current_input, str):
|
||||
# Drop all None and string inputs
|
||||
return
|
||||
if isinstance(current_input, abc.Sequence):
|
||||
# If the input is a sequence (like a list), expand the list so that
|
||||
# each element of the list is an input by itself
|
||||
|
|
@ -151,7 +154,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
|
|||
# each element of the dict is an input by itself
|
||||
for _, val in current_input.items():
|
||||
_expand_inputs(val, non_none_inputs)
|
||||
elif current_input is not None:
|
||||
else:
|
||||
# else just collect all the non none inputs within non_none_inputs
|
||||
non_none_inputs.append(current_input)
|
||||
|
||||
|
|
@ -314,8 +317,13 @@ def _extract_schema(data):
|
|||
"""Extract the data schema by replacing every torch.Tensor value with _TensorStub"""
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
elif isinstance(data, str):
|
||||
warn_of_constant_inputs(data)
|
||||
return data
|
||||
elif _PrimitiveType.is_primitive_type(data):
|
||||
if isinstance(data, bool):
|
||||
warn_of_constant_inputs(data)
|
||||
return _TensorStub(dtype=_PrimitiveType.get_primitive_dtype(data), shape_dims=0)
|
||||
# Depth first traversal to iterate over the data to replace every tensor with a stub
|
||||
elif isinstance(data, torch.Tensor):
|
||||
|
|
@ -324,7 +332,7 @@ def _extract_schema(data):
|
|||
# Instead of replacing the tensor with a stub in the original user input, build the stubbed_schema
|
||||
# from scratch from the user input.
|
||||
stubbed_schema = None
|
||||
if isinstance(data, abc.Sequence) and not isinstance(data, str):
|
||||
if isinstance(data, abc.Sequence):
|
||||
sequence_type = type(data)
|
||||
stubbed_schema = [_extract_schema(val) for val in data]
|
||||
try:
|
||||
|
|
@ -431,8 +439,8 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
|
|||
def _add_input(name, input, onnx_graph, onnx_graph_input_names):
|
||||
"""Returns number of expanded non none inputs that _add_input processed"""
|
||||
|
||||
if input is None:
|
||||
# Drop all None inputs and return 0.
|
||||
if input is None or isinstance(input, str):
|
||||
# Drop all None and string inputs and return 0.
|
||||
return 0
|
||||
|
||||
num_expanded_non_none_inputs = 0
|
||||
|
|
|
|||
|
|
@ -226,3 +226,7 @@ def switch_backend_to_pytorch(ortmodule, pytorch_module):
|
|||
ortmodule._load_state_dict_pre_hooks = pytorch_module._load_state_dict_pre_hooks
|
||||
ortmodule._modules = pytorch_module._modules
|
||||
ortmodule.forward = pytorch_module.forward
|
||||
|
||||
def warn_of_constant_inputs(data):
|
||||
warnings.warn(f"Received input of type {type(data)} which may be treated as a constant by ORT by default."
|
||||
" Please consider moving constant arguments to the model constructor.")
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
import torch
|
||||
|
||||
|
||||
class MyStrNet(torch.nn.Module):
|
||||
def forward(self, x, my_str):
|
||||
if my_str.lower() == 'hello':
|
||||
print('hi')
|
||||
class MyCustomClassInputNet(torch.nn.Module):
|
||||
def forward(self, x, custom_class_obj):
|
||||
if custom_class_obj.x == 1:
|
||||
return x+1
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ class NeuralNetCustomClassOutput(torch.nn.Module):
|
|||
class MyStrNet(torch.nn.Module):
|
||||
def forward(self, x, my_str):
|
||||
if my_str.lower() == 'hello':
|
||||
print('hi')
|
||||
return x+1
|
||||
return x
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
|
|
@ -3321,23 +3321,18 @@ def test_hf_save_pretrained():
|
|||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert p1.data.ne(p2.data).sum() == 0
|
||||
|
||||
def test_input_with_string_exception():
|
||||
def test_ortmodule_string_inputs_are_ignored():
|
||||
|
||||
pt_model = MyStrNet()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(1, 2)
|
||||
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
|
||||
# Fallback
|
||||
pt_out = pt_model(x, 'hello')
|
||||
ort_out = pt_model(x, 'hello')
|
||||
_test_helpers.assert_values_are_close(pt_out, ort_out)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
|
||||
_ = ort_model(x, 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
with pytest.warns(UserWarning) as warning_record:
|
||||
out = ort_model(x, 'hello')
|
||||
|
||||
assert len(warning_record) == 2
|
||||
assert "Received input of type <class 'str'> which may be treated as a constant by ORT by default." in warning_record[1].message.args[0]
|
||||
_test_helpers.assert_values_are_close(out, x+1)
|
||||
|
||||
def test_ortmodule_list_input():
|
||||
class ListNet(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as
|
|||
import _test_helpers
|
||||
from _orttraining_ortmodule_models import (NeuralNetSinglePositionalArgument,
|
||||
NeuralNetCustomClassOutput,
|
||||
MyStrNet,
|
||||
MyCustomClassInputNet,
|
||||
MyCustomFunctionReluModel)
|
||||
|
||||
# PyTorch model definitions for tests
|
||||
|
|
@ -254,10 +254,14 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy
|
|||
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
|
||||
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
|
||||
|
||||
pt_model = MyStrNet()
|
||||
pt_model = MyCustomClassInputNet()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
inputs = torch.randn(1, 2)
|
||||
|
||||
class CustomClass:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
ort_model.train(is_training)
|
||||
pt_model.train(is_training)
|
||||
|
||||
|
|
@ -267,17 +271,21 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy
|
|||
if i > 0 and persist_fallback:
|
||||
assert ort_model._torch_module._execution_manager(
|
||||
is_training=is_training)._fallback_manager._exception is not None
|
||||
ort_out = ort_model(inputs, 'hello')
|
||||
pt_out = pt_model(inputs, 'hello')
|
||||
ort_out = ort_model(inputs, CustomClass(1))
|
||||
pt_out = pt_model(inputs, CustomClass(1))
|
||||
_test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0)
|
||||
else:
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
|
||||
_ = ort_model(torch.randn(1, 2), 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
_ = ort_model(torch.randn(1, 2), CustomClass(1))
|
||||
assert "ORTModule does not support the following model data"\
|
||||
" type <class 'orttraining_test_ortmodule_fallback."\
|
||||
"test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value)
|
||||
else:
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
|
||||
_ = ort_model(torch.randn(1, 2), 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
_ = ort_model(torch.randn(1, 2), CustomClass(1))
|
||||
assert "ORTModule does not support the following model data"\
|
||||
" type <class 'orttraining_test_ortmodule_fallback."\
|
||||
"test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback",
|
||||
|
|
|
|||
Loading…
Reference in a new issue