diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index c6cef974d3..fe71da6800 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -212,6 +212,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint64_t, Expand); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, bool, Expand); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16, Expand); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, string, Expand); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Loop); @@ -451,6 +452,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); @@ -882,6 +884,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Expand)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast::Compute(OpKernelContext* context) const { return Status::OK(); } -#define REG_EXPAND_KERNEL(TYPE) \ +#define REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE_NAME) \ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ Expand, \ 8, \ 12, \ - TYPE, \ + TYPE_NAME, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Expand_8); \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ Expand, \ 13, \ - TYPE, \ + TYPE_NAME, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Expand_8); +#define REG_EXPAND_KERNEL(TYPE) REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE) + REG_EXPAND_KERNEL(float) REG_EXPAND_KERNEL(double) REG_EXPAND_KERNEL(int8_t) @@ -1286,6 +1288,7 @@ REG_EXPAND_KERNEL(uint32_t) REG_EXPAND_KERNEL(uint64_t) REG_EXPAND_KERNEL(bool) REG_EXPAND_KERNEL(MLFloat16) +REG_EXPAND_KERNEL_WITH_TYPE_NAME(std::string, string) template <> Status Erf::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 043865e080..f09fe00798 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1851,6 +1851,38 @@ TEST(MathOpTest, Expand_8_1x3_float16) { MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))}); test.Run(); } +TEST(MathOpTest, Expand_8_3x3_string) { + OpTester test("Expand", 8); + test.AddInput("data_0", {1}, {"1"}); + test.AddInput("data_1", {2}, {3, 3}); + test.AddOutput("result", {3, 3}, + {"1", "1", "1", + "1", "1", "1", + "1", "1", "1"}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_3x1_string) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3}, {"1", "2", "3"}); + test.AddInput("data_1", {2}, {3, 1}); + test.AddOutput("result", {3, 3}, + {"1", "2", "3", + "1", "2", "3", + "1", "2", "3"}); + test.Run(); +} + +TEST(MathOpTest, Expand_8_1x3_string) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 1}, {"1", "2", "3"}); + test.AddInput("data_1", {2}, {1, 3}); + test.AddOutput("result", {3, 3}, + {"1", "1", "1", + "2", "2", "2", + "3", "3", "3"}); + test.Run(); +} TEST(MathOpTest, Erf) { OpTester test("Erf", 9);