diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index ad52918ab7..c2c8d3318d 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -92,6 +92,36 @@ Sample echo operator.)DOC"); .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) .SetDoc(R"DOC(Returns which elements of the input are NaN.)DOC"); + ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "X", "Strings to tokenize", "T") + .Output(0, "Y", "Tokenized strings", "T") + .TypeConstraint( + "T", + {"tensor(string)"}, + "Input/Output is a string tensor") + .Attr( + "mark", + "Boolean whether to mark the beginning/end character with start of text character (0x02)/end of text character (0x03).", + AttributeProto::INT) + .Attr( + "pad_value", + "The string used to pad output tensors when the tokens extracted doesn't match the maximum number of tokens found. If start/end markers are needed, padding will appear outside the markers.", + AttributeProto::STRING) + .Attr( + "separators", + "The list of separators, two consecutive segments in X connected by a separator would be divided into two tokens.", + AttributeProto::STRINGS) + .Attr( + "mincharnum", + "Minimum number of characters allowed in the output. For example, if mincharnum is 2, tokens such as \"A\" and \"B\" would be ignored", + AttributeProto::INT) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + }) + .SetDoc(R"DOC(Tokenizer divides each string in X into a vector of strings along the last axis. All input strings including attributes are UTF-8 encoded.)DOC"); + // Operators for linear 8 bit quanitzation support. ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear) .SetDomain(kMSDomain) @@ -491,6 +521,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); @@ -505,6 +536,7 @@ void RegisterContribKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.cc b/onnxruntime/contrib_ops/cpu/tokenizer.cc new file mode 100644 index 0000000000..7e83e1ca71 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tokenizer.cc @@ -0,0 +1,489 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tokenizer.h" +#include "onnx/defs/schema.h" +#include "core/common/common.h" +#include "core/framework/tensor.h" + +#include "core/common/utf8_util.h" + +#include +#include + +namespace onnxruntime { +namespace contrib { + +using namespace utf8_util; + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + Tokenizer, + 1, + string, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + contrib::Tokenizer); + +namespace tokenizer_details { + +const char start_text = 0x2; +const char end_text = 0x3; + +const std::string conv_error("Conversion Error"); +const std::wstring wconv_error(L"Conversion Error"); + +// Use a Trie like structure for searching multiple strings +// at once but convert it to a ternary tree for saving space. +// We insert separators in the same order they are specified. +// Template parameter is a CharT which can be a char/wchar_t +// or anything else that supports operator ><,== as long as +// this is not a variable length sequence. We convert utf8 to utf16 +// before inserting. +// Value is a supplementary information useful for search hit +// and is present in the nodes that terminate the whole search pattern +template +class TernarySearchTree { + private: + struct Node { + std::unique_ptr left_; + std::unique_ptr mid_; + std::unique_ptr right_; + CharT c_; // character + Value value_; + bool has_val_; + + explicit Node(CharT c) : c_(c), value_(), has_val_(false) { + } + ~Node() = default; + }; + + struct GetState { + const CharT* const str_; + const size_t len_; + size_t depth_; + const Value* result_; + }; + + public: + TernarySearchTree() = default; + ~TernarySearchTree() = default; + + /** + * Returns a ptr to an associated value and nullptr on search miss. + * Must use default constructed state. + */ + const Value* get(const CharT* str, size_t len) const { + if (len == 0) { + return nullptr; + } + GetState get_state{str, len, 0, nullptr}; + get(root_.get(), get_state); + return get_state.result_; + } + + /** + * Returns true if successful and false on empty strings + * and duplicates. + */ + bool put(const CharT* str, size_t len, const Value& v) { + if (len < 1) { + assert(false); + return false; + } + Node* new_root = put(root_.get(), str, len, v, 0); + if (new_root != nullptr) { + root_.release(); + root_.reset(new_root); + return true; + } + return false; + } + + private: + void update_state(const Node* node, GetState& state) const { + if (node->has_val_) { + if (state.result_ == nullptr) { + state.result_ = &node->value_; + } else if (node->value_ < *state.result_) { + state.result_ = &node->value_; + } + } + } + void get(const Node* node, GetState& state) const { + if (node == nullptr) { + return; + } + assert(state.depth_ < state.len_); + CharT c = state.str_[state.depth_]; + if (c < node->c_) { + get(node->left_.get(), state); + return; + } else if (c > node->c_) { + get(node->right_.get(), state); + return; + } else if (state.depth_ < (state.len_ - 1)) { + // Check if we have a match at this node + update_state(node, state); + if (node->mid_ != nullptr) { + ++state.depth_; + get(node->mid_.get(), state); + } + return; + } + update_state(node, state); + } + + Node* put(Node* node, const CharT* str, size_t len, const Value& v, size_t depth) { + CharT c = str[depth]; + + std::unique_ptr new_node; + if (node == nullptr) { + new_node.reset(new Node(c)); + } + + Node* new_link = nullptr; + Node* n = (node != nullptr) ? node : new_node.get(); + if (c < n->c_) { + new_link = put(n->left_.get(), str, len, v, depth); + if (new_link != nullptr) { + n->left_.release(); + n->left_.reset(new_link); + } + } else if (c > n->c_) { + new_link = put(n->right_.get(), str, len, v, depth); + if (new_link != nullptr) { + n->right_.release(); + n->right_.reset(new_link); + } + } else if (depth < (len - 1)) { + new_link = put(n->mid_.get(), str, len, v, depth + 1); + if (new_link != nullptr) { + n->mid_.release(); + n->mid_.reset(new_link); + } + } else { + if (!n->has_val_) { + n->value_ = v; + n->has_val_ = true; + new_link = n; + } + } + if (new_link != nullptr) { + new_node.release(); + return n; + } + return nullptr; + } + std::unique_ptr root_; +}; + +// We store the length of the original pattern within the +// Ternary Tree. This allows us to cut out the length of the matching +// separator from the original string. +struct SearchValue { + size_t w_len; + int priority_; + bool operator<(const SearchValue& o) const { + return priority_ < o.priority_; + } +}; + +} // namespace tokenizer_details + +using namespace tokenizer_details; + +struct Tokenizer::SearchData { + TernarySearchTree tst_; +}; + +Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) { + int64_t mark = 0; + auto status = info.GetAttr("mark", &mark); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mark is not set"); + mark_ = mark != 0; + + status = info.GetAttr("pad_value", &pad_value_); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute pad_value is not set"); + + status = info.GetAttr("mincharnum", &mincharnum_); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute mincharnum is not set"); + ONNXRUNTIME_ENFORCE(mincharnum_ > 0, "attribute mincharnum must have a positive value"); + + std::vector separators; + status = info.GetAttrs("separators", separators); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute separators is not set"); + ONNXRUNTIME_ENFORCE(!separators.empty(), "Requires at least one separator"); + + char_tokenezation_ = (separators.size() == 1 && + separators[0].empty()); + + ONNXRUNTIME_ENFORCE(!char_tokenezation_ || mincharnum_ < 2, + "mincharnum is too big for char level tokenezation"); + + // Create TST and insert separators + if (!char_tokenezation_) { + std::unique_ptr sd(std::make_unique()); + std::wstring_convert> converter(conv_error, wconv_error); + int priority = 0; // earlier search patterns get priority + for (const auto& sep : separators) { + ONNXRUNTIME_ENFORCE(!sep.empty(), "No empty separators allowed"); + std::wstring wsep = converter.from_bytes(sep); + ONNXRUNTIME_ENFORCE(wsep != wconv_error, "Separator strings contains invalid utf8 chars"); + bool result = sd->tst_.put(wsep.c_str(), wsep.length(), {wsep.length(), priority}); + ONNXRUNTIME_ENFORCE(result, "duplicate separator detected"); + ++priority; + } + search_data_.swap(sd); + } +} + +// Make SearchData definition available for destruction +Tokenizer ::~Tokenizer() { +} + +Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, + const std::vector& input_dims) const { + // With char tokenzation we get as many tokens as the number of + // utf8 characters in the string. So for every string we calculate its character(utf8) length + // add padding and add start/end test separators if necessary + size_t max_tokens = 0; + auto X = ctx->Input(0); + auto const input_data = X->template Data(); + auto curr_input = input_data; + auto const last = input_data + N * C; + while (curr_input != last) { + const auto& s = *curr_input; + size_t tokens = 0; // length in utf8 chars + if (!utf8_validate(reinterpret_cast(s.data()), s.size(), + tokens)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input string contains invalid utf8 chars: " + s); + } + if (mark_) { + tokens += 2; // Start/end markers as separate tokens + } + max_tokens = std::max(max_tokens, tokens); + ++curr_input; + } + + std::vector output_dims(input_dims); + // Check if we have no output due to apparently empty strings input. + if ((max_tokens - mark_ * 2) == 0) { + output_dims.push_back(0); + TensorShape output_shape(output_dims); + ctx->Output(0, output_shape); + return Status::OK(); + } + + output_dims.push_back(max_tokens); + TensorShape output_shape(output_dims); + auto output_tensor = ctx->Output(0, output_shape); + auto const output_data = output_tensor->template MutableData(); + size_t output_index = 0; + curr_input = input_data; + while (curr_input != last) { + const auto& s = *curr_input; + if (mark_) { + new (output_data + output_index) std::string(&start_text, 1); + ++output_index; + } + size_t tokens = 0; + const size_t str_len = s.size(); + for (size_t token_idx = 0; token_idx < str_len;) { + size_t tlen = 0; + bool result = utf8_bytes(static_cast(s[token_idx]), tlen); + assert(result); + (void)result; + assert(token_idx + tlen <= str_len); + new (output_data + output_index) std::string(s.substr(token_idx, tlen)); + ++output_index; + token_idx += tlen; + ++tokens; + } + if (mark_) { + new (output_data + output_index) std::string(&end_text, 1); + ++output_index; + } + // Padding strings + assert(tokens + (mark_ * 2) <= max_tokens); + const size_t pads = max_tokens - (mark_ * 2) - tokens; + for (size_t p = 0; p < pads; ++p) { + new (output_data + output_index) std::string(pad_value_); + ++output_index; + } + ++curr_input; + } + return Status::OK(); +} + +Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx, + size_t N, size_t C, + const std::vector& input_dims) const { + struct Match { + int priority_; + size_t offset_; + size_t size_; + // create a conflict for overlapping matches + // thus if they overlap neither is less than the other + // and they are considered equal + bool operator<(const Match& o) const { + return (offset_ + size_) <= o.offset_; + } + }; + + std::wstring_convert> converter(conv_error, wconv_error); + // Scan all strings and attempt to find separators in them + // collect all the output tokens here + size_t max_tokens = 0; + std::vector> tokenized_strings; + tokenized_strings.reserve(N * C); + auto X = ctx->Input(0); + auto const input_data = X->template Data(); + auto curr_input = input_data; + auto const last = input_data + N * C; + while (curr_input != last) { + const auto& s = *curr_input; + std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid utf8 chars in the input: " + s); + } + + std::set matches; + const wchar_t* ws = wstr.c_str(); + size_t len_remaining = wstr.length(); + size_t offset = 0; + while (len_remaining > 0) { + const auto* val = search_data_->tst_.get(ws, len_remaining); + if (val != nullptr) { + auto p = matches.insert({val->priority_, offset, val->w_len}); + while (!p.second && val->priority_ < p.first->priority_) { + // if overlapping matches of the same pattern(priority), then + // the earlier match naturally wins + matches.erase(p.first); + p = matches.insert({val->priority_, offset, val->w_len}); + } + } + ++ws; + ++offset; + --len_remaining; + } + + // Tokenize + tokenized_strings.emplace_back(); + auto& row_tokens = tokenized_strings.back(); + row_tokens.reserve(matches.size() + 1); + ws = wstr.c_str(); + offset = 0; + for (const auto& m : matches) { + assert(m.offset_ >= offset); + size_t sz = (m.offset_ - offset); + if (sz > 0 && sz >= size_t(mincharnum_)) { + row_tokens.emplace_back(ws, sz); + } + offset = m.offset_ + m.size_; + ws = wstr.c_str() + offset; + } + assert(offset <= wstr.length()); + if (offset < wstr.length()) { + row_tokens.emplace_back(ws, wstr.length() - offset); + } + + size_t tokens = row_tokens.size(); + if (mark_) { + tokens += 2; // Start/end markers as separate tokens + } + max_tokens = std::max(max_tokens, tokens); + ++curr_input; + } + + std::vector output_dims(input_dims); + // Check if we have no output due to either empty input + // everything is a separator + if ((max_tokens - mark_ * 2) == 0) { + output_dims.push_back(0); + TensorShape output_shape(output_dims); + ctx->Output(0, output_shape); + return Status::OK(); + } + + output_dims.push_back(max_tokens); + TensorShape output_shape(output_dims); + + auto output_tensor = ctx->Output(0, output_shape); + auto const output_data = output_tensor->template MutableData(); + +#ifdef _DEBUG + const size_t max_output_index = N * C * max_tokens; +#endif + size_t output_index = 0; + for (auto& row : tokenized_strings) { +#ifdef _DEBUG + size_t c_idx = output_index; +#endif + if (mark_) { + new (output_data + output_index) std::string(&start_text, 1); + ++output_index; + } + // Output tokens for this row + for (auto& token : row) { + new (output_data + output_index) std::string(converter.to_bytes(token)); + ++output_index; + } + if (mark_) { + new (output_data + output_index) std::string(&end_text, 1); + ++output_index; + } + const size_t pads = max_tokens - (mark_ * 2) - row.size(); + for (size_t p = 0; p < pads; ++p) { + new (output_data + output_index) std::string(pad_value_); + ++output_index; + } +#ifdef _DEBUG + assert(output_index <= max_output_index); + assert((output_index - c_idx) <= max_tokens); +#endif + } + return Status::OK(); +} + +Status Tokenizer::Compute(OpKernelContext* ctx) const { + // Get input buffer ptr + auto X = ctx->Input(0); + if (X->DataType() != DataTypeImpl::GetType()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "tensor(string) expected as input"); + } + + auto& input_dims = X->Shape().GetDims(); + size_t N = 0; + size_t C = 0; + if (input_dims.size() == 1) { + N = 1; + if (input_dims[0] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid C dimension value"); + } + C = input_dims[0]; + } else if (input_dims.size() == 2) { + if (input_dims[0] < 1 || input_dims[1] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid N and/or C dimension values"); + } + N = input_dims[0]; + C = input_dims[1]; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input dimensions are either [C] or [N][C] allowed"); + } + + Status s; + if (char_tokenezation_) { + s = CharTokenize(ctx, N, C, input_dims); + } else { + s = SeparatorTokenize(ctx, N, C, input_dims); + } + return s; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.h b/onnxruntime/contrib_ops/cpu/tokenizer.h new file mode 100644 index 0000000000..0de73ca3fb --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tokenizer.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +#include + +namespace onnxruntime { +namespace contrib { +class Tokenizer final : public OpKernel { + public: + explicit Tokenizer(const OpKernelInfo& info); + Tokenizer(const Tokenizer&) = delete; + Tokenizer& operator=(const Tokenizer&) = delete; + ~Tokenizer(); + + Status Compute(OpKernelContext* context) const override; + + private: + Status CharTokenize(OpKernelContext* context, size_t N, size_t C, + const std::vector& input_dims) const; + + Status SeparatorTokenize(OpKernelContext* context, size_t N, size_t C, + const std::vector& input_dims) const; + + bool mark_; + std::string pad_value_; + int64_t mincharnum_; + bool char_tokenezation_; + struct SearchData; + std::unique_ptr search_data_; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/common/utf8_util.h b/onnxruntime/core/common/utf8_util.h new file mode 100644 index 0000000000..bbc10bb094 --- /dev/null +++ b/onnxruntime/core/common/utf8_util.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" + +namespace onnxruntime { +namespace utf8_util { + +// Returns the number of bytes in the utf8 character +// by analyzing its leading byte +inline bool utf8_bytes(unsigned char ch, size_t& len) { + if ((ch & 0x80) == 0) { + len = 1; + return true; + } + if ((ch & 0xE0) == 0xC0) { + len = 2; + return true; + } + unsigned int result = (ch & 0xF0); + if (result == 0xE0) { + len = 3; + return true; + } + if (result == 0xF0) { + len = 4; + return true; + } + return false; +} + +inline bool utf8_validate(const unsigned char* s, size_t len, size_t& utf8_chars) { + size_t utf8_len = 0; + size_t idx = 0; + while (idx < len) { + size_t bytes = 0; + auto ch = s[idx]; + if (utf8_bytes(ch, bytes)) { + switch (bytes) { + case 1: + break; + case 2: { + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { + return false; + } + } break; // 2 + case 3: { + auto ch1 = s[idx]; + switch (ch1) { + case 0xE0u: + if (++idx >= len || s[idx] < 0xA0u || s[idx] > 0xBFu) { + return false; + } + break; + case 0xEDu: + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x9Fu) { + return false; + } + break; + default: { + if ((ch1 >= 0xE1u && ch1 <= 0xECu) || + (ch1 >= 0xEEu && ch1 <= 0xEFu)) { + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { + return false; + } + } else { + return false; + } + } break; + } + // validate byte 3 + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { + return false; + } + } break; // 3 + case 4: { + auto ch1 = s[idx]; + switch (ch1) { + case 0xF0u: { + if (++idx >= len || s[idx] < 0x90u || s[idx] > 0xBFu) { + return false; + } + } break; + case 0xF4u: { + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0x8Fu) { + return false; + } + } break; + default: { + if (ch1 >= 0xF1u && ch1 <= 0xF3u) { + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { + return false; + } + } else { + return false; + } + } break; + } + // validate bytes 3 and 4 + size_t stop = idx + 2; + while (idx < stop) { + if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { + return false; + } + } + } break; // 4 + default: + // no chars longer than 4 + return false; + } // switch bytes + ++idx; + ++utf8_len; + } else { + return false; + } + } + // End index must match + // the end of the last byte sequence. + if (idx != len) { + return false; + } + utf8_chars = utf8_len; + return true; +} + +} // namespace utf8_util +} // namespace onnxruntime diff --git a/onnxruntime/test/common/utf8_util_test.cc b/onnxruntime/test/common/utf8_util_test.cc new file mode 100644 index 0000000000..47d631f977 --- /dev/null +++ b/onnxruntime/test/common/utf8_util_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/utf8_util.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +struct Sample { + const char* sequence; + bool valid; +}; + +const std::vector samples = { + {"a", true}, + {"\xc3\xb1", true}, + {"\xc3\x28", false}, + {"\xa0\xa1", false}, + {"\xe2\x82\xa1", true}, + {"\xe2\x28\xa1", false}, + {"\xe2\x82\x28", false}, + {"\xf0\x90\x8c\xbc", true}, + {"\xf0\x28\x8c\xbc", false}, + {"\xf0\x90\x28\xbc", false}, + {"\xf0\x28\x8c\x28", false}, + {"\xf8\xa1\xa1\xa1\xa1", false}, // valid but not Unicode + {"\xfc\xa1\xa1\xa1\xa1\xa1", false}}; // valid but not Unicode + +TEST(Utf8UtilTest, Validate) { + using namespace utf8_util; + for (auto& s : samples) { + size_t utf8_len = 0; + if (s.valid != utf8_validate(reinterpret_cast(s.sequence), strlen(s.sequence), utf8_len)) { + ASSERT_TRUE(false); + } else { + if (s.valid) { + ASSERT_EQ(1U, utf8_len); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/tokenizer_test.cc b/onnxruntime/test/contrib_ops/tokenizer_test.cc new file mode 100644 index 0000000000..2a22f90141 --- /dev/null +++ b/onnxruntime/test/contrib_ops/tokenizer_test.cc @@ -0,0 +1,708 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace tokenizer_test { +const std::string start_mark{0x2}; +const std::string end_mark{0x3}; +const std::string padval("0xdeadbeaf"); + +constexpr const char* domain = onnxruntime::kMSDomain; +const int opset_ver = 1; + +} // namespace tokenizer_test + +using namespace tokenizer_test; + +void InitTestAttr(OpTester& test, bool mark, const std::vector& seps, + int64_t mincharnum) { + test.AddAttribute("mark", int64_t{mark}); + test.AddAttribute("separators", seps); + // Padding for alignment + test.AddAttribute("pad_value", padval); + test.AddAttribute("mincharnum", mincharnum); +} + +TEST(ContribOpTest, TokenizerCharLevel_InvalidDim) { + // Invalid input dimensions + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, {""}, 1); + + std::vector dims{1, 1, 2}; + std::vector input = {std::string("s1"), std::string("s2")}; + test.AddInput("T", dims, input); + std::vector output(input); // do the same for now + test.AddOutput("Y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either [C] or [N][C] allowed"); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersC) { + // Char level tokenezation with latin characters and no + // start/end text markers + // [C] dimensions + // Output [C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, {""}, 1); + + std::vector dims{2}; + std::vector input{"abcdef", "abcd"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(input[0].length())); + std::vector output{ + "a", + "b", + "c", + "d", + "e", + "f", + "a", + "b", + "c", + "d", + padval, + padval}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersC) { + // Char level tokenezation with latin characters and + // with start/end text markers + // [C] dimensions + // Output [C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2}; + std::vector input{"abcdef", "abcd"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(input[0].length() + 2)); + std::vector output{ + start_mark, + "a", + "b", + "c", + "d", + "e", + "f", + end_mark, + start_mark, + "a", + "b", + "c", + "d", + end_mark, + padval, + padval}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_LatinCharsNoMarkersNC) { + // Char level tokenezation with latin characters and no + // start/end text markers + // [N][C] dimensions + // Output [N][C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, {""}, 1); + + std::vector dims{2, 2}; + std::vector input{"abcd", "abcd", "abcd", "abcdef"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(input[3].length())); + std::vector output{ + "a", + "b", + "c", + "d", + padval, + padval, + "a", + "b", + "c", + "d", + padval, + padval, + "a", + "b", + "c", + "d", + padval, + padval, + "a", + "b", + "c", + "d", + "e", + "f"}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_LatinCharsWithMarkersNC) { + // Char level tokenezation with latin characters and + // with start/end text markers + // [N][C] dimensions + // Output [N][C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2, 2}; + std::vector input{"abcd", "abcd", "abcd", "abcdef"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(input[3].length() + 2)); + std::vector output{ + start_mark, + "a", + "b", + "c", + "d", + end_mark, + padval, + padval, + start_mark, + "a", + "b", + "c", + "d", + end_mark, + padval, + padval, + start_mark, + "a", + "b", + "c", + "d", + end_mark, + padval, + padval, + start_mark, + "a", + "b", + "c", + "d", + "e", + "f", + end_mark}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_CyrillicCharsWithMarkersC) { + // Char level tokenezation with Cyrillic characters and + // with start/end text markers + // [C] dimensions + // Output [C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2}; + std::vector input{u8"Абсурд", u8"Кома"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Word Absurd is 6 chars long so we must get 6 individual strings out of it + // which is the max plus start/end text markers + output_dims.push_back(int64_t(6 + 2)); + std::vector output{ + start_mark, + u8"А", u8"б", u8"с", u8"у", u8"р", u8"д", + end_mark, + start_mark, + u8"К", u8"о", u8"м", u8"а", + end_mark, + padval, + padval}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_MixedCharsWithMarkersC) { + // Char level tokenezation with a mix of latin, Spanish, Cyrillic and Chinese + // characters and + // with start/end text markers + // [C] dimensions + // Output [C][D] + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Word Absu?? is 6 chars long so we must get 6 individual strings out of it + // which is the max plus start/end text markers + output_dims.push_back(int64_t(6 + 2)); + std::vector output{ + start_mark, + u8"А", u8"б", u8"с", u8"у", u8"中", u8"文", + end_mark, + start_mark, + u8"К", u8"о", u8"ñ", u8"ó", + end_mark, + padval, + padval}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputC) { + // Special case where empty output is produced + // For [C] we expect [C][0] output + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2}; + std::vector input{u8"", u8""}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(0)); + std::vector output{}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputNC) { + // Special case where empty output is produced + // For [N][C] we expect [N][C][0] output + { + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, {""}, 1); + + std::vector dims{2, 2}; + std::vector input{u8"", u8"", u8"", u8""}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(0)); + std::vector output{}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersC) { + // Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese + // characters and with start/end text markers + // [C] dimensions + // Output [C][D] + { + std::vector separators = { + u8"у", + u8"ñ"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 1); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must split both in 2 + output_dims.push_back(int64_t(2 + 2)); + std::vector output{ + start_mark, + u8"Абс", u8"中文", + end_mark, + start_mark, + u8"Ко", u8"ó", + end_mark}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } // namespace test +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersCompleteMatchEmptyOutputC) { + // Test entire separators match so we get nothing + // in the output + { + std::vector separators = { + u8"Абсу中文", + u8"Коñó"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 1); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must have no output + output_dims.push_back(int64_t(0)); + std::vector output; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) { + // Match the start + { + std::vector separators = { + u8"А", + u8"К"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 1); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop first characters from both strings + output_dims.push_back(int64_t(3)); + std::vector output{ + start_mark, + u8"бсу中文", + end_mark, + start_mark, + u8"оñó", + end_mark}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) { + // Match the end + { + std::vector separators = { + u8"文", + u8"ó"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 1); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop last characters from both strings + output_dims.push_back(int64_t(3)); + std::vector output{ + start_mark, + u8"Абсу中", + end_mark, + start_mark, + u8"Коñ", + end_mark}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4CharsC) { + // Match the end, require at least 4 chars + { + std::vector separators = { + u8"文", + u8"ó"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 4); + + std::vector dims{2}; + std::vector input{u8"Абсу中文", u8"Коñó"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop the last character from the first + // and the second 3 character token does not pass mincharnum + output_dims.push_back(int64_t(3)); + std::vector output{ + start_mark, + u8"Абсу中", + end_mark, + start_mark, + end_mark, + padval}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputC) { + // Empty input for [C] should produce [C][0] + { + std::vector separators = { + u8"文", + u8"ó"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 4); + + std::vector dims{2}; + std::vector input{u8"", u8""}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(0)); + std::vector output; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputNC) { + // Empty input for [N][C] should produce [N][C][0] + { + std::vector separators = { + u8"文", + u8"ó"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 4); + + std::vector dims{2, 2}; + std::vector input{u8"", u8"文", u8"ó", u8""}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(0)); + std::vector output; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapShortFirstC) { + // Test of the overlapping search patterns + // The spec mandates that the patterns that appear + // in the separators earlier must be matched first. + { + // In this case the first pattern must match first + // and there would be no match for the second + std::vector separators = { + u8"су", + u8"Абсу"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, separators, 1); + + std::vector dims{1}; + std::vector input{u8"Абсу中文"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // must split in 2 with no two middle characters + output_dims.push_back(int64_t(2)); + std::vector output{ + u8"Аб", u8"中文"}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstC) { + // Test of the overlapping search patterns + // The spec mandates that the patterns that appear + // in the separators earlier must be matched first. + { + // In this case the first pattern must match first + // and there would be no match for the second + std::vector separators = { + u8"Абсу", + u8"су"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, separators, 1); + + std::vector dims{1}; + std::vector input{u8"Абсу中文"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop the beginning of the word that + // also contains the second separator + output_dims.push_back(int64_t(1)); + std::vector output{u8"中文"}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLongFirstRepeatedShortC) { + // Test of the overlapping search patterns + // The spec mandates that the patterns that appear + // in the separators earlier must be matched first. + { + // In this case the first pattern must match first + // and there would be no match for the second + std::vector separators = { + u8"Абсу", + u8"су"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, separators, 1); + + std::vector dims{1}; + std::vector input{u8"Абсусусу中文"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop the beginning of the word that + // also contains the second separator + output_dims.push_back(int64_t(1)); + std::vector output{u8"中文"}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapingMatchC) { + // Test of the overlapping search patterns + // The spec mandates that the patterns that appear + // in the separators earlier must be matched first. + { + // In this case the first pattern must match first + // and there are more than one overlapping matches for the first + // so the earlier match for the first wins. + // and there would be no match for the second + std::vector separators = { + u8"усу", + u8"Абсу"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, false, separators, 1); + + std::vector dims{1}; + std::vector input{u8"Абсусусу中文"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must drop the beginning of the word that + // also contains the second separator + output_dims.push_back(int64_t(2)); + std::vector output{u8"Абс", u8"су中文"}; + + test.AddOutput("Y", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) { + // Separators and strings with a mix of latin, Spanish, Cyrillic and Chinese + // characters and with start/end text markers + // [C] dimensions + // Output [C][D] + std::vector separators = { + u8";", + u8";;;"}; + + OpTester test("Tokenizer", opset_ver, domain); + InitTestAttr(test, true, separators, 1); + + std::vector dims{4}; + std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + // Must split both in 2 + output_dims.push_back(int64_t(6)); + std::vector output{ + start_mark, + u8"a", + u8"b", + end_mark, + padval, + padval, + start_mark, + u8"a", + u8"b", + end_mark, + padval, + padval, + start_mark, + u8"b", + u8"c", + u8"d", + u8"e", + end_mark, + start_mark, + u8"a", + u8"b", + u8"c", + end_mark, + padval, + }; + + test.AddOutput("Y", output_dims, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +} // namespace test +} // namespace onnxruntime