mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
RegexFullMatch operator (#18002)
### Description <!-- Describe your changes. --> ### Motivation and Context Closes https://github.com/microsoft/onnxruntime/issues/17594.
This commit is contained in:
parent
08cf4fbcad
commit
d8962d67f4
6 changed files with 177 additions and 3 deletions
|
|
@ -305,6 +305,7 @@ Do not modify directly.*
|
|||
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|RegexFullMatch|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(string)<br/> **T2** = tensor(bool)|
|
||||
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)|
|
||||
|||13|**T** = tensor(double), tensor(float)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
|||
|
|
@ -990,6 +990,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
#endif
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
|
||||
|
||||
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
|
||||
|
||||
|
|
@ -2449,6 +2450,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
#endif
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
35
onnxruntime/core/providers/cpu/text/regex_full_match.cc
Normal file
35
onnxruntime/core/providers/cpu/text/regex_full_match.cc
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "regex_full_match.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
RegexFullMatch,
|
||||
20,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
|
||||
RegexFullMatch);
|
||||
|
||||
RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr<std::string>("pattern")} {
|
||||
ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
|
||||
}
|
||||
|
||||
Status RegexFullMatch::Compute(OpKernelContext* context) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto input_data = input_tensor->template DataAsSpan<std::string>();
|
||||
auto* output_tensor = context->Output(0, input_tensor->Shape());
|
||||
auto output_data = output_tensor->template MutableDataAsSpan<bool>();
|
||||
auto output_iter = output_data.begin();
|
||||
auto input_iter = input_data.begin();
|
||||
while (input_iter != input_data.end()) {
|
||||
*output_iter = RE2::FullMatch(*input_iter, re_);
|
||||
input_iter++;
|
||||
output_iter++;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
20
onnxruntime/core/providers/cpu/text/regex_full_match.h
Normal file
20
onnxruntime/core/providers/cpu/text/regex_full_match.h
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "re2/re2.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class RegexFullMatch final : public OpKernel {
|
||||
public:
|
||||
explicit RegexFullMatch(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
RE2 re_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
119
onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
Normal file
119
onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
static void RunTest(const std::initializer_list<int64_t>& dims, const std::initializer_list<std::string>& input, const std::string& pattern, const std::initializer_list<bool>& output) {
|
||||
OpTester test("RegexFullMatch", 20, kOnnxDomain);
|
||||
test.AddAttribute("pattern", pattern);
|
||||
test.AddInput<std::string>("Input", dims, input);
|
||||
test.AddOutput<bool>("Output", dims, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(RegexFullMatch, WebsiteMatch) {
|
||||
RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
|
||||
}
|
||||
|
||||
TEST(RegexFullMatch, EmailMatch) {
|
||||
RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
|
||||
}
|
||||
|
||||
TEST(RegexFullMatch, MultibyteMatch) {
|
||||
RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
|
||||
RunTest({
|
||||
1,
|
||||
},
|
||||
{"une cédille like in Besançon"}, R"(.*Besançon.*)", {
|
||||
true,
|
||||
});
|
||||
RunTest({
|
||||
1,
|
||||
},
|
||||
{"une cédille like in Besançon"}, R"(.*Besancon.*)", {
|
||||
false,
|
||||
});
|
||||
RunTest({
|
||||
1,
|
||||
},
|
||||
{"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
|
||||
true,
|
||||
});
|
||||
RunTest({
|
||||
1,
|
||||
},
|
||||
{"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
|
||||
false,
|
||||
});
|
||||
RunTest({
|
||||
3,
|
||||
},
|
||||
{"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
});
|
||||
RunTest({
|
||||
3,
|
||||
},
|
||||
{"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
});
|
||||
RunTest({
|
||||
3,
|
||||
},
|
||||
{"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
});
|
||||
RunTest({
|
||||
3,
|
||||
},
|
||||
{"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
});
|
||||
}
|
||||
|
||||
TEST(RegexFullMatch, InvalidPattern) {
|
||||
OpTester test("RegexFullMatch", 20, kOnnxDomain);
|
||||
test.AddAttribute("pattern", R"([a-z)");
|
||||
test.AddInput<std::string>("Input", {
|
||||
1,
|
||||
},
|
||||
{
|
||||
"abcdef",
|
||||
});
|
||||
test.AddOutput<bool>("Output", {
|
||||
1,
|
||||
},
|
||||
{
|
||||
false,
|
||||
});
|
||||
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
|
||||
}
|
||||
|
||||
TEST(RegexFullMatch, NonUtf8Pattern) {
|
||||
uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
|
||||
OpTester test("RegexFullMatch", 20, kOnnxDomain);
|
||||
test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));
|
||||
test.AddInput<std::string>("Input", {
|
||||
1,
|
||||
},
|
||||
{
|
||||
"abcd",
|
||||
});
|
||||
test.AddOutput<bool>("Output", {
|
||||
1,
|
||||
},
|
||||
{
|
||||
false,
|
||||
});
|
||||
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -248,9 +248,6 @@
|
|||
"^test_image_decoder_decode_pnm_rgb",
|
||||
"^test_image_decoder_decode_tiff_rgb",
|
||||
"^test_image_decoder_decode_webp_rgb",
|
||||
"^test_regex_full_match_basic",
|
||||
"^test_regex_full_match_email_domain",
|
||||
"^test_regex_full_match_empty",
|
||||
"^test_string_split_basic",
|
||||
"^test_string_split_consecutive_delimiters",
|
||||
"^test_string_split_empty_string_delimiter",
|
||||
|
|
|
|||
Loading…
Reference in a new issue