mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Support non-default negative axis value and intuitive data type combination for OneHot op (#1317)
* Handle nondefault negative axis value * Support more intuitive data types for this op
This commit is contained in:
parent
2698edbc98
commit
a077ac8df5
4 changed files with 53 additions and 7 deletions
|
|
@ -221,6 +221,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_string_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh);
|
||||
|
|
@ -486,6 +487,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_string_int64_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh)>,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ REG_ONE_HOT_OP(int64_t, int64_t, int64_t);
|
|||
REG_ONE_HOT_OP(float, int64_t, int64_t);
|
||||
REG_ONE_HOT_OP(int64_t, string, int64_t);
|
||||
REG_ONE_HOT_OP(float, string, int64_t);
|
||||
REG_ONE_HOT_OP(int64_t, float, int64_t);
|
||||
REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests
|
||||
REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests
|
||||
|
||||
|
|
@ -120,16 +121,28 @@ Status OneHotOp<in_type, out_type, depth_type>::Compute(OpKernelContext* p_op_ke
|
|||
const auto& indices_dims = indices_shape.GetDims();
|
||||
const auto indices_num_dims = indices_shape.NumDimensions();
|
||||
std::vector<int64_t> output_shape(indices_shape.GetDims());
|
||||
output_shape.insert(axis_ == -1 ? output_shape.end() : output_shape.begin() + axis_,
|
||||
depth_val);
|
||||
|
||||
// output rank is always 1 more than the input rank as a new dimension is added to the input shape
|
||||
const auto output_rank = static_cast<int64_t>(indices_num_dims + 1);
|
||||
if (axis_ >= output_rank || axis_ < -output_rank) {
|
||||
std::ostringstream oss;
|
||||
oss << "'axis' attribute must have a value in the range [" << -output_rank
|
||||
<< "," << indices_num_dims << "]";
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, oss.str());
|
||||
}
|
||||
|
||||
auto true_axis = axis_;
|
||||
if (true_axis < 0)
|
||||
true_axis += output_rank;
|
||||
|
||||
output_shape.insert(output_shape.begin() + true_axis, depth_val);
|
||||
|
||||
// allocate output
|
||||
const auto* values_data = values->Data<out_type>();
|
||||
Tensor* output = p_op_kernel_context->Output(0, TensorShape(output_shape));
|
||||
|
||||
const int64_t axis = (axis_ == -1) ? indices_num_dims : axis_;
|
||||
int64_t prefix_dim_size = 1;
|
||||
for (int64_t i = 0; i < axis; ++i) {
|
||||
for (int64_t i = 0; i < true_axis; ++i) {
|
||||
prefix_dim_size *= indices_dims[i];
|
||||
}
|
||||
const int64_t suffix_dim_size = indices_shape.Size() / prefix_dim_size;
|
||||
|
|
|
|||
|
|
@ -14,9 +14,6 @@ class OneHotOp final : public OpKernel {
|
|||
explicit OneHotOp(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
|
||||
int64_t tmp_axis;
|
||||
if (op_kernel_info.GetAttr<int64_t>("axis", &tmp_axis).IsOK()) {
|
||||
if (tmp_axis < -1) { // as per spec it can be -1 or more
|
||||
ORT_THROW("Value of axis is < -1");
|
||||
}
|
||||
axis_ = tmp_axis;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,6 +51,20 @@ TEST(OneHotOpTest, DefaultAxis_int64_int32_float /*indices, output, depth*/) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) {
|
||||
OpTester test("OneHot", 9);
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
|
||||
test.AddInput<int64_t>("depth", {1}, {10});
|
||||
test.AddInput<float>("values", {2}, {0, 1});
|
||||
test.AddOutput<float>("output", {2, 3, 10}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
|
||||
0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 0, 0, 0,});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, Axis_0) {
|
||||
OpTester test("OneHot", 9);
|
||||
int64_t axis = 0;
|
||||
|
|
@ -117,6 +131,26 @@ TEST(OneHotOpTest, Axis_2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, Axis_Negative_NonDefault) {
|
||||
OpTester test("OneHot", 9);
|
||||
int64_t axis = -3;
|
||||
test.AddAttribute("axis", axis);
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
|
||||
test.AddInput<int64_t>("depth", {1}, {10});
|
||||
test.AddInput<int64_t>("values", {2}, {0, 1});
|
||||
test.AddOutput<int64_t>("output", {10, 2, 3}, { 0, 0, 0, 0, 0, 0,
|
||||
1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 0,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0,});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, FloatInt64) {
|
||||
OpTester test("OneHot", 9);
|
||||
test.AddInput<float>("indices", {2, 3}, {1.f, 9.f, 8.f, 2.f, 4.f, 6.f});
|
||||
|
|
|
|||
Loading…
Reference in a new issue