mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add support for string with operator Expand (#5751)
This commit is contained in:
parent
4094a09a56
commit
8c74df2068
3 changed files with 44 additions and 3 deletions
|
|
@ -212,6 +212,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint64_t, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, bool, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, string, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Loop);
|
||||
|
|
@ -451,6 +452,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
|
||||
|
|
@ -882,6 +884,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16,
|
||||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, string,
|
||||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
|
||||
|
|
@ -1264,6 +1268,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
|
||||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string,
|
||||
Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf)>,
|
||||
// REVIEW(codemzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
|
||||
// However these types work on GPU implementation.
|
||||
|
|
|
|||
|
|
@ -1259,21 +1259,23 @@ Status Expand_8<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REG_EXPAND_KERNEL(TYPE) \
|
||||
#define REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE_NAME) \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
Expand, \
|
||||
8, \
|
||||
12, \
|
||||
TYPE, \
|
||||
TYPE_NAME, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
Expand_8<TYPE>); \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
Expand, \
|
||||
13, \
|
||||
TYPE, \
|
||||
TYPE_NAME, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
Expand_8<TYPE>);
|
||||
|
||||
#define REG_EXPAND_KERNEL(TYPE) REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE)
|
||||
|
||||
REG_EXPAND_KERNEL(float)
|
||||
REG_EXPAND_KERNEL(double)
|
||||
REG_EXPAND_KERNEL(int8_t)
|
||||
|
|
@ -1286,6 +1288,7 @@ REG_EXPAND_KERNEL(uint32_t)
|
|||
REG_EXPAND_KERNEL(uint64_t)
|
||||
REG_EXPAND_KERNEL(bool)
|
||||
REG_EXPAND_KERNEL(MLFloat16)
|
||||
REG_EXPAND_KERNEL_WITH_TYPE_NAME(std::string, string)
|
||||
|
||||
template <>
|
||||
Status Erf<float>::Compute(OpKernelContext* context) const {
|
||||
|
|
|
|||
|
|
@ -1851,6 +1851,38 @@ TEST(MathOpTest, Expand_8_1x3_float16) {
|
|||
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
||||
test.Run();
|
||||
}
|
||||
TEST(MathOpTest, Expand_8_3x3_string) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<std::string>("data_0", {1}, {"1"});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
||||
test.AddOutput<std::string>("result", {3, 3},
|
||||
{"1", "1", "1",
|
||||
"1", "1", "1",
|
||||
"1", "1", "1"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x1_string) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<std::string>("data_0", {3}, {"1", "2", "3"});
|
||||
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
||||
test.AddOutput<std::string>("result", {3, 3},
|
||||
{"1", "2", "3",
|
||||
"1", "2", "3",
|
||||
"1", "2", "3"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_1x3_string) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<std::string>("data_0", {3, 1}, {"1", "2", "3"});
|
||||
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
||||
test.AddOutput<std::string>("result", {3, 3},
|
||||
{"1", "1", "1",
|
||||
"2", "2", "2",
|
||||
"3", "3", "3"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Erf) {
|
||||
OpTester test("Erf", 9);
|
||||
|
|
|
|||
Loading…
Reference in a new issue