mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
OnnxRuntime Eager: Implement log_softmax with ONNX Ops (#12190)
* share CHECK_STATUS * log_softmax
This commit is contained in:
parent
9bca8405aa
commit
a2dc6d32fc
5 changed files with 110 additions and 5 deletions
|
|
@ -170,7 +170,7 @@ hand_implemented = {
|
|||
"aten::argmax.out": SignatureOnly(),
|
||||
"aten::nonzero": Transpose(NonZero("self")),
|
||||
"aten::nonzero.out": SignatureOnly(),
|
||||
"aten::_log_softmax.out": MakeTorchFallback(),
|
||||
"aten::_log_softmax.out": SignatureOnly(),
|
||||
"aten::nll_loss_forward.output": MakeTorchFallback(),
|
||||
"aten::nll_loss_backward.grad_input": MakeTorchFallback(),
|
||||
"aten::_log_softmax_backward_data.out": MakeTorchFallback(),
|
||||
|
|
|
|||
|
|
@ -198,10 +198,6 @@ class ORTGen:
|
|||
writer.writeline('#include "ort_aten.h"')
|
||||
writer.writeline('#include "ort_log.h"')
|
||||
writer.writeline()
|
||||
writer.writeline(
|
||||
'#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); }'
|
||||
)
|
||||
writer.writeline()
|
||||
writer.push_namespace("torch_ort")
|
||||
writer.push_namespace("eager")
|
||||
writer.writeline()
|
||||
|
|
|
|||
|
|
@ -241,6 +241,19 @@ onnx::AttributeProto create_ort_attribute(
|
|||
return attr;
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
const std::vector<int64_t> values) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
for (size_t i = 0; i < values.size(); i++)
|
||||
attr.add_ints(values[i]);
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
|
||||
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
|
||||
}
|
||||
|
|
@ -976,6 +989,89 @@ at::Tensor& nonzero_out(
|
|||
return out;
|
||||
}
|
||||
|
||||
// aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& _log_softmax_out(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
bool half_to_float,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, dim, half_to_float, out);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
// resize the output and then create output ort value to be updated.
|
||||
resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), self.sizes());
|
||||
auto ort_input_out = create_ort_value(invoker, out);
|
||||
auto ort_input_0_self = create_ort_value(invoker, self);
|
||||
|
||||
// Check dimensions (according to symbolic_opset9).
|
||||
// Onnx only supports log_softmax with dim -1, otherwise transpose required.
|
||||
int64_t ndim = self.dim();
|
||||
if (dim < 0) {
|
||||
dim += ndim;
|
||||
}
|
||||
bool need_transpose = ndim != dim + 1;
|
||||
|
||||
// Use transpose to switch the needed dimension to -1
|
||||
// This requires specifying all of the dimensions in order and then
|
||||
// swapping the last one with the one specified.
|
||||
std::vector<int64_t> axes;
|
||||
std::vector<OrtValue> ort_outputs_0_Transpose(1);
|
||||
if (need_transpose) {
|
||||
axes.reserve(ndim);
|
||||
for (int64_t i = 0; i < ndim; i++)
|
||||
axes.push_back(i);
|
||||
|
||||
axes[dim] = ndim-1;
|
||||
axes[ndim-1] = dim;
|
||||
dim = ndim-1;
|
||||
|
||||
NodeAttributes attrs_0(1);
|
||||
attrs_0["perm"] = create_ort_attribute("perm", axes);
|
||||
auto status = invoker.Invoke("Transpose", {
|
||||
std::move(ort_input_0_self),
|
||||
}, ort_outputs_0_Transpose, &attrs_0);
|
||||
CHECK_STATUS(status);
|
||||
}
|
||||
|
||||
NodeAttributes attrs_1(1);
|
||||
attrs_1["axis"] = create_ort_attribute(
|
||||
"axis", dim, at::ScalarType::Int);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_LogSoftmax(1);
|
||||
if (!need_transpose) {
|
||||
ort_outputs_1_LogSoftmax[0] = ort_input_out;
|
||||
}
|
||||
|
||||
auto status = invoker.Invoke("LogSoftmax", {
|
||||
std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self),
|
||||
}, ort_outputs_1_LogSoftmax, &attrs_1);
|
||||
CHECK_STATUS(status);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_2_Transpose(1);
|
||||
|
||||
if (need_transpose) {
|
||||
ort_outputs_2_Transpose[0] = ort_input_out;
|
||||
|
||||
NodeAttributes attrs_2(1);
|
||||
attrs_2["perm"] = create_ort_attribute("perm", axes);;
|
||||
|
||||
status = invoker.Invoke("Transpose", {
|
||||
std::move(ort_outputs_1_LogSoftmax[0]),
|
||||
}, ort_outputs_2_Transpose, &attrs_2);
|
||||
CHECK_STATUS(status);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace aten
|
||||
|
||||
//#pragma endregion
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@
|
|||
#include "ort_log.h"
|
||||
#include "ort_tensor.h"
|
||||
|
||||
#define CHECK_STATUS(status) if(!status.IsOK()) { std::stringstream err; err << "ORT return failure (line " << __LINE__ << "): " << status.ErrorMessage(); throw std::runtime_error(err.str()); }
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,17 @@ class OrtOpTests(unittest.TestCase):
|
|||
ort_result = torch.softmax(ort_tensor, dim=1)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
|
||||
def test_log_softmax(self):
|
||||
device = self.get_device()
|
||||
cpu_tensor = torch.rand(3, 5)
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
cpu_result_a = torch.log_softmax(cpu_tensor, dim=1)
|
||||
ort_result_a = torch.log_softmax(ort_tensor, dim=1)
|
||||
assert torch.allclose(cpu_result_a, ort_result_a.cpu())
|
||||
cpu_result_b = torch.log_softmax(cpu_tensor, dim=0)
|
||||
ort_result_b = torch.log_softmax(ort_tensor, dim=0)
|
||||
assert torch.allclose(cpu_result_b, ort_result_b.cpu())
|
||||
|
||||
def test_addmm(self):
|
||||
device = self.get_device()
|
||||
size = 4
|
||||
|
|
|
|||
Loading…
Reference in a new issue