mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Implement Tokenizer op (#31)
* Implement separator tokenizer with TST. TODO: Clarify what to do if the output is empty and no start/end text markers required. Also see if the current search algo is acceptable. * Add utf8 util test * For empty output produce [C] -> [C][0], [N][C] -> [N][C][0] * Augument TST search with match conflict resolution in favor of earlier specified pattern matches. * Address MAcOS build error. * Adjust error message * Address review comments. * Remove nested loops. * Remove 3rd party utf8 validation code. * Address review comments part I. * Move padding outside start/end markers. Split unit tests for invidividual test cases. * Fix a common prefix bug reported by Xavier.
This commit is contained in:
parent
a68f5ccfd9
commit
c52636e187
6 changed files with 1440 additions and 0 deletions
|
|
@ -92,6 +92,36 @@ Sample echo operator.)DOC");
|
|||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetDoc(R"DOC(Returns which elements of the input are NaN.)DOC");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "X", "Strings to tokenize", "T")
|
||||
.Output(0, "Y", "Tokenized strings", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(string)"},
|
||||
"Input/Output is a string tensor")
|
||||
.Attr(
|
||||
"mark",
|
||||
"Boolean whether to mark the beginning/end character with start of text character (0x02)/end of text character (0x03).",
|
||||
AttributeProto::INT)
|
||||
.Attr(
|
||||
"pad_value",
|
||||
"The string used to pad output tensors when the tokens extracted doesn't match the maximum number of tokens found. If start/end markers are needed, padding will appear outside the markers.",
|
||||
AttributeProto::STRING)
|
||||
.Attr(
|
||||
"separators",
|
||||
"The list of separators, two consecutive segments in X connected by a separator would be divided into two tokens.",
|
||||
AttributeProto::STRINGS)
|
||||
.Attr(
|
||||
"mincharnum",
|
||||
"Minimum number of characters allowed in the output. For example, if mincharnum is 2, tokens such as \"A\" and \"B\" would be ignored",
|
||||
AttributeProto::INT)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
})
|
||||
.SetDoc(R"DOC(Tokenizer divides each string in X into a vector of strings along the last axis. All input strings including attributes are UTF-8 encoded.)DOC");
|
||||
|
||||
// Operators for linear 8 bit quanitzation support.
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
@ -491,6 +521,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
|
||||
|
|
@ -505,6 +536,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
|
||||
|
|
|
|||
489
onnxruntime/contrib_ops/cpu/tokenizer.cc
Normal file
489
onnxruntime/contrib_ops/cpu/tokenizer.cc
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "tokenizer.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensor.h"
|
||||
|
||||
#include "core/common/utf8_util.h"
|
||||
|
||||
#include <codecvt>
|
||||
#include <locale>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
using namespace utf8_util;
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
Tokenizer,
|
||||
1,
|
||||
string,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
|
||||
contrib::Tokenizer);
|
||||
|
||||
namespace tokenizer_details {
|
||||
|
||||
const char start_text = 0x2;
|
||||
const char end_text = 0x3;
|
||||
|
||||
const std::string conv_error("Conversion Error");
|
||||
const std::wstring wconv_error(L"Conversion Error");
|
||||
|
||||
// Use a Trie like structure for searching multiple strings
|
||||
// at once but convert it to a ternary tree for saving space.
|
||||
// We insert separators in the same order they are specified.
|
||||
// Template parameter is a CharT which can be a char/wchar_t
|
||||
// or anything else that supports operator ><,== as long as
|
||||
// this is not a variable length sequence. We convert utf8 to utf16
|
||||
// before inserting.
|
||||
// Value is a supplementary information useful for search hit
|
||||
// and is present in the nodes that terminate the whole search pattern
|
||||
template <class CharT, class Value>
|
||||
class TernarySearchTree {
|
||||
private:
|
||||
struct Node {
|
||||
std::unique_ptr<Node> left_;
|
||||
std::unique_ptr<Node> mid_;
|
||||
std::unique_ptr<Node> right_;
|
||||
CharT c_; // character
|
||||
Value value_;
|
||||
bool has_val_;
|
||||
|
||||
explicit Node(CharT c) : c_(c), value_(), has_val_(false) {
|
||||
}
|
||||
~Node() = default;
|
||||
};
|
||||
|
||||
struct GetState {
|
||||
const CharT* const str_;
|
||||
const size_t len_;
|
||||
size_t depth_;
|
||||
const Value* result_;
|
||||
};
|
||||
|
||||
public:
|
||||
TernarySearchTree() = default;
|
||||
~TernarySearchTree() = default;
|
||||
|
||||
/**
|
||||
* Returns a ptr to an associated value and nullptr on search miss.
|
||||
* Must use default constructed state.
|
||||
*/
|
||||
const Value* get(const CharT* str, size_t len) const {
|
||||
if (len == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
GetState get_state{str, len, 0, nullptr};
|
||||
get(root_.get(), get_state);
|
||||
return get_state.result_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if successful and false on empty strings
|
||||
* and duplicates.
|
||||
*/
|
||||
bool put(const CharT* str, size_t len, const Value& v) {
|
||||
if (len < 1) {
|
||||
assert(false);
|
||||
return false;
|
||||
}
|
||||
Node* new_root = put(root_.get(), str, len, v, 0);
|
||||
if (new_root != nullptr) {
|
||||
root_.release();
|
||||
root_.reset(new_root);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
void update_state(const Node* node, GetState& state) const {
|
||||
if (node->has_val_) {
|
||||
if (state.result_ == nullptr) {
|
||||
state.result_ = &node->value_;
|
||||
} else if (node->value_ < *state.result_) {
|
||||
state.result_ = &node->value_;
|
||||
}
|
||||
}
|
||||
}
|
||||
void get(const Node* node, GetState& state) const {
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
assert(state.depth_ < state.len_);
|
||||
CharT c = state.str_[state.depth_];
|
||||
if (c < node->c_) {
|
||||
get(node->left_.get(), state);
|
||||
return;
|
||||
} else if (c > node->c_) {
|
||||
get(node->right_.get(), state);
|
||||
return;
|
||||
} else if (state.depth_ < (state.len_ - 1)) {
|
||||
// Check if we have a match at this node
|
||||
update_state(node, state);
|
||||
if (node->mid_ != nullptr) {
|
||||
++state.depth_;
|
||||
get(node->mid_.get(), state);
|
||||
}
|
||||
return;
|
||||
}
|
||||
update_state(node, state);
|
||||
}
|
||||
|
||||
Node* put(Node* node, const CharT* str, size_t len, const Value& v, size_t depth) {
|
||||
CharT c = str[depth];
|
||||
|
||||
std::unique_ptr<Node> new_node;
|
||||
if (node == nullptr) {
|
||||
new_node.reset(new Node(c));
|
||||
}
|
||||
|
||||
Node* new_link = nullptr;
|
||||
Node* n = (node != nullptr) ? node : new_node.get();
|
||||
if (c < n->c_) {
|
||||
new_link = put(n->left_.get(), str, len, v, depth);
|
||||
if (new_link != nullptr) {
|
||||
n->left_.release();
|
||||
n->left_.reset(new_link);
|
||||
}
|
||||
} else if (c > n->c_) {
|
||||
new_link = put(n->right_.get(), str, len, v, depth);
|
||||
if (new_link != nullptr) {
|
||||
n->right_.release();
|
||||
n->right_.reset(new_link);
|
||||
}
|
||||
} else if (depth < (len - 1)) {
|
||||
new_link = put(n->mid_.get(), str, len, v, depth + 1);
|
||||
if (new_link != nullptr) {
|
||||
n->mid_.release();
|
||||
n->mid_.reset(new_link);
|
||||
}
|
||||
} else {
|
||||
if (!n->has_val_) {
|
||||
n->value_ = v;
|
||||
n->has_val_ = true;
|
||||
new_link = n;
|
||||
}
|
||||
}
|
||||
if (new_link != nullptr) {
|
||||
new_node.release();
|
||||
return n;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<Node> root_;
|
||||
};
|
||||
|
||||
// We store the length of the original pattern within the
|
||||
// Ternary Tree. This allows us to cut out the length of the matching
|
||||
// separator from the original string.
|
||||
struct SearchValue {
|
||||
size_t w_len;
|
||||
int priority_;
|
||||
bool operator<(const SearchValue& o) const {
|
||||
return priority_ < o.priority_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tokenizer_details
|
||||
|
||||
using namespace tokenizer_details;
|
||||
|
||||
struct Tokenizer::SearchData {
|
||||
TernarySearchTree<wchar_t, SearchValue> tst_;
|
||||
};
|
||||
|
||||
Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t mark = 0;
|
||||
auto status = info.GetAttr("mark", &mark);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mark is not set");
|
||||
mark_ = mark != 0;
|
||||
|
||||
status = info.GetAttr("pad_value", &pad_value_);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute pad_value is not set");
|
||||
|
||||
status = info.GetAttr("mincharnum", &mincharnum_);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mincharnum is not set");
|
||||
ONNXRUNTIME_ENFORCE(mincharnum_ > 0, "attribute mincharnum must have a positive value");
|
||||
|
||||
std::vector<std::string> separators;
|
||||
status = info.GetAttrs<std::string>("separators", separators);
|
||||
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute separators is not set");
|
||||
ONNXRUNTIME_ENFORCE(!separators.empty(), "Requires at least one separator");
|
||||
|
||||
char_tokenezation_ = (separators.size() == 1 &&
|
||||
separators[0].empty());
|
||||
|
||||
ONNXRUNTIME_ENFORCE(!char_tokenezation_ || mincharnum_ < 2,
|
||||
"mincharnum is too big for char level tokenezation");
|
||||
|
||||
// Create TST and insert separators
|
||||
if (!char_tokenezation_) {
|
||||
std::unique_ptr<SearchData> sd(std::make_unique<SearchData>());
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
|
||||
int priority = 0; // earlier search patterns get priority
|
||||
for (const auto& sep : separators) {
|
||||
ONNXRUNTIME_ENFORCE(!sep.empty(), "No empty separators allowed");
|
||||
std::wstring wsep = converter.from_bytes(sep);
|
||||
ONNXRUNTIME_ENFORCE(wsep != wconv_error, "Separator strings contains invalid utf8 chars");
|
||||
bool result = sd->tst_.put(wsep.c_str(), wsep.length(), {wsep.length(), priority});
|
||||
ONNXRUNTIME_ENFORCE(result, "duplicate separator detected");
|
||||
++priority;
|
||||
}
|
||||
search_data_.swap(sd);
|
||||
}
|
||||
}
|
||||
|
||||
// Make SearchData definition available for destruction
|
||||
Tokenizer ::~Tokenizer() {
|
||||
}
|
||||
|
||||
Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
|
||||
const std::vector<int64_t>& input_dims) const {
|
||||
// With char tokenzation we get as many tokens as the number of
|
||||
// utf8 characters in the string. So for every string we calculate its character(utf8) length
|
||||
// add padding and add start/end test separators if necessary
|
||||
size_t max_tokens = 0;
|
||||
auto X = ctx->Input<Tensor>(0);
|
||||
auto const input_data = X->template Data<std::string>();
|
||||
auto curr_input = input_data;
|
||||
auto const last = input_data + N * C;
|
||||
while (curr_input != last) {
|
||||
const auto& s = *curr_input;
|
||||
size_t tokens = 0; // length in utf8 chars
|
||||
if (!utf8_validate(reinterpret_cast<const unsigned char*>(s.data()), s.size(),
|
||||
tokens)) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input string contains invalid utf8 chars: " + s);
|
||||
}
|
||||
if (mark_) {
|
||||
tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
max_tokens = std::max(max_tokens, tokens);
|
||||
++curr_input;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims(input_dims);
|
||||
// Check if we have no output due to apparently empty strings input.
|
||||
if ((max_tokens - mark_ * 2) == 0) {
|
||||
output_dims.push_back(0);
|
||||
TensorShape output_shape(output_dims);
|
||||
ctx->Output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
output_dims.push_back(max_tokens);
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
auto const output_data = output_tensor->template MutableData<std::string>();
|
||||
size_t output_index = 0;
|
||||
curr_input = input_data;
|
||||
while (curr_input != last) {
|
||||
const auto& s = *curr_input;
|
||||
if (mark_) {
|
||||
new (output_data + output_index) std::string(&start_text, 1);
|
||||
++output_index;
|
||||
}
|
||||
size_t tokens = 0;
|
||||
const size_t str_len = s.size();
|
||||
for (size_t token_idx = 0; token_idx < str_len;) {
|
||||
size_t tlen = 0;
|
||||
bool result = utf8_bytes(static_cast<unsigned char>(s[token_idx]), tlen);
|
||||
assert(result);
|
||||
(void)result;
|
||||
assert(token_idx + tlen <= str_len);
|
||||
new (output_data + output_index) std::string(s.substr(token_idx, tlen));
|
||||
++output_index;
|
||||
token_idx += tlen;
|
||||
++tokens;
|
||||
}
|
||||
if (mark_) {
|
||||
new (output_data + output_index) std::string(&end_text, 1);
|
||||
++output_index;
|
||||
}
|
||||
// Padding strings
|
||||
assert(tokens + (mark_ * 2) <= max_tokens);
|
||||
const size_t pads = max_tokens - (mark_ * 2) - tokens;
|
||||
for (size_t p = 0; p < pads; ++p) {
|
||||
new (output_data + output_index) std::string(pad_value_);
|
||||
++output_index;
|
||||
}
|
||||
++curr_input;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx,
|
||||
size_t N, size_t C,
|
||||
const std::vector<int64_t>& input_dims) const {
|
||||
struct Match {
|
||||
int priority_;
|
||||
size_t offset_;
|
||||
size_t size_;
|
||||
// create a conflict for overlapping matches
|
||||
// thus if they overlap neither is less than the other
|
||||
// and they are considered equal
|
||||
bool operator<(const Match& o) const {
|
||||
return (offset_ + size_) <= o.offset_;
|
||||
}
|
||||
};
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
|
||||
// Scan all strings and attempt to find separators in them
|
||||
// collect all the output tokens here
|
||||
size_t max_tokens = 0;
|
||||
std::vector<std::vector<std::wstring>> tokenized_strings;
|
||||
tokenized_strings.reserve(N * C);
|
||||
auto X = ctx->Input<Tensor>(0);
|
||||
auto const input_data = X->template Data<std::string>();
|
||||
auto curr_input = input_data;
|
||||
auto const last = input_data + N * C;
|
||||
while (curr_input != last) {
|
||||
const auto& s = *curr_input;
|
||||
std::wstring wstr = converter.from_bytes(s);
|
||||
if (wstr == wconv_error) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Invalid utf8 chars in the input: " + s);
|
||||
}
|
||||
|
||||
std::set<Match> matches;
|
||||
const wchar_t* ws = wstr.c_str();
|
||||
size_t len_remaining = wstr.length();
|
||||
size_t offset = 0;
|
||||
while (len_remaining > 0) {
|
||||
const auto* val = search_data_->tst_.get(ws, len_remaining);
|
||||
if (val != nullptr) {
|
||||
auto p = matches.insert({val->priority_, offset, val->w_len});
|
||||
while (!p.second && val->priority_ < p.first->priority_) {
|
||||
// if overlapping matches of the same pattern(priority), then
|
||||
// the earlier match naturally wins
|
||||
matches.erase(p.first);
|
||||
p = matches.insert({val->priority_, offset, val->w_len});
|
||||
}
|
||||
}
|
||||
++ws;
|
||||
++offset;
|
||||
--len_remaining;
|
||||
}
|
||||
|
||||
// Tokenize
|
||||
tokenized_strings.emplace_back();
|
||||
auto& row_tokens = tokenized_strings.back();
|
||||
row_tokens.reserve(matches.size() + 1);
|
||||
ws = wstr.c_str();
|
||||
offset = 0;
|
||||
for (const auto& m : matches) {
|
||||
assert(m.offset_ >= offset);
|
||||
size_t sz = (m.offset_ - offset);
|
||||
if (sz > 0 && sz >= size_t(mincharnum_)) {
|
||||
row_tokens.emplace_back(ws, sz);
|
||||
}
|
||||
offset = m.offset_ + m.size_;
|
||||
ws = wstr.c_str() + offset;
|
||||
}
|
||||
assert(offset <= wstr.length());
|
||||
if (offset < wstr.length()) {
|
||||
row_tokens.emplace_back(ws, wstr.length() - offset);
|
||||
}
|
||||
|
||||
size_t tokens = row_tokens.size();
|
||||
if (mark_) {
|
||||
tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
max_tokens = std::max(max_tokens, tokens);
|
||||
++curr_input;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims(input_dims);
|
||||
// Check if we have no output due to either empty input
|
||||
// everything is a separator
|
||||
if ((max_tokens - mark_ * 2) == 0) {
|
||||
output_dims.push_back(0);
|
||||
TensorShape output_shape(output_dims);
|
||||
ctx->Output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
output_dims.push_back(max_tokens);
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
auto const output_data = output_tensor->template MutableData<std::string>();
|
||||
|
||||
#ifdef _DEBUG
|
||||
const size_t max_output_index = N * C * max_tokens;
|
||||
#endif
|
||||
size_t output_index = 0;
|
||||
for (auto& row : tokenized_strings) {
|
||||
#ifdef _DEBUG
|
||||
size_t c_idx = output_index;
|
||||
#endif
|
||||
if (mark_) {
|
||||
new (output_data + output_index) std::string(&start_text, 1);
|
||||
++output_index;
|
||||
}
|
||||
// Output tokens for this row
|
||||
for (auto& token : row) {
|
||||
new (output_data + output_index) std::string(converter.to_bytes(token));
|
||||
++output_index;
|
||||
}
|
||||
if (mark_) {
|
||||
new (output_data + output_index) std::string(&end_text, 1);
|
||||
++output_index;
|
||||
}
|
||||
const size_t pads = max_tokens - (mark_ * 2) - row.size();
|
||||
for (size_t p = 0; p < pads; ++p) {
|
||||
new (output_data + output_index) std::string(pad_value_);
|
||||
++output_index;
|
||||
}
|
||||
#ifdef _DEBUG
|
||||
assert(output_index <= max_output_index);
|
||||
assert((output_index - c_idx) <= max_tokens);
|
||||
#endif
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tokenizer::Compute(OpKernelContext* ctx) const {
|
||||
// Get input buffer ptr
|
||||
auto X = ctx->Input<Tensor>(0);
|
||||
if (X->DataType() != DataTypeImpl::GetType<std::string>()) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"tensor(string) expected as input");
|
||||
}
|
||||
|
||||
auto& input_dims = X->Shape().GetDims();
|
||||
size_t N = 0;
|
||||
size_t C = 0;
|
||||
if (input_dims.size() == 1) {
|
||||
N = 1;
|
||||
if (input_dims[0] < 1) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Invalid C dimension value");
|
||||
}
|
||||
C = input_dims[0];
|
||||
} else if (input_dims.size() == 2) {
|
||||
if (input_dims[0] < 1 || input_dims[1] < 1) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Invalid N and/or C dimension values");
|
||||
}
|
||||
N = input_dims[0];
|
||||
C = input_dims[1];
|
||||
} else {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input dimensions are either [C] or [N][C] allowed");
|
||||
}
|
||||
|
||||
Status s;
|
||||
if (char_tokenezation_) {
|
||||
s = CharTokenize(ctx, N, C, input_dims);
|
||||
} else {
|
||||
s = SeparatorTokenize(ctx, N, C, input_dims);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
36
onnxruntime/contrib_ops/cpu/tokenizer.h
Normal file
36
onnxruntime/contrib_ops/cpu/tokenizer.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
class Tokenizer final : public OpKernel {
|
||||
public:
|
||||
explicit Tokenizer(const OpKernelInfo& info);
|
||||
Tokenizer(const Tokenizer&) = delete;
|
||||
Tokenizer& operator=(const Tokenizer&) = delete;
|
||||
~Tokenizer();
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status CharTokenize(OpKernelContext* context, size_t N, size_t C,
|
||||
const std::vector<int64_t>& input_dims) const;
|
||||
|
||||
Status SeparatorTokenize(OpKernelContext* context, size_t N, size_t C,
|
||||
const std::vector<int64_t>& input_dims) const;
|
||||
|
||||
bool mark_;
|
||||
std::string pad_value_;
|
||||
int64_t mincharnum_;
|
||||
bool char_tokenezation_;
|
||||
struct SearchData;
|
||||
std::unique_ptr<SearchData> search_data_;
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
129
onnxruntime/core/common/utf8_util.h
Normal file
129
onnxruntime/core/common/utf8_util.h
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace utf8_util {
|
||||
|
||||
// Returns the number of bytes in the utf8 character
|
||||
// by analyzing its leading byte
|
||||
inline bool utf8_bytes(unsigned char ch, size_t& len) {
|
||||
if ((ch & 0x80) == 0) {
|
||||
len = 1;
|
||||
return true;
|
||||
}
|
||||
if ((ch & 0xE0) == 0xC0) {
|
||||
len = 2;
|
||||
return true;
|
||||
}
|
||||
unsigned int result = (ch & 0xF0);
|
||||
if (result == 0xE0) {
|
||||
len = 3;
|
||||
return true;
|
||||
}
|
||||
if (result == 0xF0) {
|
||||
len = 4;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool utf8_validate(const unsigned char* s, size_t len, size_t& utf8_chars) {
|
||||
size_t utf8_len = 0;
|
||||
size_t idx = 0;
|
||||
while (idx < len) {
|
||||
size_t bytes = 0;
|
||||
auto ch = s[idx];
|
||||
if (utf8_bytes(ch, bytes)) {
|
||||
switch (bytes) {
|
||||
case 1:
|
||||
break;
|
||||
case 2: {
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
} break; // 2
|
||||
case 3: {
|
||||
auto ch1 = s[idx];
|
||||
switch (ch1) {
|
||||
case 0xE0u:
|
||||
if (++idx >= len || s[idx] < 0xA0u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case 0xEDu:
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x9Fu) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
if ((ch1 >= 0xE1u && ch1 <= 0xECu) ||
|
||||
(ch1 >= 0xEEu && ch1 <= 0xEFu)) {
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
// validate byte 3
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
} break; // 3
|
||||
case 4: {
|
||||
auto ch1 = s[idx];
|
||||
switch (ch1) {
|
||||
case 0xF0u: {
|
||||
if (++idx >= len || s[idx] < 0x90u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case 0xF4u: {
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x8Fu) {
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
if (ch1 >= 0xF1u && ch1 <= 0xF3u) {
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
// validate bytes 3 and 4
|
||||
size_t stop = idx + 2;
|
||||
while (idx < stop) {
|
||||
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} break; // 4
|
||||
default:
|
||||
// no chars longer than 4
|
||||
return false;
|
||||
} // switch bytes
|
||||
++idx;
|
||||
++utf8_len;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// End index must match
|
||||
// the end of the last byte sequence.
|
||||
if (idx != len) {
|
||||
return false;
|
||||
}
|
||||
utf8_chars = utf8_len;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace utf8_util
|
||||
} // namespace onnxruntime
|
||||
46
onnxruntime/test/common/utf8_util_test.cc
Normal file
46
onnxruntime/test/common/utf8_util_test.cc
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/utf8_util.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
struct Sample {
|
||||
const char* sequence;
|
||||
bool valid;
|
||||
};
|
||||
|
||||
const std::vector<Sample> samples = {
|
||||
{"a", true},
|
||||
{"\xc3\xb1", true},
|
||||
{"\xc3\x28", false},
|
||||
{"\xa0\xa1", false},
|
||||
{"\xe2\x82\xa1", true},
|
||||
{"\xe2\x28\xa1", false},
|
||||
{"\xe2\x82\x28", false},
|
||||
{"\xf0\x90\x8c\xbc", true},
|
||||
{"\xf0\x28\x8c\xbc", false},
|
||||
{"\xf0\x90\x28\xbc", false},
|
||||
{"\xf0\x28\x8c\x28", false},
|
||||
{"\xf8\xa1\xa1\xa1\xa1", false}, // valid but not Unicode
|
||||
{"\xfc\xa1\xa1\xa1\xa1\xa1", false}}; // valid but not Unicode
|
||||
|
||||
TEST(Utf8UtilTest, Validate) {
|
||||
using namespace utf8_util;
|
||||
for (auto& s : samples) {
|
||||
size_t utf8_len = 0;
|
||||
if (s.valid != utf8_validate(reinterpret_cast<const unsigned char*>(s.sequence), strlen(s.sequence), utf8_len)) {
|
||||
ASSERT_TRUE(false);
|
||||
} else {
|
||||
if (s.valid) {
|
||||
ASSERT_EQ(1U, utf8_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
708
onnxruntime/test/contrib_ops/tokenizer_test.cc
Normal file
708
onnxruntime/test/contrib_ops/tokenizer_test.cc
Normal file
|
|
@ -0,0 +1,708 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <codecvt>
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace tokenizer_test {
|
||||
const std::string start_mark{0x2};
|
||||
const std::string end_mark{0x3};
|
||||
const std::string padval("0xdeadbeaf");
|
||||
|
||||
constexpr const char* domain = onnxruntime::kMSDomain;
|
||||
const int opset_ver = 1;
|
||||
|
||||
} // namespace tokenizer_test
|
||||
|
||||
using namespace tokenizer_test;
|
||||
|
||||
void InitTestAttr(OpTester& test, bool mark, const std::vector<std::string>& seps,
|
||||
int64_t mincharnum) {
|
||||
test.AddAttribute("mark", int64_t{mark});
|
||||
test.AddAttribute("separators", seps);
|
||||
// Padding for alignment
|
||||
test.AddAttribute("pad_value", padval);
|
||||
test.AddAttribute("mincharnum", mincharnum);
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_InvalidDim) {
|
||||
// Invalid input dimensions
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{1, 1, 2};
|
||||
std::vector<std::string> input = {std::string("s1"), std::string("s2")};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
std::vector<std::string> output(input); // do the same for now
|
||||
test.AddOutput<std::string>("Y", dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either [C] or [N][C] allowed");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersC) {
|
||||
// Char level tokenezation with latin characters and no
|
||||
// start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{"abcdef", "abcd"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(input[0].length()));
|
||||
std::vector<std::string> output{
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
padval,
|
||||
padval};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersC) {
|
||||
// Char level tokenezation with latin characters and
|
||||
// with start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{"abcdef", "abcd"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(input[0].length() + 2));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
end_mark,
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
end_mark,
|
||||
padval,
|
||||
padval};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersNC) {
|
||||
// Char level tokenezation with latin characters and no
|
||||
// start/end text markers
|
||||
// [N][C] dimensions
|
||||
// Output [N][C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::vector<std::string> input{"abcd", "abcd", "abcd", "abcdef"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(input[3].length()));
|
||||
std::vector<std::string> output{
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
padval,
|
||||
padval,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
padval,
|
||||
padval,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
padval,
|
||||
padval,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f"};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersNC) {
|
||||
// Char level tokenezation with latin characters and
|
||||
// with start/end text markers
|
||||
// [N][C] dimensions
|
||||
// Output [N][C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::vector<std::string> input{"abcd", "abcd", "abcd", "abcdef"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(input[3].length() + 2));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_CyrillicCharsWithMarkersC) {
|
||||
// Char level tokenezation with Cyrillic characters and
|
||||
// with start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсурд", u8"Кома"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Word Absurd is 6 chars long so we must get 6 individual strings out of it
|
||||
// which is the max plus start/end text markers
|
||||
output_dims.push_back(int64_t(6 + 2));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"А", u8"б", u8"с", u8"у", u8"р", u8"д",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"К", u8"о", u8"м", u8"а",
|
||||
end_mark,
|
||||
padval,
|
||||
padval};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_MixedCharsWithMarkersC) {
|
||||
// Char level tokenezation with a mix of latin, Spanish, Cyrillic and Chinese
|
||||
// characters and
|
||||
// with start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Word Absu?? is 6 chars long so we must get 6 individual strings out of it
|
||||
// which is the max plus start/end text markers
|
||||
output_dims.push_back(int64_t(6 + 2));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"А", u8"б", u8"с", u8"у", u8"中", u8"文",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"К", u8"о", u8"ñ", u8"ó",
|
||||
end_mark,
|
||||
padval,
|
||||
padval};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputC) {
|
||||
// Special case where empty output is produced
|
||||
// For [C] we expect [C][0] output
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"", u8""};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(0));
|
||||
std::vector<std::string> output{};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputNC) {
|
||||
// Special case where empty output is produced
|
||||
// For [N][C] we expect [N][C][0] output
|
||||
{
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, {""}, 1);
|
||||
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::vector<std::string> input{u8"", u8"", u8"", u8""};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(0));
|
||||
std::vector<std::string> output{};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersC) {
|
||||
// Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese
|
||||
// characters and with start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"у",
|
||||
u8"ñ"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must split both in 2
|
||||
output_dims.push_back(int64_t(2 + 2));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"Абс", u8"中文",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"Ко", u8"ó",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
} // namespace test
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersCompleteMatchEmptyOutputC) {
|
||||
// Test entire separators match so we get nothing
|
||||
// in the output
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"Абсу中文",
|
||||
u8"Коñó"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must have no output
|
||||
output_dims.push_back(int64_t(0));
|
||||
std::vector<std::string> output;
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) {
|
||||
// Match the start
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"А",
|
||||
u8"К"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop first characters from both strings
|
||||
output_dims.push_back(int64_t(3));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"бсу中文",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"оñó",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) {
|
||||
// Match the end
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"文",
|
||||
u8"ó"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop last characters from both strings
|
||||
output_dims.push_back(int64_t(3));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"Абсу中",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"Коñ",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4CharsC) {
|
||||
// Match the end, require at least 4 chars
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"文",
|
||||
u8"ó"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 4);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop the last character from the first
|
||||
// and the second 3 character token does not pass mincharnum
|
||||
output_dims.push_back(int64_t(3));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"Абсу中",
|
||||
end_mark,
|
||||
start_mark,
|
||||
end_mark,
|
||||
padval};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputC) {
|
||||
// Empty input for [C] should produce [C][0]
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"文",
|
||||
u8"ó"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 4);
|
||||
|
||||
std::vector<int64_t> dims{2};
|
||||
std::vector<std::string> input{u8"", u8""};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(0));
|
||||
std::vector<std::string> output;
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputNC) {
|
||||
// Empty input for [N][C] should produce [N][C][0]
|
||||
{
|
||||
std::vector<std::string> separators = {
|
||||
u8"文",
|
||||
u8"ó"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 4);
|
||||
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
std::vector<std::string> input{u8"", u8"文", u8"ó", u8""};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(0));
|
||||
std::vector<std::string> output;
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapShortFirstC) {
|
||||
// Test of the overlapping search patterns
|
||||
// The spec mandates that the patterns that appear
|
||||
// in the separators earlier must be matched first.
|
||||
{
|
||||
// In this case the first pattern must match first
|
||||
// and there would be no match for the second
|
||||
std::vector<std::string> separators = {
|
||||
u8"су",
|
||||
u8"Абсу"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{1};
|
||||
std::vector<std::string> input{u8"Абсу中文"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// must split in 2 with no two middle characters
|
||||
output_dims.push_back(int64_t(2));
|
||||
std::vector<std::string> output{
|
||||
u8"Аб", u8"中文"};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstC) {
|
||||
// Test of the overlapping search patterns
|
||||
// The spec mandates that the patterns that appear
|
||||
// in the separators earlier must be matched first.
|
||||
{
|
||||
// In this case the first pattern must match first
|
||||
// and there would be no match for the second
|
||||
std::vector<std::string> separators = {
|
||||
u8"Абсу",
|
||||
u8"су"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{1};
|
||||
std::vector<std::string> input{u8"Абсу中文"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop the beginning of the word that
|
||||
// also contains the second separator
|
||||
output_dims.push_back(int64_t(1));
|
||||
std::vector<std::string> output{u8"中文"};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstRepeatedShortC) {
|
||||
// Test of the overlapping search patterns
|
||||
// The spec mandates that the patterns that appear
|
||||
// in the separators earlier must be matched first.
|
||||
{
|
||||
// In this case the first pattern must match first
|
||||
// and there would be no match for the second
|
||||
std::vector<std::string> separators = {
|
||||
u8"Абсу",
|
||||
u8"су"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{1};
|
||||
std::vector<std::string> input{u8"Абсусусу中文"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop the beginning of the word that
|
||||
// also contains the second separator
|
||||
output_dims.push_back(int64_t(1));
|
||||
std::vector<std::string> output{u8"中文"};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapingMatchC) {
|
||||
// Test of the overlapping search patterns
|
||||
// The spec mandates that the patterns that appear
|
||||
// in the separators earlier must be matched first.
|
||||
{
|
||||
// In this case the first pattern must match first
|
||||
// and there are more than one overlapping matches for the first
|
||||
// so the earlier match for the first wins.
|
||||
// and there would be no match for the second
|
||||
std::vector<std::string> separators = {
|
||||
u8"усу",
|
||||
u8"Абсу"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, false, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{1};
|
||||
std::vector<std::string> input{u8"Абсусусу中文"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must drop the beginning of the word that
|
||||
// also contains the second separator
|
||||
output_dims.push_back(int64_t(2));
|
||||
std::vector<std::string> output{u8"Абс", u8"су中文"};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) {
|
||||
// Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese
|
||||
// characters and with start/end text markers
|
||||
// [C] dimensions
|
||||
// Output [C][D]
|
||||
std::vector<std::string> separators = {
|
||||
u8";",
|
||||
u8";;;"};
|
||||
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
InitTestAttr(test, true, separators, 1);
|
||||
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
// Must split both in 2
|
||||
output_dims.push_back(int64_t(6));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"b",
|
||||
u8"c",
|
||||
u8"d",
|
||||
u8"e",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
u8"c",
|
||||
end_mark,
|
||||
padval,
|
||||
};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue