OnnxRuntime Eager: Implement log_softmax with ONNX Ops (#12190)

* share CHECK_STATUS

* log_softmax
This commit is contained in:
msftlincoln 2022-07-15 15:03:08 -04:00 committed by GitHub
parent 9bca8405aa
commit a2dc6d32fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 110 additions and 5 deletions

View file

@ -170,7 +170,7 @@ hand_implemented = {
"aten::argmax.out": SignatureOnly(),
"aten::nonzero": Transpose(NonZero("self")),
"aten::nonzero.out": SignatureOnly(),
"aten::_log_softmax.out": MakeTorchFallback(),
"aten::_log_softmax.out": SignatureOnly(),
"aten::nll_loss_forward.output": MakeTorchFallback(),
"aten::nll_loss_backward.grad_input": MakeTorchFallback(),
"aten::_log_softmax_backward_data.out": MakeTorchFallback(),

View file

@ -198,10 +198,6 @@ class ORTGen:
writer.writeline('#include "ort_aten.h"')
writer.writeline('#include "ort_log.h"')
writer.writeline()
writer.writeline(
'#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); }'
)
writer.writeline()
writer.push_namespace("torch_ort")
writer.push_namespace("eager")
writer.writeline()

View file

@ -241,6 +241,19 @@ onnx::AttributeProto create_ort_attribute(
return attr;
}
onnx::AttributeProto create_ort_attribute(
const char* name,
const std::vector<int64_t> values) {
onnx::AttributeProto attr;
attr.set_name(name);
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
for (size_t i = 0; i < values.size(); i++)
attr.add_ints(values[i]);
return attr;
}
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
}
@ -976,6 +989,89 @@ at::Tensor& nonzero_out(
return out;
}
// aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
at::Tensor& _log_softmax_out(
const at::Tensor& self,
int64_t dim,
bool half_to_float,
// *,
at::Tensor& out) {
ORT_LOG_FN(self, dim, half_to_float, out);
if (
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
}
auto& invoker = GetORTInvoker(self.device());
// resize the output and then create output ort value to be updated.
resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), self.sizes());
auto ort_input_out = create_ort_value(invoker, out);
auto ort_input_0_self = create_ort_value(invoker, self);
// Check dimensions (according to symbolic_opset9).
// Onnx only supports log_softmax with dim -1, otherwise transpose required.
int64_t ndim = self.dim();
if (dim < 0) {
dim += ndim;
}
bool need_transpose = ndim != dim + 1;
// Use transpose to switch the needed dimension to -1
// This requires specifying all of the dimensions in order and then
// swapping the last one with the one specified.
std::vector<int64_t> axes;
std::vector<OrtValue> ort_outputs_0_Transpose(1);
if (need_transpose) {
axes.reserve(ndim);
for (int64_t i = 0; i < ndim; i++)
axes.push_back(i);
axes[dim] = ndim-1;
axes[ndim-1] = dim;
dim = ndim-1;
NodeAttributes attrs_0(1);
attrs_0["perm"] = create_ort_attribute("perm", axes);
auto status = invoker.Invoke("Transpose", {
std::move(ort_input_0_self),
}, ort_outputs_0_Transpose, &attrs_0);
CHECK_STATUS(status);
}
NodeAttributes attrs_1(1);
attrs_1["axis"] = create_ort_attribute(
"axis", dim, at::ScalarType::Int);
std::vector<OrtValue> ort_outputs_1_LogSoftmax(1);
if (!need_transpose) {
ort_outputs_1_LogSoftmax[0] = ort_input_out;
}
auto status = invoker.Invoke("LogSoftmax", {
std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self),
}, ort_outputs_1_LogSoftmax, &attrs_1);
CHECK_STATUS(status);
std::vector<OrtValue> ort_outputs_2_Transpose(1);
if (need_transpose) {
ort_outputs_2_Transpose[0] = ort_input_out;
NodeAttributes attrs_2(1);
attrs_2["perm"] = create_ort_attribute("perm", axes);;
status = invoker.Invoke("Transpose", {
std::move(ort_outputs_1_LogSoftmax[0]),
}, ort_outputs_2_Transpose, &attrs_2);
CHECK_STATUS(status);
}
return out;
}
} // namespace aten
//#pragma endregion

View file

@ -11,6 +11,8 @@
#include "ort_log.h"
#include "ort_tensor.h"
#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); }
namespace torch_ort {
namespace eager {

View file

@ -232,6 +232,17 @@ class OrtOpTests(unittest.TestCase):
ort_result = torch.softmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result, ort_result.cpu())
def test_log_softmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
ort_tensor = cpu_tensor.to(device)
cpu_result_a = torch.log_softmax(cpu_tensor, dim=1)
ort_result_a = torch.log_softmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result_a, ort_result_a.cpu())
cpu_result_b = torch.log_softmax(cpu_tensor, dim=0)
ort_result_b = torch.log_softmax(ort_tensor, dim=0)
assert torch.allclose(cpu_result_b, ort_result_b.cpu())
def test_addmm(self):
device = self.get_device()
size = 4