diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index c3c550a5b4a..b6202151bc6 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1499,6 +1499,8 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) { return TensorType::createContiguous(at::kDouble, at::kCPU, {}); } else if (typ->isSubtypeOf(BoolType::get())) { return TensorType::createContiguous(at::kBool, at::kCPU, {}); + } else if (typ->kind() == NumberType::Kind) { + return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt); } TORCH_CHECK(false, "Unknown number type: ", typ->str()); } diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 610c7fc464a..8f8a38dcf1b 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4263,6 +4263,16 @@ class TestONNXRuntime(unittest.TestCase): ind = torch.tensor(-2, dtype=torch.long) self.run_test(GetItemModel(), (x, y, z, ind)) + def test_item(self): + class M(torch.nn.Module): + def forward(self, x, y, i: int): + return int(x[y[i]].item()) + + x = torch.arange(6, dtype=torch.float) + y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long) + i = 3 + self.run_test(torch.jit.script(M()), (x, y, i)) + @disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable. @skipIfUnsupportedMinOpsetVersion(9) def test_nonzero(self): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 522367c2c41..d59d46e2727 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2945,6 +2945,10 @@ def __getitem_(g, self, i): return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) +def item(g, self): + return self + + def take(g, self, index): self_flattened = g.op('Reshape', self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) out = index_select(g, self_flattened, 0, index)