Add support for string with operator Expand (#5751)

This commit is contained in:
Xavier Dupré 2020-11-10 18:38:20 +01:00 committed by GitHub
parent 4094a09a56
commit 8c74df2068
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 3 deletions

View file

@ -212,6 +212,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, uint64_t, Expand);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, bool, Expand);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16, Expand);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, string, Expand);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Loop);
@ -451,6 +452,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
@ -882,6 +884,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, MLFloat16,
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 12, string,
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, If)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
@ -1264,6 +1268,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string,
Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf)>,
// REVIEW(codemzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.

View file

@ -1259,21 +1259,23 @@ Status Expand_8<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
#define REG_EXPAND_KERNEL(TYPE) \
#define REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE_NAME) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
Expand, \
8, \
12, \
TYPE, \
TYPE_NAME, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
Expand_8<TYPE>); \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Expand, \
13, \
TYPE, \
TYPE_NAME, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
Expand_8<TYPE>);
#define REG_EXPAND_KERNEL(TYPE) REG_EXPAND_KERNEL_WITH_TYPE_NAME(TYPE, TYPE)
REG_EXPAND_KERNEL(float)
REG_EXPAND_KERNEL(double)
REG_EXPAND_KERNEL(int8_t)
@ -1286,6 +1288,7 @@ REG_EXPAND_KERNEL(uint32_t)
REG_EXPAND_KERNEL(uint64_t)
REG_EXPAND_KERNEL(bool)
REG_EXPAND_KERNEL(MLFloat16)
REG_EXPAND_KERNEL_WITH_TYPE_NAME(std::string, string)
template <>
Status Erf<float>::Compute(OpKernelContext* context) const {

View file

@ -1851,6 +1851,38 @@ TEST(MathOpTest, Expand_8_1x3_float16) {
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))});
test.Run();
}
TEST(MathOpTest, Expand_8_3x3_string) {
OpTester test("Expand", 8);
test.AddInput<std::string>("data_0", {1}, {"1"});
test.AddInput<int64_t>("data_1", {2}, {3, 3});
test.AddOutput<std::string>("result", {3, 3},
{"1", "1", "1",
"1", "1", "1",
"1", "1", "1"});
test.Run();
}
TEST(MathOpTest, Expand_8_3x1_string) {
OpTester test("Expand", 8);
test.AddInput<std::string>("data_0", {3}, {"1", "2", "3"});
test.AddInput<int64_t>("data_1", {2}, {3, 1});
test.AddOutput<std::string>("result", {3, 3},
{"1", "2", "3",
"1", "2", "3",
"1", "2", "3"});
test.Run();
}
TEST(MathOpTest, Expand_8_1x3_string) {
OpTester test("Expand", 8);
test.AddInput<std::string>("data_0", {3, 1}, {"1", "2", "3"});
test.AddInput<int64_t>("data_1", {2}, {1, 3});
test.AddOutput<std::string>("result", {3, 3},
{"1", "1", "1",
"2", "2", "2",
"3", "3", "3"});
test.Run();
}
TEST(MathOpTest, Erf) {
OpTester test("Erf", 9);