From b0e027c661769f4d5edd5a893e4cb7bd6b8ff053 Mon Sep 17 00:00:00 2001 From: Wil Brady <25513670+WilBrady@users.noreply.github.com> Date: Mon, 13 Jun 2022 13:05:26 -0400 Subject: [PATCH] Add aten::_softmax to eager ops. (#11820) --- orttraining/orttraining/eager/opgen/opgen/atenops.py | 1 + orttraining/orttraining/eager/test/ort_ops.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 225c2edad3..a470224260 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -138,6 +138,7 @@ hand_implemented = { "aten::_local_scalar_dense": MakeTorchFallback(), "aten::gt.Scalar_out": MakeTorchFallback(), "aten::equal": MakeTorchFallback(), + "aten::_softmax": Softmax("self", axis="dim"), } # Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439 diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 62c12f01f8..515967d4a5 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -144,6 +144,14 @@ class OrtOpTests(unittest.TestCase): cpu_tensor_copied = ort_tensor.cpu() assert cpu_tensor_copied.stride() == (0, 0, 0) + def test_softmax(self): + device = self.get_device() + cpu_tensor = torch.rand(3, 5) + ort_tensor = cpu_tensor.to(device) + cpu_result = torch.softmax(cpu_tensor, dim=1) + ort_result = torch.softmax(ort_tensor, dim=1) + assert torch.allclose(cpu_result, ort_result.cpu()) + if __name__ == "__main__": unittest.main()