Fix input schema extrator for ORTModule (#8098)

This commit is contained in:
Thiago Crepaldi 2021-06-18 21:47:49 -07:00 committed by GitHub
parent 7701c8703e
commit 5c2e1bbb0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View file

@ -291,7 +291,7 @@ def _extract_schema(data):
elif isinstance(data, torch.Tensor):
return _TensorStub(dtype=str(data.dtype), shape_dims=len(data.size()))
if isinstance(data, abc.Sequence):
if isinstance(data, abc.Sequence) and not isinstance(data, str):
sequence_type = type(data)
data = list(data)
for idx in range(len(data)):

View file

@ -2808,3 +2808,16 @@ def test_hf_save_pretrained():
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert p1.data.ne(p2.data).sum() == 0
def test_input_with_string_exception():
class MyStrNet(torch.nn.Module):
def forward(self, x, my_str):
if my_str.lower() == 'hello':
print('hi')
return x
model = MyStrNet()
model = ORTModule(model)
with pytest.raises(TypeError) as ex_info:
_ = model(torch.randn(1, 2), 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)