Add aten::_softmax to eager ops. (#11820)

This commit is contained in:
Wil Brady 2022-06-13 13:05:26 -04:00 committed by GitHub
parent 7582644f57
commit b0e027c661
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 0 deletions

View file

@ -138,6 +138,7 @@ hand_implemented = {
"aten::_local_scalar_dense": MakeTorchFallback(),
"aten::gt.Scalar_out": MakeTorchFallback(),
"aten::equal": MakeTorchFallback(),
"aten::_softmax": Softmax("self", axis="dim"),
}
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439

View file

@ -144,6 +144,14 @@ class OrtOpTests(unittest.TestCase):
cpu_tensor_copied = ort_tensor.cpu()
assert cpu_tensor_copied.stride() == (0, 0, 0)
def test_softmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
ort_tensor = cpu_tensor.to(device)
cpu_result = torch.softmax(cpu_tensor, dim=1)
ort_result = torch.softmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result, ort_result.cpu())
if __name__ == "__main__":
unittest.main()