diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.cc b/onnxruntime/contrib_ops/cpu/tokenizer.cc index 418cbc4486..36c99909c5 100644 --- a/onnxruntime/contrib_ops/cpu/tokenizer.cc +++ b/onnxruntime/contrib_ops/cpu/tokenizer.cc @@ -305,22 +305,23 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C, return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input string contains invalid utf8 chars: " + s); } - if (mark_) { - tokens += 2; // Start/end markers as separate tokens - } max_tokens = std::max(max_tokens, tokens); ++curr_input; } std::vector output_dims(input_dims); // Check if we have no output due to apparently empty strings input. - if ((max_tokens - mark_ * 2) == 0) { + if (max_tokens == 0) { output_dims.push_back(0); TensorShape output_shape(output_dims); ctx->Output(0, output_shape); return Status::OK(); } + if (mark_) { + max_tokens += 2; // Start/end markers as separate tokens + } + output_dims.push_back(max_tokens); TensorShape output_shape(output_dims); auto output_tensor = ctx->Output(0, output_shape); @@ -435,24 +436,24 @@ Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx, row_tokens.emplace_back(ws, wstr.length() - offset); } - size_t tokens = row_tokens.size(); - if (mark_) { - tokens += 2; // Start/end markers as separate tokens - } - max_tokens = std::max(max_tokens, tokens); + max_tokens = std::max(max_tokens, row_tokens.size()); ++curr_input; } std::vector output_dims(input_dims); // Check if we have no output due to either empty input // everything is a separator - if ((max_tokens - mark_ * 2) == 0) { + if (max_tokens == 0) { output_dims.push_back(0); TensorShape output_shape(output_dims); ctx->Output(0, output_shape); return Status::OK(); } + if (mark_) { + max_tokens += 2; // Start/end markers as separate tokens + } + output_dims.push_back(max_tokens); TensorShape output_shape(output_dims); @@ -499,8 +500,7 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx, using namespace re2; // Represents a token that will be output after // first is the index, second is the size; - using Token = std::pair; - std::vector> tokens; + std::vector> tokens; tokens.reserve(N * C); size_t max_tokens = 0; @@ -524,51 +524,43 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx, StringPiece submatch; bool match = true; - while (match) { + do { match = regex_->Match(text, start_pos, end_pos, anchor, &submatch, 1); if (match) { // Record pos/len assert(submatch.data() != nullptr); size_t match_pos = submatch.data() - s.data(); assert(match_pos >= start_pos); - auto token_len = match_pos - start_pos; + // Guard against empty match and make + // sure we make progress either way + auto token_len = submatch.length(); if (token_len > 0) { - row.emplace_back(start_pos, token_len); - } - // Update starting position - // Guard against empty string match - auto match_len = submatch.length(); - if (match_len > 0) { - start_pos = match_pos + match_len; + row.push_back(submatch); + start_pos = match_pos + token_len; } else { start_pos = match_pos + 1; } - } else { - // record trailing token - auto trailing_len = end_pos - start_pos; - if (trailing_len > 0) { - row.emplace_back(start_pos, trailing_len); - } } - } - size_t tokens_num = row.size(); - if (mark_) { - tokens_num += 2; // Start/end markers as separate tokens - } - max_tokens = std::max(max_tokens, tokens_num); + } while (match); + max_tokens = std::max(max_tokens, row.size()); ++curr_input; } + // Check for empty output std::vector output_dims(input_dims); // Check if we have no output due to either empty input // everything is a separator - if ((max_tokens - mark_ * 2) == 0) { + if (max_tokens == 0) { output_dims.push_back(0); TensorShape output_shape(output_dims); ctx->Output(0, output_shape); return Status::OK(); } + if (mark_) { + max_tokens += 2; // Start/end markers as separate tokens + } + output_dims.push_back(max_tokens); TensorShape output_shape(output_dims); @@ -590,15 +582,8 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx, ++output_index; } // Output tokens for this row - const char* data = curr_input->data(); for (const auto& token : row) { -#ifdef _DEBUG - auto s_len = curr_input->length(); - assert(token.second > 0); - assert(token.first < s_len); - assert(token.first + token.second <= s_len); -#endif - (output_data + output_index)->assign(data + token.first, token.second); + (output_data + output_index)->assign(token.data(), token.length()); ++output_index; } if (mark_) { diff --git a/onnxruntime/test/contrib_ops/tokenizer_test.cc b/onnxruntime/test/contrib_ops/tokenizer_test.cc index 83a09a6d85..8823b53eb2 100644 --- a/onnxruntime/test/contrib_ops/tokenizer_test.cc +++ b/onnxruntime/test/contrib_ops/tokenizer_test.cc @@ -11,7 +11,7 @@ namespace test { namespace tokenizer_test { const std::string start_mark{0x2}; const std::string end_mark{0x3}; -const std::string padval("0xdeadbeaf"); +const std::string padval(u8"0xdeadbeaf"); constexpr const char* domain = onnxruntime::kMSDomain; const int opset_ver = 1; @@ -693,51 +693,9 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) { test.Run(OpTester::ExpectResult::kExpectSuccess); } -TEST(ContribOpTest, TokenizerExpression_SimpleSep) { - OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(";"); - InitTestAttr(test, true, {}, 1, tokenexp); - - std::vector dims{4}; - std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; - test.AddInput("T", dims, input); - - std::vector output_dims(dims); - output_dims.push_back(int64_t(6)); - std::vector output{ - start_mark, - u8"a", - u8"b", - end_mark, - padval, - padval, - start_mark, - u8"a", - u8"b", - end_mark, - padval, - padval, - start_mark, - u8"b", - u8"c", - u8"d", - u8"e", - end_mark, - start_mark, - u8"a", - u8"b", - u8"c", - end_mark, - padval, - }; - - test.AddOutput("Y", output_dims, output); - test.Run(OpTester::ExpectResult::kExpectSuccess); -} - TEST(ContribOpTest, TokenizerExpression_RegEx) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp("a."); + const std::string tokenexp(u8"a."); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; @@ -748,16 +706,16 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) { output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"b", + u8"a;", end_mark, start_mark, - u8";;b", + u8"a;", end_mark, start_mark, - u8"b;c;;;d;e", end_mark, + padval, start_mark, - u8";b;;;c", + u8"a;", end_mark, }; @@ -767,7 +725,7 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) { TEST(ContribOpTest, TokenizerExpression_RegRep) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp("c;*"); + const std::string tokenexp(u8"c;+"); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; @@ -775,22 +733,18 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) { test.AddInput("T", dims, input); std::vector output_dims(dims); - output_dims.push_back(int64_t(4)); + output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"a;b", end_mark, padval, start_mark, - u8"a;;;b", end_mark, padval, start_mark, - u8"b;", - u8"d;e", + u8"c;;;", end_mark, start_mark, - u8"a;;b;;;", end_mark, padval}; @@ -800,7 +754,7 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) { TEST(ContribOpTest, TokenizerExpression_Grouping) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp("(a;)|(b;)"); + const std::string tokenexp(u8"(a;)|(b;)"); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; @@ -811,20 +765,44 @@ TEST(ContribOpTest, TokenizerExpression_Grouping) { output_dims.push_back(int64_t(4)); std::vector output{ start_mark, - u8"b", + u8"a;", end_mark, padval, start_mark, - u8";;b", + u8"a;", end_mark, padval, start_mark, - u8"c;;;d;e", + u8"b;", end_mark, padval, start_mark, + u8"a;", + u8"b;", + end_mark}; + + test.AddOutput("Y", output_dims, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, TokenizerExpression_RegDot) { + OpTester test("Tokenizer", opset_ver, domain); + const std::string tokenexp(u8"."); + InitTestAttr(test, true, {}, 1, tokenexp); + + std::vector dims{1}; + std::vector input{u8"a;;;b"}; + test.AddInput("T", dims, input); + + std::vector output_dims(dims); + output_dims.push_back(int64_t(7)); + std::vector output{ + start_mark, + u8"a", u8";", - u8";;c", + u8";", + u8";", + u8"b", end_mark}; test.AddOutput("Y", output_dims, output);