mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Expand: add additional supported types. (#364)
This commit is contained in:
parent
ea816615eb
commit
d040b452cb
3 changed files with 142 additions and 6 deletions
|
|
@ -210,6 +210,17 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Uns
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If);
|
||||
|
|
@ -440,6 +451,17 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If)>());
|
||||
|
|
|
|||
|
|
@ -1001,12 +1001,26 @@ Status Expand_8<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Expand,
|
||||
8,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Expand_8<float>);
|
||||
#define REG_EXPAND_KERNEL(TYPE) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Expand, \
|
||||
8, \
|
||||
TYPE, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
Expand_8<TYPE>);
|
||||
|
||||
REG_EXPAND_KERNEL(float)
|
||||
REG_EXPAND_KERNEL(double)
|
||||
REG_EXPAND_KERNEL(int8_t)
|
||||
REG_EXPAND_KERNEL(int16_t)
|
||||
REG_EXPAND_KERNEL(int32_t)
|
||||
REG_EXPAND_KERNEL(int64_t)
|
||||
REG_EXPAND_KERNEL(uint8_t)
|
||||
REG_EXPAND_KERNEL(uint16_t)
|
||||
REG_EXPAND_KERNEL(uint32_t)
|
||||
REG_EXPAND_KERNEL(uint64_t)
|
||||
REG_EXPAND_KERNEL(bool)
|
||||
REG_EXPAND_KERNEL(MLFloat16)
|
||||
|
||||
template <>
|
||||
Status Scale<float>::Compute(OpKernelContext* ctx) const {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -877,6 +878,105 @@ TEST(MathOpTest, Expand_8_1x3) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3_int32) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int32_t>("data_0", {1}, {1});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
||||
test.AddOutput<int32_t>("result", {3, 3},
|
||||
{1, 1, 1,
|
||||
1, 1, 1,
|
||||
1, 1, 1});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x1_int32) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int32_t>("data_0", {3}, {1, 2, 3});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
||||
test.AddOutput<int32_t>("result", {3, 3},
|
||||
{1, 2, 3,
|
||||
1, 2, 3,
|
||||
1, 2, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_1x3_int32) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int32_t>("data_0", {3, 1}, {1, 2, 3});
|
||||
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
||||
test.AddOutput<int32_t>("result", {3, 3},
|
||||
{1, 1, 1,
|
||||
2, 2, 2,
|
||||
3, 3, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3_int64) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int64_t>("data_0", {1}, {1});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
||||
test.AddOutput<int64_t>("result", {3, 3},
|
||||
{1, 1, 1,
|
||||
1, 1, 1,
|
||||
1, 1, 1});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x1_int64) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int64_t>("data_0", {3}, {1, 2, 3});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
||||
test.AddOutput<int64_t>("result", {3, 3},
|
||||
{1, 2, 3,
|
||||
1, 2, 3,
|
||||
1, 2, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_1x3_int64) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int64_t>("data_0", {3, 1}, {1, 2, 3});
|
||||
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
||||
test.AddOutput<int64_t>("result", {3, 3},
|
||||
{1, 1, 1,
|
||||
2, 2, 2,
|
||||
3, 3, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3_float16) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f))});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x1_float16) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
|
||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
|
||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_1x3_float16) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
||||
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
||||
MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)),
|
||||
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Scale) {
|
||||
OpTester test("Scale");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
|
|
|
|||
Loading…
Reference in a new issue