Implement Tokenizer op (#31)

* Implement separator tokenizer with TST.
  TODO: Clarify what to do if the output is empty and no start/end text
  markers required. Also see if the current search algo is acceptable.

* Add utf8 util test

* For empty output produce [C] -> [C][0], [N][C] -> [N][C][0]

* Augument TST search with match conflict resolution in favor of
  earlier specified pattern matches.

* Address MAcOS build error.

* Adjust error message

* Address review comments.

* Remove nested loops.

* Remove 3rd party utf8 validation code.

* Address review comments part I.

* Move padding outside start/end markers.
  Split unit tests for invidividual test cases.

* Fix a common prefix bug reported by Xavier.
This commit is contained in:
Dmitri Smirnov 2018-12-05 17:52:04 -08:00 committed by GitHub
parent a68f5ccfd9
commit c52636e187
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1440 additions and 0 deletions

View file

@ -92,6 +92,36 @@ Sample echo operator.)DOC");
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(Returns which elements of the input are NaN.)DOC");
ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "X", "Strings to tokenize", "T")
.Output(0, "Y", "Tokenized strings", "T")
.TypeConstraint(
"T",
{"tensor(string)"},
"Input/Output is a string tensor")
.Attr(
"mark",
"Boolean whether to mark the beginning/end character with start of text character (0x02)/end of text character (0x03).",
AttributeProto::INT)
.Attr(
"pad_value",
"The string used to pad output tensors when the tokens extracted doesn't match the maximum number of tokens found. If start/end markers are needed, padding will appear outside the markers.",
AttributeProto::STRING)
.Attr(
"separators",
"The list of separators, two consecutive segments in X connected by a separator would be divided into two tokens.",
AttributeProto::STRINGS)
.Attr(
"mincharnum",
"Minimum number of characters allowed in the output. For example, if mincharnum is 2, tokens such as \"A\" and \"B\" would be ignored",
AttributeProto::INT)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
})
.SetDoc(R"DOC(Tokenizer divides each string in X into a vector of strings along the last axis. All input strings including attributes are UTF-8 encoded.)DOC");
// Operators for linear 8 bit quanitzation support.
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
.SetDomain(kMSDomain)
@ -491,6 +521,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
@ -505,6 +536,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());

View file

@ -0,0 +1,489 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "tokenizer.h"
#include "onnx/defs/schema.h"
#include "core/common/common.h"
#include "core/framework/tensor.h"
#include "core/common/utf8_util.h"
#include <codecvt>
#include <locale>
namespace onnxruntime {
namespace contrib {
using namespace utf8_util;
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
Tokenizer,
1,
string,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
contrib::Tokenizer);
namespace tokenizer_details {
const char start_text = 0x2;
const char end_text = 0x3;
const std::string conv_error("Conversion Error");
const std::wstring wconv_error(L"Conversion Error");
// Use a Trie like structure for searching multiple strings
// at once but convert it to a ternary tree for saving space.
// We insert separators in the same order they are specified.
// Template parameter is a CharT which can be a char/wchar_t
// or anything else that supports operator ><,== as long as
// this is not a variable length sequence. We convert utf8 to utf16
// before inserting.
// Value is a supplementary information useful for search hit
// and is present in the nodes that terminate the whole search pattern
template <class CharT, class Value>
class TernarySearchTree {
private:
struct Node {
std::unique_ptr<Node> left_;
std::unique_ptr<Node> mid_;
std::unique_ptr<Node> right_;
CharT c_; // character
Value value_;
bool has_val_;
explicit Node(CharT c) : c_(c), value_(), has_val_(false) {
}
~Node() = default;
};
struct GetState {
const CharT* const str_;
const size_t len_;
size_t depth_;
const Value* result_;
};
public:
TernarySearchTree() = default;
~TernarySearchTree() = default;
/**
* Returns a ptr to an associated value and nullptr on search miss.
* Must use default constructed state.
*/
const Value* get(const CharT* str, size_t len) const {
if (len == 0) {
return nullptr;
}
GetState get_state{str, len, 0, nullptr};
get(root_.get(), get_state);
return get_state.result_;
}
/**
* Returns true if successful and false on empty strings
* and duplicates.
*/
bool put(const CharT* str, size_t len, const Value& v) {
if (len < 1) {
assert(false);
return false;
}
Node* new_root = put(root_.get(), str, len, v, 0);
if (new_root != nullptr) {
root_.release();
root_.reset(new_root);
return true;
}
return false;
}
private:
void update_state(const Node* node, GetState& state) const {
if (node->has_val_) {
if (state.result_ == nullptr) {
state.result_ = &node->value_;
} else if (node->value_ < *state.result_) {
state.result_ = &node->value_;
}
}
}
void get(const Node* node, GetState& state) const {
if (node == nullptr) {
return;
}
assert(state.depth_ < state.len_);
CharT c = state.str_[state.depth_];
if (c < node->c_) {
get(node->left_.get(), state);
return;
} else if (c > node->c_) {
get(node->right_.get(), state);
return;
} else if (state.depth_ < (state.len_ - 1)) {
// Check if we have a match at this node
update_state(node, state);
if (node->mid_ != nullptr) {
++state.depth_;
get(node->mid_.get(), state);
}
return;
}
update_state(node, state);
}
Node* put(Node* node, const CharT* str, size_t len, const Value& v, size_t depth) {
CharT c = str[depth];
std::unique_ptr<Node> new_node;
if (node == nullptr) {
new_node.reset(new Node(c));
}
Node* new_link = nullptr;
Node* n = (node != nullptr) ? node : new_node.get();
if (c < n->c_) {
new_link = put(n->left_.get(), str, len, v, depth);
if (new_link != nullptr) {
n->left_.release();
n->left_.reset(new_link);
}
} else if (c > n->c_) {
new_link = put(n->right_.get(), str, len, v, depth);
if (new_link != nullptr) {
n->right_.release();
n->right_.reset(new_link);
}
} else if (depth < (len - 1)) {
new_link = put(n->mid_.get(), str, len, v, depth + 1);
if (new_link != nullptr) {
n->mid_.release();
n->mid_.reset(new_link);
}
} else {
if (!n->has_val_) {
n->value_ = v;
n->has_val_ = true;
new_link = n;
}
}
if (new_link != nullptr) {
new_node.release();
return n;
}
return nullptr;
}
std::unique_ptr<Node> root_;
};
// We store the length of the original pattern within the
// Ternary Tree. This allows us to cut out the length of the matching
// separator from the original string.
struct SearchValue {
size_t w_len;
int priority_;
bool operator<(const SearchValue& o) const {
return priority_ < o.priority_;
}
};
} // namespace tokenizer_details
using namespace tokenizer_details;
struct Tokenizer::SearchData {
TernarySearchTree<wchar_t, SearchValue> tst_;
};
Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) {
int64_t mark = 0;
auto status = info.GetAttr("mark", &mark);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mark is not set");
mark_ = mark != 0;
status = info.GetAttr("pad_value", &pad_value_);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute pad_value is not set");
status = info.GetAttr("mincharnum", &mincharnum_);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mincharnum is not set");
ONNXRUNTIME_ENFORCE(mincharnum_ > 0, "attribute mincharnum must have a positive value");
std::vector<std::string> separators;
status = info.GetAttrs<std::string>("separators", separators);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute separators is not set");
ONNXRUNTIME_ENFORCE(!separators.empty(), "Requires at least one separator");
char_tokenezation_ = (separators.size() == 1 &&
separators[0].empty());
ONNXRUNTIME_ENFORCE(!char_tokenezation_ || mincharnum_ < 2,
"mincharnum is too big for char level tokenezation");
// Create TST and insert separators
if (!char_tokenezation_) {
std::unique_ptr<SearchData> sd(std::make_unique<SearchData>());
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
int priority = 0; // earlier search patterns get priority
for (const auto& sep : separators) {
ONNXRUNTIME_ENFORCE(!sep.empty(), "No empty separators allowed");
std::wstring wsep = converter.from_bytes(sep);
ONNXRUNTIME_ENFORCE(wsep != wconv_error, "Separator strings contains invalid utf8 chars");
bool result = sd->tst_.put(wsep.c_str(), wsep.length(), {wsep.length(), priority});
ONNXRUNTIME_ENFORCE(result, "duplicate separator detected");
++priority;
}
search_data_.swap(sd);
}
}
// Make SearchData definition available for destruction
Tokenizer ::~Tokenizer() {
}
Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const {
// With char tokenzation we get as many tokens as the number of
// utf8 characters in the string. So for every string we calculate its character(utf8) length
// add padding and add start/end test separators if necessary
size_t max_tokens = 0;
auto X = ctx->Input<Tensor>(0);
auto const input_data = X->template Data<std::string>();
auto curr_input = input_data;
auto const last = input_data + N * C;
while (curr_input != last) {
const auto& s = *curr_input;
size_t tokens = 0; // length in utf8 chars
if (!utf8_validate(reinterpret_cast<const unsigned char*>(s.data()), s.size(),
tokens)) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input string contains invalid utf8 chars: " + s);
}
if (mark_) {
tokens += 2; // Start/end markers as separate tokens
}
max_tokens = std::max(max_tokens, tokens);
++curr_input;
}
std::vector<int64_t> output_dims(input_dims);
// Check if we have no output due to apparently empty strings input.
if ((max_tokens - mark_ * 2) == 0) {
output_dims.push_back(0);
TensorShape output_shape(output_dims);
ctx->Output(0, output_shape);
return Status::OK();
}
output_dims.push_back(max_tokens);
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
auto const output_data = output_tensor->template MutableData<std::string>();
size_t output_index = 0;
curr_input = input_data;
while (curr_input != last) {
const auto& s = *curr_input;
if (mark_) {
new (output_data + output_index) std::string(&start_text, 1);
++output_index;
}
size_t tokens = 0;
const size_t str_len = s.size();
for (size_t token_idx = 0; token_idx < str_len;) {
size_t tlen = 0;
bool result = utf8_bytes(static_cast<unsigned char>(s[token_idx]), tlen);
assert(result);
(void)result;
assert(token_idx + tlen <= str_len);
new (output_data + output_index) std::string(s.substr(token_idx, tlen));
++output_index;
token_idx += tlen;
++tokens;
}
if (mark_) {
new (output_data + output_index) std::string(&end_text, 1);
++output_index;
}
// Padding strings
assert(tokens + (mark_ * 2) <= max_tokens);
const size_t pads = max_tokens - (mark_ * 2) - tokens;
for (size_t p = 0; p < pads; ++p) {
new (output_data + output_index) std::string(pad_value_);
++output_index;
}
++curr_input;
}
return Status::OK();
}
Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx,
size_t N, size_t C,
const std::vector<int64_t>& input_dims) const {
struct Match {
int priority_;
size_t offset_;
size_t size_;
// create a conflict for overlapping matches
// thus if they overlap neither is less than the other
// and they are considered equal
bool operator<(const Match& o) const {
return (offset_ + size_) <= o.offset_;
}
};
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
// Scan all strings and attempt to find separators in them
// collect all the output tokens here
size_t max_tokens = 0;
std::vector<std::vector<std::wstring>> tokenized_strings;
tokenized_strings.reserve(N * C);
auto X = ctx->Input<Tensor>(0);
auto const input_data = X->template Data<std::string>();
auto curr_input = input_data;
auto const last = input_data + N * C;
while (curr_input != last) {
const auto& s = *curr_input;
std::wstring wstr = converter.from_bytes(s);
if (wstr == wconv_error) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid utf8 chars in the input: " + s);
}
std::set<Match> matches;
const wchar_t* ws = wstr.c_str();
size_t len_remaining = wstr.length();
size_t offset = 0;
while (len_remaining > 0) {
const auto* val = search_data_->tst_.get(ws, len_remaining);
if (val != nullptr) {
auto p = matches.insert({val->priority_, offset, val->w_len});
while (!p.second && val->priority_ < p.first->priority_) {
// if overlapping matches of the same pattern(priority), then
// the earlier match naturally wins
matches.erase(p.first);
p = matches.insert({val->priority_, offset, val->w_len});
}
}
++ws;
++offset;
--len_remaining;
}
// Tokenize
tokenized_strings.emplace_back();
auto& row_tokens = tokenized_strings.back();
row_tokens.reserve(matches.size() + 1);
ws = wstr.c_str();
offset = 0;
for (const auto& m : matches) {
assert(m.offset_ >= offset);
size_t sz = (m.offset_ - offset);
if (sz > 0 && sz >= size_t(mincharnum_)) {
row_tokens.emplace_back(ws, sz);
}
offset = m.offset_ + m.size_;
ws = wstr.c_str() + offset;
}
assert(offset <= wstr.length());
if (offset < wstr.length()) {
row_tokens.emplace_back(ws, wstr.length() - offset);
}
size_t tokens = row_tokens.size();
if (mark_) {
tokens += 2; // Start/end markers as separate tokens
}
max_tokens = std::max(max_tokens, tokens);
++curr_input;
}
std::vector<int64_t> output_dims(input_dims);
// Check if we have no output due to either empty input
// everything is a separator
if ((max_tokens - mark_ * 2) == 0) {
output_dims.push_back(0);
TensorShape output_shape(output_dims);
ctx->Output(0, output_shape);
return Status::OK();
}
output_dims.push_back(max_tokens);
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
auto const output_data = output_tensor->template MutableData<std::string>();
#ifdef _DEBUG
const size_t max_output_index = N * C * max_tokens;
#endif
size_t output_index = 0;
for (auto& row : tokenized_strings) {
#ifdef _DEBUG
size_t c_idx = output_index;
#endif
if (mark_) {
new (output_data + output_index) std::string(&start_text, 1);
++output_index;
}
// Output tokens for this row
for (auto& token : row) {
new (output_data + output_index) std::string(converter.to_bytes(token));
++output_index;
}
if (mark_) {
new (output_data + output_index) std::string(&end_text, 1);
++output_index;
}
const size_t pads = max_tokens - (mark_ * 2) - row.size();
for (size_t p = 0; p < pads; ++p) {
new (output_data + output_index) std::string(pad_value_);
++output_index;
}
#ifdef _DEBUG
assert(output_index <= max_output_index);
assert((output_index - c_idx) <= max_tokens);
#endif
}
return Status::OK();
}
Status Tokenizer::Compute(OpKernelContext* ctx) const {
// Get input buffer ptr
auto X = ctx->Input<Tensor>(0);
if (X->DataType() != DataTypeImpl::GetType<std::string>()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"tensor(string) expected as input");
}
auto& input_dims = X->Shape().GetDims();
size_t N = 0;
size_t C = 0;
if (input_dims.size() == 1) {
N = 1;
if (input_dims[0] < 1) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid C dimension value");
}
C = input_dims[0];
} else if (input_dims.size() == 2) {
if (input_dims[0] < 1 || input_dims[1] < 1) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid N and/or C dimension values");
}
N = input_dims[0];
C = input_dims[1];
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input dimensions are either [C] or [N][C] allowed");
}
Status s;
if (char_tokenezation_) {
s = CharTokenize(ctx, N, C, input_dims);
} else {
s = SeparatorTokenize(ctx, N, C, input_dims);
}
return s;
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/op_kernel.h"
#include <memory>
namespace onnxruntime {
namespace contrib {
class Tokenizer final : public OpKernel {
public:
explicit Tokenizer(const OpKernelInfo& info);
Tokenizer(const Tokenizer&) = delete;
Tokenizer& operator=(const Tokenizer&) = delete;
~Tokenizer();
Status Compute(OpKernelContext* context) const override;
private:
Status CharTokenize(OpKernelContext* context, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const;
Status SeparatorTokenize(OpKernelContext* context, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const;
bool mark_;
std::string pad_value_;
int64_t mincharnum_;
bool char_tokenezation_;
struct SearchData;
std::unique_ptr<SearchData> search_data_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
namespace onnxruntime {
namespace utf8_util {
// Returns the number of bytes in the utf8 character
// by analyzing its leading byte
inline bool utf8_bytes(unsigned char ch, size_t& len) {
if ((ch & 0x80) == 0) {
len = 1;
return true;
}
if ((ch & 0xE0) == 0xC0) {
len = 2;
return true;
}
unsigned int result = (ch & 0xF0);
if (result == 0xE0) {
len = 3;
return true;
}
if (result == 0xF0) {
len = 4;
return true;
}
return false;
}
inline bool utf8_validate(const unsigned char* s, size_t len, size_t& utf8_chars) {
size_t utf8_len = 0;
size_t idx = 0;
while (idx < len) {
size_t bytes = 0;
auto ch = s[idx];
if (utf8_bytes(ch, bytes)) {
switch (bytes) {
case 1:
break;
case 2: {
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
return false;
}
} break; // 2
case 3: {
auto ch1 = s[idx];
switch (ch1) {
case 0xE0u:
if (++idx >= len || s[idx] < 0xA0u || s[idx] > 0xBFu) {
return false;
}
break;
case 0xEDu:
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x9Fu) {
return false;
}
break;
default: {
if ((ch1 >= 0xE1u && ch1 <= 0xECu) ||
(ch1 >= 0xEEu && ch1 <= 0xEFu)) {
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
return false;
}
} else {
return false;
}
} break;
}
// validate byte 3
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
return false;
}
} break; // 3
case 4: {
auto ch1 = s[idx];
switch (ch1) {
case 0xF0u: {
if (++idx >= len || s[idx] < 0x90u || s[idx] > 0xBFu) {
return false;
}
} break;
case 0xF4u: {
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x8Fu) {
return false;
}
} break;
default: {
if (ch1 >= 0xF1u && ch1 <= 0xF3u) {
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
return false;
}
} else {
return false;
}
} break;
}
// validate bytes 3 and 4
size_t stop = idx + 2;
while (idx < stop) {
if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) {
return false;
}
}
} break; // 4
default:
// no chars longer than 4
return false;
} // switch bytes
++idx;
++utf8_len;
} else {
return false;
}
}
// End index must match
// the end of the last byte sequence.
if (idx != len) {
return false;
}
utf8_chars = utf8_len;
return true;
}
} // namespace utf8_util
} // namespace onnxruntime

