[ONNX] Support .item() export & NumberType to tensor conversion (#55697) (#57594)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57594

Support .item() export & NumberType to tensor conversion

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D28393516

Pulled By: SplitInfinity

fbshipit-source-id: 94d0aec0a8fe144ee2567dc3c9c19fcb18ed21fa

Co-authored-by: BowenBao <bowbao@microsoft.com>
This commit is contained in:
BowenBao 2021-05-13 13:37:10 -07:00 committed by Facebook GitHub Bot
parent 061c7a1e17
commit 3bc8a2264d
3 changed files with 16 additions and 0 deletions

View file

@ -1499,6 +1499,8 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) {
return TensorType::createContiguous(at::kDouble, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
return TensorType::createContiguous(at::kBool, at::kCPU, {});
} else if (typ->kind() == NumberType::Kind) {
return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
}
TORCH_CHECK(false, "Unknown number type: ", typ->str());
}

View file

@ -4263,6 +4263,16 @@ class TestONNXRuntime(unittest.TestCase):
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))
def test_item(self):
class M(torch.nn.Module):
def forward(self, x, y, i: int):
return int(x[y[i]].item())
x = torch.arange(6, dtype=torch.float)
y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
i = 3
self.run_test(torch.jit.script(M()), (x, y, i))
@disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
@skipIfUnsupportedMinOpsetVersion(9)
def test_nonzero(self):

View file

@ -2945,6 +2945,10 @@ def __getitem_(g, self, i):
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
def item(g, self):
return self
def take(g, self, index):
self_flattened = g.op('Reshape', self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
out = index_select(g, self_flattened, 0, index)