From a2dc6d32fca1d13ffbdbc03672f5192375956763 Mon Sep 17 00:00:00 2001 From: msftlincoln <107071614+msftlincoln@users.noreply.github.com> Date: Fri, 15 Jul 2022 15:03:08 -0400 Subject: [PATCH] OnnxRuntime Eager: Implement log_softmax with ONNX Ops (#12190) * share CHECK_STATUS * log_softmax --- .../orttraining/eager/opgen/opgen/atenops.py | 2 +- .../eager/opgen/opgen/generator.py | 4 - orttraining/orttraining/eager/ort_aten.cpp | 96 +++++++++++++++++++ orttraining/orttraining/eager/ort_aten.h | 2 + orttraining/orttraining/eager/test/ort_ops.py | 11 +++ 5 files changed, 110 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 62b8678b33..923618f853 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -170,7 +170,7 @@ hand_implemented = { "aten::argmax.out": SignatureOnly(), "aten::nonzero": Transpose(NonZero("self")), "aten::nonzero.out": SignatureOnly(), - "aten::_log_softmax.out": MakeTorchFallback(), + "aten::_log_softmax.out": SignatureOnly(), "aten::nll_loss_forward.output": MakeTorchFallback(), "aten::nll_loss_backward.grad_input": MakeTorchFallback(), "aten::_log_softmax_backward_data.out": MakeTorchFallback(), diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 5741017c98..9b4941ed6f 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -198,10 +198,6 @@ class ORTGen: writer.writeline('#include "ort_aten.h"') writer.writeline('#include "ort_log.h"') writer.writeline() - writer.writeline( - '#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); }' - ) - writer.writeline() writer.push_namespace("torch_ort") writer.push_namespace("eager") writer.writeline() diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index e66f968312..58b22723e8 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -241,6 +241,19 @@ onnx::AttributeProto create_ort_attribute( return attr; } +onnx::AttributeProto create_ort_attribute( + const char* name, + const std::vector values) { + onnx::AttributeProto attr; + attr.set_name(name); + attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); + + for (size_t i = 0; i < values.size(); i++) + attr.add_ints(values[i]); + + return attr; +} + bool IsSupportedType(at::Scalar scalar, const std::vector& valid_types){ return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end(); } @@ -976,6 +989,89 @@ at::Tensor& nonzero_out( return out; } +// aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) +at::Tensor& _log_softmax_out( + const at::Tensor& self, + int64_t dim, + bool half_to_float, + // *, + at::Tensor& out) { + ORT_LOG_FN(self, dim, half_to_float, out); + + if ( + !IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) { + return at::native::call_fallback_fn< + &at::native::cpu_fallback, + ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); + } + auto& invoker = GetORTInvoker(self.device()); + + // resize the output and then create output ort value to be updated. + resize_output(invoker, dynamic_cast(out.unsafeGetTensorImpl()), self.sizes()); + auto ort_input_out = create_ort_value(invoker, out); + auto ort_input_0_self = create_ort_value(invoker, self); + + // Check dimensions (according to symbolic_opset9). + // Onnx only supports log_softmax with dim -1, otherwise transpose required. + int64_t ndim = self.dim(); + if (dim < 0) { + dim += ndim; + } + bool need_transpose = ndim != dim + 1; + + // Use transpose to switch the needed dimension to -1 + // This requires specifying all of the dimensions in order and then + // swapping the last one with the one specified. + std::vector axes; + std::vector ort_outputs_0_Transpose(1); + if (need_transpose) { + axes.reserve(ndim); + for (int64_t i = 0; i < ndim; i++) + axes.push_back(i); + + axes[dim] = ndim-1; + axes[ndim-1] = dim; + dim = ndim-1; + + NodeAttributes attrs_0(1); + attrs_0["perm"] = create_ort_attribute("perm", axes); + auto status = invoker.Invoke("Transpose", { + std::move(ort_input_0_self), + }, ort_outputs_0_Transpose, &attrs_0); + CHECK_STATUS(status); + } + + NodeAttributes attrs_1(1); + attrs_1["axis"] = create_ort_attribute( + "axis", dim, at::ScalarType::Int); + + std::vector ort_outputs_1_LogSoftmax(1); + if (!need_transpose) { + ort_outputs_1_LogSoftmax[0] = ort_input_out; + } + + auto status = invoker.Invoke("LogSoftmax", { + std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self), + }, ort_outputs_1_LogSoftmax, &attrs_1); + CHECK_STATUS(status); + + std::vector ort_outputs_2_Transpose(1); + + if (need_transpose) { + ort_outputs_2_Transpose[0] = ort_input_out; + + NodeAttributes attrs_2(1); + attrs_2["perm"] = create_ort_attribute("perm", axes);; + + status = invoker.Invoke("Transpose", { + std::move(ort_outputs_1_LogSoftmax[0]), + }, ort_outputs_2_Transpose, &attrs_2); + CHECK_STATUS(status); + } + + return out; +} + } // namespace aten //#pragma endregion diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index 88ddb15a70..daee92f614 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -11,6 +11,8 @@ #include "ort_log.h" #include "ort_tensor.h" +#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); } + namespace torch_ort { namespace eager { diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index d278b865d6..31e5b57499 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -232,6 +232,17 @@ class OrtOpTests(unittest.TestCase): ort_result = torch.softmax(ort_tensor, dim=1) assert torch.allclose(cpu_result, ort_result.cpu()) + def test_log_softmax(self): + device = self.get_device() + cpu_tensor = torch.rand(3, 5) + ort_tensor = cpu_tensor.to(device) + cpu_result_a = torch.log_softmax(cpu_tensor, dim=1) + ort_result_a = torch.log_softmax(ort_tensor, dim=1) + assert torch.allclose(cpu_result_a, ort_result_a.cpu()) + cpu_result_b = torch.log_softmax(cpu_tensor, dim=0) + ort_result_b = torch.log_softmax(ort_tensor, dim=0) + assert torch.allclose(cpu_result_b, ort_result_b.cpu()) + def test_addmm(self): device = self.get_device() size = 4