mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Rework tokenexp to match tokens instead of separators. (#617)
Adjust tests.
This commit is contained in:
parent
2ae83c580c
commit
fba98bb4de
2 changed files with 65 additions and 102 deletions
|
|
@ -305,22 +305,23 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
|
|||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input string contains invalid utf8 chars: " + s);
|
||||
}
|
||||
if (mark_) {
|
||||
tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
max_tokens = std::max(max_tokens, tokens);
|
||||
++curr_input;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims(input_dims);
|
||||
// Check if we have no output due to apparently empty strings input.
|
||||
if ((max_tokens - mark_ * 2) == 0) {
|
||||
if (max_tokens == 0) {
|
||||
output_dims.push_back(0);
|
||||
TensorShape output_shape(output_dims);
|
||||
ctx->Output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mark_) {
|
||||
max_tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
|
||||
output_dims.push_back(max_tokens);
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
|
|
@ -435,24 +436,24 @@ Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx,
|
|||
row_tokens.emplace_back(ws, wstr.length() - offset);
|
||||
}
|
||||
|
||||
size_t tokens = row_tokens.size();
|
||||
if (mark_) {
|
||||
tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
max_tokens = std::max(max_tokens, tokens);
|
||||
max_tokens = std::max(max_tokens, row_tokens.size());
|
||||
++curr_input;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims(input_dims);
|
||||
// Check if we have no output due to either empty input
|
||||
// everything is a separator
|
||||
if ((max_tokens - mark_ * 2) == 0) {
|
||||
if (max_tokens == 0) {
|
||||
output_dims.push_back(0);
|
||||
TensorShape output_shape(output_dims);
|
||||
ctx->Output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mark_) {
|
||||
max_tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
|
||||
output_dims.push_back(max_tokens);
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
|
|
@ -499,8 +500,7 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
|
|||
using namespace re2;
|
||||
// Represents a token that will be output after
|
||||
// first is the index, second is the size;
|
||||
using Token = std::pair<size_t, size_t>;
|
||||
std::vector<std::vector<Token>> tokens;
|
||||
std::vector<std::vector<StringPiece>> tokens;
|
||||
tokens.reserve(N * C);
|
||||
|
||||
size_t max_tokens = 0;
|
||||
|
|
@ -524,51 +524,43 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
|
|||
StringPiece submatch;
|
||||
|
||||
bool match = true;
|
||||
while (match) {
|
||||
do {
|
||||
match = regex_->Match(text, start_pos, end_pos, anchor, &submatch, 1);
|
||||
if (match) {
|
||||
// Record pos/len
|
||||
assert(submatch.data() != nullptr);
|
||||
size_t match_pos = submatch.data() - s.data();
|
||||
assert(match_pos >= start_pos);
|
||||
auto token_len = match_pos - start_pos;
|
||||
// Guard against empty match and make
|
||||
// sure we make progress either way
|
||||
auto token_len = submatch.length();
|
||||
if (token_len > 0) {
|
||||
row.emplace_back(start_pos, token_len);
|
||||
}
|
||||
// Update starting position
|
||||
// Guard against empty string match
|
||||
auto match_len = submatch.length();
|
||||
if (match_len > 0) {
|
||||
start_pos = match_pos + match_len;
|
||||
row.push_back(submatch);
|
||||
start_pos = match_pos + token_len;
|
||||
} else {
|
||||
start_pos = match_pos + 1;
|
||||
}
|
||||
} else {
|
||||
// record trailing token
|
||||
auto trailing_len = end_pos - start_pos;
|
||||
if (trailing_len > 0) {
|
||||
row.emplace_back(start_pos, trailing_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
size_t tokens_num = row.size();
|
||||
if (mark_) {
|
||||
tokens_num += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
max_tokens = std::max(max_tokens, tokens_num);
|
||||
} while (match);
|
||||
max_tokens = std::max(max_tokens, row.size());
|
||||
++curr_input;
|
||||
}
|
||||
|
||||
// Check for empty output
|
||||
std::vector<int64_t> output_dims(input_dims);
|
||||
// Check if we have no output due to either empty input
|
||||
// everything is a separator
|
||||
if ((max_tokens - mark_ * 2) == 0) {
|
||||
if (max_tokens == 0) {
|
||||
output_dims.push_back(0);
|
||||
TensorShape output_shape(output_dims);
|
||||
ctx->Output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mark_) {
|
||||
max_tokens += 2; // Start/end markers as separate tokens
|
||||
}
|
||||
|
||||
output_dims.push_back(max_tokens);
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
|
|
@ -590,15 +582,8 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
|
|||
++output_index;
|
||||
}
|
||||
// Output tokens for this row
|
||||
const char* data = curr_input->data();
|
||||
for (const auto& token : row) {
|
||||
#ifdef _DEBUG
|
||||
auto s_len = curr_input->length();
|
||||
assert(token.second > 0);
|
||||
assert(token.first < s_len);
|
||||
assert(token.first + token.second <= s_len);
|
||||
#endif
|
||||
(output_data + output_index)->assign(data + token.first, token.second);
|
||||
(output_data + output_index)->assign(token.data(), token.length());
|
||||
++output_index;
|
||||
}
|
||||
if (mark_) {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace test {
|
|||
namespace tokenizer_test {
|
||||
const std::string start_mark{0x2};
|
||||
const std::string end_mark{0x3};
|
||||
const std::string padval("0xdeadbeaf");
|
||||
const std::string padval(u8"0xdeadbeaf");
|
||||
|
||||
constexpr const char* domain = onnxruntime::kMSDomain;
|
||||
const int opset_ver = 1;
|
||||
|
|
@ -693,51 +693,9 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerExpression_SimpleSep) {
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
const std::string tokenexp(";");
|
||||
InitTestAttr(test, true, {}, 1, tokenexp);
|
||||
|
||||
std::vector<int64_t> dims{4};
|
||||
std::vector<std::string> input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(6));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
end_mark,
|
||||
padval,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"b",
|
||||
u8"c",
|
||||
u8"d",
|
||||
u8"e",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8"b",
|
||||
u8"c",
|
||||
end_mark,
|
||||
padval,
|
||||
};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerExpression_RegEx) {
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
const std::string tokenexp("a.");
|
||||
const std::string tokenexp(u8"a.");
|
||||
InitTestAttr(test, true, {}, 1, tokenexp);
|
||||
|
||||
std::vector<int64_t> dims{4};
|
||||
|
|
@ -748,16 +706,16 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) {
|
|||
output_dims.push_back(int64_t(3));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"b",
|
||||
u8"a;",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8";;b",
|
||||
u8"a;",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"b;c;;;d;e",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8";b;;;c",
|
||||
u8"a;",
|
||||
end_mark,
|
||||
};
|
||||
|
||||
|
|
@ -767,7 +725,7 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) {
|
|||
|
||||
TEST(ContribOpTest, TokenizerExpression_RegRep) {
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
const std::string tokenexp("c;*");
|
||||
const std::string tokenexp(u8"c;+");
|
||||
InitTestAttr(test, true, {}, 1, tokenexp);
|
||||
|
||||
std::vector<int64_t> dims{4};
|
||||
|
|
@ -775,22 +733,18 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) {
|
|||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(4));
|
||||
output_dims.push_back(int64_t(3));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"a;b",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"a;;;b",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"b;",
|
||||
u8"d;e",
|
||||
u8"c;;;",
|
||||
end_mark,
|
||||
start_mark,
|
||||
u8"a;;b;;;",
|
||||
end_mark,
|
||||
padval};
|
||||
|
||||
|
|
@ -800,7 +754,7 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) {
|
|||
|
||||
TEST(ContribOpTest, TokenizerExpression_Grouping) {
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
const std::string tokenexp("(a;)|(b;)");
|
||||
const std::string tokenexp(u8"(a;)|(b;)");
|
||||
InitTestAttr(test, true, {}, 1, tokenexp);
|
||||
|
||||
std::vector<int64_t> dims{4};
|
||||
|
|
@ -811,20 +765,44 @@ TEST(ContribOpTest, TokenizerExpression_Grouping) {
|
|||
output_dims.push_back(int64_t(4));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"b",
|
||||
u8"a;",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8";;b",
|
||||
u8"a;",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"c;;;d;e",
|
||||
u8"b;",
|
||||
end_mark,
|
||||
padval,
|
||||
start_mark,
|
||||
u8"a;",
|
||||
u8"b;",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
TEST(ContribOpTest, TokenizerExpression_RegDot) {
|
||||
OpTester test("Tokenizer", opset_ver, domain);
|
||||
const std::string tokenexp(u8".");
|
||||
InitTestAttr(test, true, {}, 1, tokenexp);
|
||||
|
||||
std::vector<int64_t> dims{1};
|
||||
std::vector<std::string> input{u8"a;;;b"};
|
||||
test.AddInput<std::string>("T", dims, input);
|
||||
|
||||
std::vector<int64_t> output_dims(dims);
|
||||
output_dims.push_back(int64_t(7));
|
||||
std::vector<std::string> output{
|
||||
start_mark,
|
||||
u8"a",
|
||||
u8";",
|
||||
u8";;c",
|
||||
u8";",
|
||||
u8";",
|
||||
u8"b",
|
||||
end_mark};
|
||||
|
||||
test.AddOutput<std::string>("Y", output_dims, output);
|
||||
|
|
|
|||
Loading…
Reference in a new issue