mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Implement StringNormalizer (#69)
* Imlpement StringNormalizer Add mixed language tests, test case insentive path. * Create a locale on the fly. Default locale does not seem to create well. * Add CI language-pack-en to make default locale available. Catch and translate locale creation exception to make the message meaningful. * Make sure locales are configured on Ubuntu.
This commit is contained in:
parent
005f9dca96
commit
fbb23a9ed0
7 changed files with 540 additions and 5 deletions
|
|
@ -12,8 +12,8 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
using ::ONNX_NAMESPACE::AttributeProto;
|
||||
using ::ONNX_NAMESPACE::OPTIONAL;
|
||||
using ::ONNX_NAMESPACE::OpSchema;
|
||||
using ::ONNX_NAMESPACE::OPTIONAL;
|
||||
|
||||
void RegisterContribSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
|
||||
|
|
@ -452,6 +452,39 @@ The bounding box coordinates corresponding to the selected indices can then be o
|
|||
->set_dim_value(1);
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "X", "Strings to normalize", "T")
|
||||
.Output(0, "Y", "Normalized strings", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(string)"},
|
||||
"Input/Output is a string tensor")
|
||||
.Attr(
|
||||
"casechangeaction",
|
||||
"string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"",
|
||||
AttributeProto::STRING)
|
||||
.Attr(
|
||||
"is_case_sensitive",
|
||||
"Boolean. Whether the identification of stop words in X is case-sensitive.",
|
||||
AttributeProto::INT)
|
||||
.Attr(
|
||||
"stopwords",
|
||||
"List of stop words",
|
||||
AttributeProto::STRINGS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"locale",
|
||||
"Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
|
||||
output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING);
|
||||
})
|
||||
.SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");
|
||||
}
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
|
||||
|
|
@ -461,6 +494,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);
|
||||
|
||||
void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
||||
|
|
@ -474,6 +508,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression)>());
|
||||
}
|
||||
} // namespace contrib
|
||||
|
|
|
|||
244
onnxruntime/contrib_ops/cpu/string_normalizer.cc
Normal file
244
onnxruntime/contrib_ops/cpu/string_normalizer.cc
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_normalizer.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensor.h"
|
||||
|
||||
#include <codecvt>
|
||||
#include <locale>
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
StringNormalizer,
|
||||
1,
|
||||
string,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
|
||||
contrib::StringNormalizer);
|
||||
|
||||
namespace string_normalizer {
|
||||
const std::string conv_error("Conversion Error");
|
||||
const std::wstring wconv_error(L"Conversion Error");
|
||||
// performs tolower/toupper in-place
|
||||
inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction,
|
||||
std::wstring& wstr) {
|
||||
assert(caseaction != StringNormalizer::NONE);
|
||||
if (caseaction == StringNormalizer::LOWER) {
|
||||
std::transform(wstr.begin(), wstr.end(), wstr.begin(),
|
||||
[&loc](wchar_t ch) { return std::tolower(ch, loc); });
|
||||
} else {
|
||||
std::transform(wstr.begin(), wstr.end(), wstr.begin(),
|
||||
[&loc](wchar_t ch) { return std::toupper(ch, loc); });
|
||||
}
|
||||
}
|
||||
|
||||
template <class ForwardIter>
|
||||
Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx,
|
||||
const std::locale& loc,
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>>& converter,
|
||||
size_t N, size_t C,
|
||||
StringNormalizer::CaseAction caseaction) {
|
||||
std::vector<int64_t> output_dims;
|
||||
if (N == 1) {
|
||||
output_dims.push_back(1);
|
||||
}
|
||||
|
||||
// Empty output case
|
||||
if (C == 0) {
|
||||
output_dims.push_back(1);
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_ten = ctx->Output(0, output_shape);
|
||||
auto output_default = output_ten->template MutableData<std::string>();
|
||||
new (output_default) std::string();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
output_dims.push_back(C);
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
auto const output_data = output_tensor->template MutableData<std::string>();
|
||||
|
||||
size_t output_idx = 0;
|
||||
while (first != end) {
|
||||
auto& s = *first;
|
||||
if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) {
|
||||
std::wstring wstr = converter.from_bytes(s);
|
||||
if (wstr == wconv_error) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input contains invalid utf8 chars at: " + static_cast<const std::string&>(s));
|
||||
}
|
||||
// In place transform
|
||||
ChangeCase(loc, caseaction, wstr);
|
||||
new (output_data + output_idx) std::string(converter.to_bytes(wstr));
|
||||
} else {
|
||||
assert(caseaction == StringNormalizer::NONE);
|
||||
// Simple copy or move if the iterator points to a non-const string
|
||||
new (output_data + output_idx) std::string(std::move(s));
|
||||
}
|
||||
++output_idx;
|
||||
++first;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline std::locale GetLocale(const std::string& locale_name) {
|
||||
try {
|
||||
std::locale result(locale_name);
|
||||
return result;
|
||||
} catch (const std::runtime_error& e) {
|
||||
ONNXRUNTIME_THROW("Failed to construct locale with name:",
|
||||
locale_name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales");
|
||||
}
|
||||
}
|
||||
} // namespace string_normalizer
|
||||
|
||||
using namespace string_normalizer;
|
||||
|
||||
StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info),
|
||||
is_case_sensitive_(true),
|
||||
casechangeaction_(NONE),
|
||||
compare_caseaction_(NONE) {
|
||||
int64_t iscasesensitive = 0;
|
||||
Status status = info.GetAttr("is_case_sensitive", &iscasesensitive);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set");
|
||||
is_case_sensitive_ = iscasesensitive != 0;
|
||||
|
||||
std::string casechangeaction;
|
||||
status = info.GetAttr("casechangeaction", &casechangeaction);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set");
|
||||
if (casechangeaction == "LOWER") {
|
||||
casechangeaction_ = LOWER;
|
||||
} else if (casechangeaction == "UPPER") {
|
||||
casechangeaction_ = UPPER;
|
||||
} else if (casechangeaction == "NONE") {
|
||||
casechangeaction_ = NONE;
|
||||
} else {
|
||||
ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value");
|
||||
}
|
||||
|
||||
if (!is_case_sensitive_) {
|
||||
// Convert stop words to a case which can help us preserve the case of filtered strings
|
||||
compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER;
|
||||
}
|
||||
|
||||
locale_name_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8"));
|
||||
std::locale locale = GetLocale(locale_name_);
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
|
||||
|
||||
std::vector<std::string> swords = info.GetAttrsOrDefault<std::string>("stopwords");
|
||||
for (const auto& sw : swords) {
|
||||
ONNXRUNTIME_ENFORCE(!sw.empty(), "Empty stopwords not allowed");
|
||||
if (is_case_sensitive_) {
|
||||
auto p = stopwords_.insert(sw);
|
||||
ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
|
||||
} else {
|
||||
std::wstring wstr = converter.from_bytes(sw);
|
||||
ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars");
|
||||
ChangeCase(locale, compare_caseaction_, wstr);
|
||||
auto p = wstopwords_.insert(wstr);
|
||||
ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status StringNormalizer::Compute(OpKernelContext* ctx) const {
|
||||
using namespace string_normalizer;
|
||||
|
||||
auto X = ctx->Input<Tensor>(0);
|
||||
auto& input_dims = X->Shape().GetDims();
|
||||
|
||||
size_t N = 0;
|
||||
size_t C = 0;
|
||||
if (input_dims.size() == 1) {
|
||||
if (input_dims[0] < 1) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Single dimension value must be greater than 0");
|
||||
}
|
||||
C = input_dims[0];
|
||||
} else if (input_dims.size() == 2) {
|
||||
if (input_dims[0] != 1 || input_dims[1] < 1) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input dimensions are either[C > 0] or [1][C > 0] allowed");
|
||||
}
|
||||
N = 1;
|
||||
C = input_dims[1];
|
||||
} else {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input dimensions are either[C > 0] or [1][C > 0] allowed");
|
||||
}
|
||||
|
||||
Status status;
|
||||
std::locale locale = GetLocale(locale_name_);
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
|
||||
auto const input_data = X->template Data<std::string>();
|
||||
using StrRef = std::reference_wrapper<const std::string>;
|
||||
if (is_case_sensitive_) {
|
||||
if (!stopwords_.empty()) {
|
||||
std::vector<StrRef> filtered_strings;
|
||||
filtered_strings.reserve(C);
|
||||
auto first = input_data;
|
||||
auto const last = input_data + C;
|
||||
while (first != last) {
|
||||
const std::string& s = *first;
|
||||
if (0 == stopwords_.count(s)) {
|
||||
filtered_strings.push_back(std::cref(s));
|
||||
}
|
||||
++first;
|
||||
}
|
||||
status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter,
|
||||
N, filtered_strings.size(), casechangeaction_);
|
||||
} else {
|
||||
// Nothing to filter. Copy input to output and change case if needed
|
||||
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
|
||||
}
|
||||
} else {
|
||||
if (!wstopwords_.empty()) {
|
||||
// Filter input. When no case action is required
|
||||
// we simply store original string references.
|
||||
// Otherwise, we store converted strings.
|
||||
std::vector<StrRef> filtered_orignal_strings;
|
||||
std::vector<std::string> filtered_cased_strings;
|
||||
filtered_orignal_strings.reserve(C);
|
||||
filtered_cased_strings.reserve(C);
|
||||
auto first = input_data;
|
||||
auto const last = input_data + C;
|
||||
while (first != last) {
|
||||
const std::string& s = *first;
|
||||
std::wstring wstr = converter.from_bytes(s);
|
||||
if (wstr == wconv_error) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input contains invalid utf8 chars at: " + s);
|
||||
}
|
||||
ChangeCase(locale, compare_caseaction_, wstr);
|
||||
if (0 == wstopwords_.count(wstr)) {
|
||||
if (casechangeaction_ == NONE) {
|
||||
filtered_orignal_strings.push_back(std::cref(s));
|
||||
} else {
|
||||
filtered_cased_strings.push_back(converter.to_bytes(wstr));
|
||||
}
|
||||
}
|
||||
++first;
|
||||
}
|
||||
if (casechangeaction_ == NONE) {
|
||||
status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter,
|
||||
N, filtered_orignal_strings.size(), NONE);
|
||||
} else {
|
||||
status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale, converter,
|
||||
N, filtered_cased_strings.size(), NONE);
|
||||
}
|
||||
} else {
|
||||
// Nothing to filter. Copy input to output and change case if needed
|
||||
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
39
onnxruntime/contrib_ops/cpu/string_normalizer.h
Normal file
39
onnxruntime/contrib_ops/cpu/string_normalizer.h
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
#include <locale>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class StringNormalizer : public OpKernel {
|
||||
public:
|
||||
enum CaseAction {
|
||||
NONE = 0,
|
||||
LOWER = 1,
|
||||
UPPER = 2,
|
||||
};
|
||||
|
||||
explicit StringNormalizer(const OpKernelInfo& info);
|
||||
~StringNormalizer() = default;
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
bool is_case_sensitive_;
|
||||
CaseAction casechangeaction_;
|
||||
CaseAction compare_caseaction_; // used for case-insensitive compare
|
||||
std::string locale_name_;
|
||||
// Either if these are populated but not both
|
||||
std::unordered_set<std::string> stopwords_;
|
||||
std::unordered_set<std::wstring> wstopwords_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
212
onnxruntime/test/contrib_ops/string_normalizer_test.cc
Normal file
212
onnxruntime/test/contrib_ops/string_normalizer_test.cc
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <codecvt>
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace str_normalizer_test {
|
||||
constexpr const char* domain = onnxruntime::kMSDomain;
|
||||
const int opset_ver = 1;
|
||||
|
||||
void InitTestAttr(OpTester& test, const std::string& casechangeaction,
|
||||
bool iscasesensitive,
|
||||
const std::vector<std::string>& stopwords,
|
||||
const std::string& locale) {
|
||||
test.AddAttribute("casechangeaction", casechangeaction);
|
||||
test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive});
|
||||
if (!stopwords.empty()) {
|
||||
test.AddAttribute("stopwords", stopwords);
|
||||
}
|
||||
if (!locale.empty()) {
|
||||
test.AddAttribute("locale", locale);
|
||||
}
|
||||
}
|
||||
} // namespace str_normalizer_test
|
||||
|
||||
using namespace str_normalizer_test;
|
||||
|
||||
TEST(ContribOpTest, StringNormalizerTest) {
|
||||
// Test wrong 2 dimensions
|
||||
// - casesensitive approach
|
||||
// - no stopwords.
|
||||
// - No change case action
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
std::vector<std::string> output(input); // do the same for now
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed");
|
||||
}
|
||||
// - casesensitive approach
|
||||
// - no stopwords.
|
||||
// - No change case action
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
std::vector<std::string> output(input); // do the same for now
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// - casesensitive approach
|
||||
// - filter out monday
|
||||
// - No change case action
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "NONE", true, {"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<std::string> output = {std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddOutput<std::string>("Y", {3}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// - casesensitive approach
|
||||
// - filter out monday
|
||||
// - LOWER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "LOWER", true, {"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<std::string> output = {std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddOutput<std::string>("Y", {3}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// - casesensitive approach
|
||||
// - filter out monday
|
||||
// - UPPER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input = {std::string("monday"), std::string("tuesday"),
|
||||
std::string("wednesday"), std::string("thursday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<std::string> output = {std::string("TUESDAY"),
|
||||
std::string("WEDNESDAY"), std::string("THURSDAY")};
|
||||
test.AddOutput<std::string>("Y", {3}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// - case-SENSETIVE approach en_US locale
|
||||
// - we test the behavior of a mix of english, french, german, russian and chinese
|
||||
// with en_US locale
|
||||
// - filter out monday
|
||||
// - UPPER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{7};
|
||||
std::vector<std::string> input = {std::string(u8"monday"),
|
||||
std::string(u8"tuesday"),
|
||||
std::string(u8"Besançon"),
|
||||
std::string(u8"École élémentaire"),
|
||||
std::string(u8"Понедельник"),
|
||||
std::string(u8"mit freundlichen grüßen"),
|
||||
std::string(u8"中文")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
// en_US results (default)
|
||||
std::vector<std::string> output = {std::string(u8"TUESDAY"),
|
||||
// It does upper case cecedille, accented E
|
||||
// and german umlaut but fails
|
||||
// with german eszett
|
||||
std::string(u8"BESANÇON"),
|
||||
std::string(u8"ÉCOLE ÉLÉMENTAIRE"),
|
||||
// No issues with Cyrllic
|
||||
std::string(u8"ПОНЕДЕЛЬНИК"),
|
||||
std::string(u8"MIT FREUNDLICHEN GRÜßEN"),
|
||||
// Chinese do not have cases
|
||||
std::string(u8"中文")};
|
||||
test.AddOutput<std::string>("Y", {6}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// - case-INSENSETIVE approach en_US locale
|
||||
// - we test the behavior of a mix of english, french, german, russian and chinese
|
||||
// with en_US locale
|
||||
// - filter out monday
|
||||
// - UPPER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{7};
|
||||
std::vector<std::string> input = {std::string(u8"monday"),
|
||||
std::string(u8"tuesday"),
|
||||
std::string(u8"Besançon"),
|
||||
std::string(u8"École élémentaire"),
|
||||
std::string(u8"Понедельник"),
|
||||
std::string(u8"mit freundlichen grüßen"),
|
||||
std::string(u8"中文")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
// en_US results (default)
|
||||
std::vector<std::string> output = {std::string(u8"TUESDAY"),
|
||||
// It does upper case cecedille, accented E
|
||||
// and german umlaut but fails
|
||||
// with german eszett
|
||||
std::string(u8"BESANÇON"),
|
||||
std::string(u8"ÉCOLE ÉLÉMENTAIRE"),
|
||||
// No issues with Cyrllic
|
||||
std::string(u8"ПОНЕДЕЛЬНИК"),
|
||||
std::string(u8"MIT FREUNDLICHEN GRÜßEN"),
|
||||
// Chinese do not have cases
|
||||
std::string(u8"中文")};
|
||||
test.AddOutput<std::string>("Y", {6}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
// Empty output case
|
||||
// - casesensitive approach
|
||||
// - filter out monday
|
||||
// - UPPER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8");
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input = {std::string("monday"),
|
||||
std::string("monday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<std::string> output{""}; // One empty string
|
||||
test.AddOutput<std::string>("Y", {1}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
// Empty output case
|
||||
// - casesensitive approach
|
||||
// - filter out monday
|
||||
// - UPPER should produce the same output as they are all lower.
|
||||
{
|
||||
OpTester test("StringNormalizer", opset_ver, domain);
|
||||
InitTestAttr(test, "UPPER", true, {"monday"}, "");
|
||||
std::vector<int64_t> dims{1, 2};
|
||||
std::vector<std::string> input = {std::string("monday"),
|
||||
std::string("monday")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<std::string> output{""}; // One empty string
|
||||
test.AddOutput<std::string>("Y", {1, 1}, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -557,10 +557,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
|
|||
// Validate that an unused initializer doesn't break graph loading/resolution
|
||||
// and is removed as expected.
|
||||
TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
|
||||
OPERATOR_SCHEMA(Identity_Fake)
|
||||
.SetDoc("Identity.")
|
||||
.Input(0, "input_1", "docstr for input_1.", "tensor(int32)")
|
||||
.Output(0, "output_1", "docstr for output_1.", "tensor(int32)");
|
||||
ASSERT_TRUE(kSchemasRegistered);
|
||||
|
||||
Model model("UnusedInitializerIsIgnored");
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ apt-get update && apt-get install -y --no-install-recommends \
|
|||
sudo \
|
||||
gfortran \
|
||||
python3-dev \
|
||||
language-pack-en \
|
||||
libopenblas-dev \
|
||||
liblttng-ust0 \
|
||||
libcurl3 \
|
||||
|
|
@ -38,6 +39,9 @@ apt-get update && apt-get install -y --no-install-recommends \
|
|||
rsync libunwind8 libpng16-dev \
|
||||
python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest
|
||||
|
||||
locale-gen en_US.UTF-8
|
||||
update-locale LANG=en_US.UTF-8
|
||||
|
||||
if [ $PYTHON_VER != "3.5" ]; then
|
||||
apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VER} \
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ apt-get update && apt-get install -y --no-install-recommends \
|
|||
sudo \
|
||||
gfortran \
|
||||
python3-dev \
|
||||
language-pack-en \
|
||||
libopenblas-dev \
|
||||
liblttng-ust0 \
|
||||
libcurl3 \
|
||||
|
|
@ -28,6 +29,9 @@ apt-get update && apt-get install -y --no-install-recommends \
|
|||
rsync libunwind8 \
|
||||
python3-setuptools python3-numpy python3-wheel python python3-pip
|
||||
|
||||
locale-gen en_US.UTF-8
|
||||
update-locale LANG=en_US.UTF-8
|
||||
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
aria2c -q -d /tmp https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
|
||||
|
|
|
|||
Loading…
Reference in a new issue