RegexFullMatch operator (#18002)

### Description
<!-- Describe your changes. -->



### Motivation and Context
Closes https://github.com/microsoft/onnxruntime/issues/17594.
This commit is contained in:
Aditya Goel 2024-01-11 23:50:07 +00:00 committed by GitHub
parent 08cf4fbcad
commit d8962d67f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 3 deletions

View file

@ -305,6 +305,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|RegexFullMatch|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(string)<br/> **T2** = tensor(bool)|
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)|
|||13|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|

View file

@ -990,6 +990,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@ -2449,6 +2450,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "regex_full_match.h"
#include "core/common/common.h"
namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
RegexFullMatch,
20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
RegexFullMatch);
RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr<std::string>("pattern")} {
ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
}
Status RegexFullMatch::Compute(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto input_data = input_tensor->template DataAsSpan<std::string>();
auto* output_tensor = context->Output(0, input_tensor->Shape());
auto output_data = output_tensor->template MutableDataAsSpan<bool>();
auto output_iter = output_data.begin();
auto input_iter = input_data.begin();
while (input_iter != input_data.end()) {
*output_iter = RE2::FullMatch(*input_iter, re_);
input_iter++;
output_iter++;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/op_kernel.h"
#include "re2/re2.h"
namespace onnxruntime {
class RegexFullMatch final : public OpKernel {
public:
explicit RegexFullMatch(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
private:
RE2 re_;
};
} // namespace onnxruntime

View file

@ -0,0 +1,119 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
static void RunTest(const std::initializer_list<int64_t>& dims, const std::initializer_list<std::string>& input, const std::string& pattern, const std::initializer_list<bool>& output) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", pattern);
test.AddInput<std::string>("Input", dims, input);
test.AddOutput<bool>("Output", dims, output);
test.Run();
}
TEST(RegexFullMatch, WebsiteMatch) {
RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
}
TEST(RegexFullMatch, EmailMatch) {
RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
}
TEST(RegexFullMatch, MultibyteMatch) {
RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besançon.*)", {
true,
});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besancon.*)", {
false,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
true,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
false,
});
RunTest({
3,
},
{"онедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
false,
true,
false,
});
RunTest({
3,
},
{"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
false,
false,
true,
});
RunTest({
3,
},
{"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
false,
false,
true,
});
RunTest({
3,
},
{"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
true,
false,
false,
});
}
TEST(RegexFullMatch, InvalidPattern) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", R"([a-z)");
test.AddInput<std::string>("Input", {
1,
},
{
"abcdef",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
}
TEST(RegexFullMatch, NonUtf8Pattern) {
uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));
test.AddInput<std::string>("Input", {
1,
},
{
"abcd",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
}
} // namespace test
} // namespace onnxruntime

View file

@ -248,9 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
"^test_regex_full_match_basic",
"^test_regex_full_match_email_domain",
"^test_regex_full_match_empty",
"^test_string_split_basic",
"^test_string_split_consecutive_delimiters",
"^test_string_split_empty_string_delimiter",