mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Fix input schema extrator for ORTModule (#8098)
This commit is contained in:
parent
7701c8703e
commit
5c2e1bbb0a
2 changed files with 14 additions and 1 deletions
|
|
@ -291,7 +291,7 @@ def _extract_schema(data):
|
|||
elif isinstance(data, torch.Tensor):
|
||||
return _TensorStub(dtype=str(data.dtype), shape_dims=len(data.size()))
|
||||
|
||||
if isinstance(data, abc.Sequence):
|
||||
if isinstance(data, abc.Sequence) and not isinstance(data, str):
|
||||
sequence_type = type(data)
|
||||
data = list(data)
|
||||
for idx in range(len(data)):
|
||||
|
|
|
|||
|
|
@ -2808,3 +2808,16 @@ def test_hf_save_pretrained():
|
|||
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert p1.data.ne(p2.data).sum() == 0
|
||||
|
||||
def test_input_with_string_exception():
|
||||
class MyStrNet(torch.nn.Module):
|
||||
def forward(self, x, my_str):
|
||||
if my_str.lower() == 'hello':
|
||||
print('hi')
|
||||
return x
|
||||
|
||||
model = MyStrNet()
|
||||
model = ORTModule(model)
|
||||
with pytest.raises(TypeError) as ex_info:
|
||||
_ = model(torch.randn(1, 2), 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue