mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Add aten::_softmax to eager ops. (#11820)
This commit is contained in:
parent
7582644f57
commit
b0e027c661
2 changed files with 9 additions and 0 deletions
|
|
@ -138,6 +138,7 @@ hand_implemented = {
|
|||
"aten::_local_scalar_dense": MakeTorchFallback(),
|
||||
"aten::gt.Scalar_out": MakeTorchFallback(),
|
||||
"aten::equal": MakeTorchFallback(),
|
||||
"aten::_softmax": Softmax("self", axis="dim"),
|
||||
}
|
||||
|
||||
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
|
||||
|
|
|
|||
|
|
@ -144,6 +144,14 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_tensor_copied = ort_tensor.cpu()
|
||||
assert cpu_tensor_copied.stride() == (0, 0, 0)
|
||||
|
||||
def test_softmax(self):
|
||||
device = self.get_device()
|
||||
cpu_tensor = torch.rand(3, 5)
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
cpu_result = torch.softmax(cpu_tensor, dim=1)
|
||||
ort_result = torch.softmax(ort_tensor, dim=1)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue