diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 0993476b23..086e2fd73c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -291,7 +291,7 @@ def _extract_schema(data): elif isinstance(data, torch.Tensor): return _TensorStub(dtype=str(data.dtype), shape_dims=len(data.size())) - if isinstance(data, abc.Sequence): + if isinstance(data, abc.Sequence) and not isinstance(data, str): sequence_type = type(data) data = list(data) for idx in range(len(data)): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 0abe778d90..6037267ae9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -2808,3 +2808,16 @@ def test_hf_save_pretrained(): for p1, p2 in zip(model1.parameters(), model2.parameters()): assert p1.data.ne(p2.data).sum() == 0 + +def test_input_with_string_exception(): + class MyStrNet(torch.nn.Module): + def forward(self, x, my_str): + if my_str.lower() == 'hello': + print('hi') + return x + + model = MyStrNet() + model = ORTModule(model) + with pytest.raises(TypeError) as ex_info: + _ = model(torch.randn(1, 2), 'hello') + assert "ORTModule does not support the following model data type " in str(ex_info.value)