diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d9af3cb42d..8355c7c7dc 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -210,6 +210,17 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Uns class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If); @@ -440,6 +451,17 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 5ebbd9512d..f12feeb5d3 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -1001,12 +1001,26 @@ Status Expand_8::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_CPU_OPERATOR_TYPED_KERNEL( - Expand, - 8, - float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Expand_8); +#define REG_EXPAND_KERNEL(TYPE) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + Expand, \ + 8, \ + TYPE, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Expand_8); + +REG_EXPAND_KERNEL(float) +REG_EXPAND_KERNEL(double) +REG_EXPAND_KERNEL(int8_t) +REG_EXPAND_KERNEL(int16_t) +REG_EXPAND_KERNEL(int32_t) +REG_EXPAND_KERNEL(int64_t) +REG_EXPAND_KERNEL(uint8_t) +REG_EXPAND_KERNEL(uint16_t) +REG_EXPAND_KERNEL(uint32_t) +REG_EXPAND_KERNEL(uint64_t) +REG_EXPAND_KERNEL(bool) +REG_EXPAND_KERNEL(MLFloat16) template <> Status Scale::Compute(OpKernelContext* ctx) const { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 90650e6df4..d3f7bc0508 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "core/util/math.h" namespace onnxruntime { namespace test { @@ -877,6 +878,105 @@ TEST(MathOpTest, Expand_8_1x3) { test.Run(); } +TEST(MathOpTest, Expand_8_3x3_int32) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1}, {1}); + test.AddInput("data_1", {2}, {3, 3}); + test.AddOutput("result", {3, 3}, + {1, 1, 1, + 1, 1, 1, + 1, 1, 1}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x1_int32) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3}, {1, 2, 3}); + test.AddInput("data_1", {2}, {3, 1}); + test.AddOutput("result", {3, 3}, + {1, 2, 3, + 1, 2, 3, + 1, 2, 3}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_1x3_int32) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {1, 2, 3}); + test.AddInput("data_1", {2}, {1, 3}); + test.AddOutput("result", {3, 3}, + {1, 1, 1, + 2, 2, 2, + 3, 3, 3}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x3_int64) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1}, {1}); + test.AddInput("data_1", {2}, {3, 3}); + test.AddOutput("result", {3, 3}, + {1, 1, 1, + 1, 1, 1, + 1, 1, 1}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x1_int64) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3}, {1, 2, 3}); + test.AddInput("data_1", {2}, {3, 1}); + test.AddOutput("result", {3, 3}, + {1, 2, 3, + 1, 2, 3, + 1, 2, 3}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_1x3_int64) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {1, 2, 3}); + test.AddInput("data_1", {2}, {1, 3}); + test.AddOutput("result", {3, 3}, + {1, 1, 1, + 2, 2, 2, + 3, 3, 3}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x3_float16) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))}); + test.AddInput("data_1", {2}, {3, 3}); + test.AddOutput("result", {3, 3}, + {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), + MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), + MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f))}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x1_float16) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); + test.AddInput("data_1", {2}, {3, 1}); + test.AddOutput("result", {3, 3}, + {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)), + MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)), + MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_1x3_float16) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); + test.AddInput("data_1", {2}, {1, 3}); + test.AddOutput("result", {3, 3}, + {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), + MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), + MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))}); + test.Run(); +} + TEST(MathOpTest, Scale) { OpTester test("Scale"); std::vector dims{2, 2};