mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57594 Support .item() export & NumberType to tensor conversion Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D28393516 Pulled By: SplitInfinity fbshipit-source-id: 94d0aec0a8fe144ee2567dc3c9c19fcb18ed21fa Co-authored-by: BowenBao <bowbao@microsoft.com>
This commit is contained in:
parent
061c7a1e17
commit
3bc8a2264d
3 changed files with 16 additions and 0 deletions
|
|
@ -1499,6 +1499,8 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) {
|
|||
return TensorType::createContiguous(at::kDouble, at::kCPU, {});
|
||||
} else if (typ->isSubtypeOf(BoolType::get())) {
|
||||
return TensorType::createContiguous(at::kBool, at::kCPU, {});
|
||||
} else if (typ->kind() == NumberType::Kind) {
|
||||
return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown number type: ", typ->str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4263,6 +4263,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
ind = torch.tensor(-2, dtype=torch.long)
|
||||
self.run_test(GetItemModel(), (x, y, z, ind))
|
||||
|
||||
def test_item(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, i: int):
|
||||
return int(x[y[i]].item())
|
||||
|
||||
x = torch.arange(6, dtype=torch.float)
|
||||
y = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
|
||||
i = 3
|
||||
self.run_test(torch.jit.script(M()), (x, y, i))
|
||||
|
||||
@disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_nonzero(self):
|
||||
|
|
|
|||
|
|
@ -2945,6 +2945,10 @@ def __getitem_(g, self, i):
|
|||
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
|
||||
|
||||
|
||||
def item(g, self):
|
||||
return self
|
||||
|
||||
|
||||
def take(g, self, index):
|
||||
self_flattened = g.op('Reshape', self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
||||
out = index_select(g, self_flattened, 0, index)
|
||||
|
|
|
|||
Loading…
Reference in a new issue