View file

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/utf8_util.h"
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
struct Sample {
const char* sequence;
bool valid;
};
const std::vector<Sample> samples = {
{"a", true},
{"\xc3\xb1", true},
{"\xc3\x28", false},
{"\xa0\xa1", false},
{"\xe2\x82\xa1", true},
{"\xe2\x28\xa1", false},
{"\xe2\x82\x28", false},
{"\xf0\x90\x8c\xbc", true},
{"\xf0\x28\x8c\xbc", false},
{"\xf0\x90\x28\xbc", false},
{"\xf0\x28\x8c\x28", false},
{"\xf8\xa1\xa1\xa1\xa1", false}, // valid but not Unicode
{"\xfc\xa1\xa1\xa1\xa1\xa1", false}}; // valid but not Unicode
TEST(Utf8UtilTest, Validate) {
using namespace utf8_util;
for (auto& s : samples) {
size_t utf8_len = 0;
if (s.valid != utf8_validate(reinterpret_cast<const unsigned char*>(s.sequence), strlen(s.sequence), utf8_len)) {
ASSERT_TRUE(false);
} else {
if (s.valid) {
ASSERT_EQ(1U, utf8_len);
}
}
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,708 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <codecvt>
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
namespace tokenizer_test {
const std::string start_mark{0x2};
const std::string end_mark{0x3};
const std::string padval("0xdeadbeaf");
constexpr const char* domain = onnxruntime::kMSDomain;
const int opset_ver = 1;
} // namespace tokenizer_test
using namespace tokenizer_test;
void InitTestAttr(OpTester& test, bool mark, const std::vector<std::string>& seps,
int64_t mincharnum) {
test.AddAttribute("mark", int64_t{mark});
test.AddAttribute("separators", seps);
// Padding for alignment
test.AddAttribute("pad_value", padval);
test.AddAttribute("mincharnum", mincharnum);
}
TEST(ContribOpTest, TokenizerCharLevel_InvalidDim) {
// Invalid input dimensions
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, {""}, 1);
std::vector<int64_t> dims{1, 1, 2};
std::vector<std::string> input = {std::string("s1"), std::string("s2")};
test.AddInput<std::string>("T", dims, input);
std::vector<std::string> output(input); // do the same for now
test.AddOutput<std::string>("Y", dims, output);
test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either [C] or [N][C] allowed");
}
}
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersC) {
// Char level tokenezation with latin characters and no
// start/end text markers
// [C] dimensions
// Output [C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, {""}, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{"abcdef", "abcd"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(input[0].length()));
std::vector<std::string> output{
"a",
"b",
"c",
"d",
"e",
"f",
"a",
"b",
"c",
"d",
padval,
padval};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersC) {
// Char level tokenezation with latin characters and
// with start/end text markers
// [C] dimensions
// Output [C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{"abcdef", "abcd"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(input[0].length() + 2));
std::vector<std::string> output{
start_mark,
"a",
"b",
"c",
"d",
"e",
"f",
end_mark,
start_mark,
"a",
"b",
"c",
"d",
end_mark,
padval,
padval};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersNC) {
// Char level tokenezation with latin characters and no
// start/end text markers
// [N][C] dimensions
// Output [N][C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, {""}, 1);
std::vector<int64_t> dims{2, 2};
std::vector<std::string> input{"abcd", "abcd", "abcd", "abcdef"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(input[3].length()));
std::vector<std::string> output{
"a",
"b",
"c",
"d",
padval,
padval,
"a",
"b",
"c",
"d",
padval,
padval,
"a",
"b",
"c",
"d",
padval,
padval,
"a",
"b",
"c",
"d",
"e",
"f"};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersNC) {
// Char level tokenezation with latin characters and
// with start/end text markers
// [N][C] dimensions
// Output [N][C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2, 2};
std::vector<std::string> input{"abcd", "abcd", "abcd", "abcdef"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(input[3].length() + 2));
std::vector<std::string> output{
start_mark,
"a",
"b",
"c",
"d",
end_mark,
padval,
padval,
start_mark,
"a",
"b",
"c",
"d",
end_mark,
padval,
padval,
start_mark,
"a",
"b",
"c",
"d",
end_mark,
padval,
padval,
start_mark,
"a",
"b",
"c",
"d",
"e",
"f",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_CyrillicCharsWithMarkersC) {
// Char level tokenezation with Cyrillic characters and
// with start/end text markers
// [C] dimensions
// Output [C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсурд", u8"Кома"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Word Absurd is 6 chars long so we must get 6 individual strings out of it
// which is the max plus start/end text markers
output_dims.push_back(int64_t(6 + 2));
std::vector<std::string> output{
start_mark,
u8"А", u8"б", u8"с", u8"у", u8"р", u8"д",
end_mark,
start_mark,
u8"К", u8"о", u8"м", u8"а",
end_mark,
padval,
padval};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_MixedCharsWithMarkersC) {
// Char level tokenezation with a mix of latin, Spanish, Cyrillic and Chinese
// characters and
// with start/end text markers
// [C] dimensions
// Output [C][D]
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Word Absu?? is 6 chars long so we must get 6 individual strings out of it
// which is the max plus start/end text markers
output_dims.push_back(int64_t(6 + 2));
std::vector<std::string> output{
start_mark,
u8"А", u8"б", u8"с", u8"у", u8"", u8"",
end_mark,
start_mark,
u8"К", u8"о", u8"ñ", u8"ó",
end_mark,
padval,
padval};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputC) {
// Special case where empty output is produced
// For [C] we expect [C][0] output
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"", u8""};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(0));
std::vector<std::string> output{};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputNC) {
// Special case where empty output is produced
// For [N][C] we expect [N][C][0] output
{
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, {""}, 1);
std::vector<int64_t> dims{2, 2};
std::vector<std::string> input{u8"", u8"", u8"", u8""};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(0));
std::vector<std::string> output{};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersC) {
// Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese
// characters and with start/end text markers
// [C] dimensions
// Output [C][D]
{
std::vector<std::string> separators = {
u8"у",
u8"ñ"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must split both in 2
output_dims.push_back(int64_t(2 + 2));
std::vector<std::string> output{
start_mark,
u8"Абс", u8"中文",
end_mark,
start_mark,
u8"Ко", u8"ó",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
} // namespace test
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersCompleteMatchEmptyOutputC) {
// Test entire separators match so we get nothing
// in the output
{
std::vector<std::string> separators = {
u8"Абсу中文",
u8"Коñó"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must have no output
output_dims.push_back(int64_t(0));
std::vector<std::string> output;
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) {
// Match the start
{
std::vector<std::string> separators = {
u8"А",
u8"К"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop first characters from both strings
output_dims.push_back(int64_t(3));
std::vector<std::string> output{
start_mark,
u8"бсу中文",
end_mark,
start_mark,
u8"оñó",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) {
// Match the end
{
std::vector<std::string> separators = {
u8"",
u8"ó"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 1);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop last characters from both strings
output_dims.push_back(int64_t(3));
std::vector<std::string> output{
start_mark,
u8"Абсу中",
end_mark,
start_mark,
u8"Коñ",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4CharsC) {
// Match the end, require at least 4 chars
{
std::vector<std::string> separators = {
u8"",
u8"ó"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 4);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"Абсу中文", u8"Коñó"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop the last character from the first
// and the second 3 character token does not pass mincharnum
output_dims.push_back(int64_t(3));
std::vector<std::string> output{
start_mark,
u8"Абсу中",
end_mark,
start_mark,
end_mark,
padval};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputC) {
// Empty input for [C] should produce [C][0]
{
std::vector<std::string> separators = {
u8"",
u8"ó"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 4);
std::vector<int64_t> dims{2};
std::vector<std::string> input{u8"", u8""};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(0));
std::vector<std::string> output;
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputNC) {
// Empty input for [N][C] should produce [N][C][0]
{
std::vector<std::string> separators = {
u8"",
u8"ó"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 4);
std::vector<int64_t> dims{2, 2};
std::vector<std::string> input{u8"", u8"", u8"ó", u8""};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(0));
std::vector<std::string> output;
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapShortFirstC) {
// Test of the overlapping search patterns
// The spec mandates that the patterns that appear
// in the separators earlier must be matched first.
{
// In this case the first pattern must match first
// and there would be no match for the second
std::vector<std::string> separators = {
u8"су",
u8"Абсу"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, separators, 1);
std::vector<int64_t> dims{1};
std::vector<std::string> input{u8"Абсу中文"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// must split in 2 with no two middle characters
output_dims.push_back(int64_t(2));
std::vector<std::string> output{
u8"Аб", u8"中文"};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstC) {
// Test of the overlapping search patterns
// The spec mandates that the patterns that appear
// in the separators earlier must be matched first.
{
// In this case the first pattern must match first
// and there would be no match for the second
std::vector<std::string> separators = {
u8"Абсу",
u8"су"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, separators, 1);
std::vector<int64_t> dims{1};
std::vector<std::string> input{u8"Абсу中文"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop the beginning of the word that
// also contains the second separator
output_dims.push_back(int64_t(1));
std::vector<std::string> output{u8"中文"};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstRepeatedShortC) {
// Test of the overlapping search patterns
// The spec mandates that the patterns that appear
// in the separators earlier must be matched first.
{
// In this case the first pattern must match first
// and there would be no match for the second
std::vector<std::string> separators = {
u8"Абсу",
u8"су"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, separators, 1);
std::vector<int64_t> dims{1};
std::vector<std::string> input{u8"Абсусусу中文"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop the beginning of the word that
// also contains the second separator
output_dims.push_back(int64_t(1));
std::vector<std::string> output{u8"中文"};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapingMatchC) {
// Test of the overlapping search patterns
// The spec mandates that the patterns that appear
// in the separators earlier must be matched first.
{
// In this case the first pattern must match first
// and there are more than one overlapping matches for the first
// so the earlier match for the first wins.
// and there would be no match for the second
std::vector<std::string> separators = {
u8"усу",
u8"Абсу"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, false, separators, 1);
std::vector<int64_t> dims{1};
std::vector<std::string> input{u8"Абсусусу中文"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must drop the beginning of the word that
// also contains the second separator
output_dims.push_back(int64_t(2));
std::vector<std::string> output{u8"Абс", u8"су中文"};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
}
TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) {
// Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese
// characters and with start/end text markers
// [C] dimensions
// Output [C][D]
std::vector<std::string> separators = {
u8";",
u8";;;"};
OpTester test("Tokenizer", opset_ver, domain);
InitTestAttr(test, true, separators, 1);
std::vector<int64_t> dims{4};
std::vector<std::string> input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
// Must split both in 2
output_dims.push_back(int64_t(6));
std::vector<std::string> output{
start_mark,
u8"a",
u8"b",
end_mark,
padval,
padval,
start_mark,
u8"a",
u8"b",
end_mark,
padval,
padval,
start_mark,
u8"b",
u8"c",
u8"d",
u8"e",
end_mark,
start_mark,
u8"a",
u8"b",
u8"c",
end_mark,
padval,
};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
} // namespace test
} // namespace onnxruntime