Rework tokenexp to match tokens instead of separators. (#617)

Adjust tests.
This commit is contained in:
Dmitri Smirnov 2019-03-13 17:43:37 -07:00 committed by GitHub
parent 2ae83c580c
commit fba98bb4de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 102 deletions

View file

@ -305,22 +305,23 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input string contains invalid utf8 chars: " + s);
}
if (mark_) {
tokens += 2; // Start/end markers as separate tokens
}
max_tokens = std::max(max_tokens, tokens);
++curr_input;
}
std::vector<int64_t> output_dims(input_dims);
// Check if we have no output due to apparently empty strings input.
if ((max_tokens - mark_ * 2) == 0) {
if (max_tokens == 0) {
output_dims.push_back(0);
TensorShape output_shape(output_dims);
ctx->Output(0, output_shape);
return Status::OK();
}
if (mark_) {
max_tokens += 2; // Start/end markers as separate tokens
}
output_dims.push_back(max_tokens);
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
@ -435,24 +436,24 @@ Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx,
row_tokens.emplace_back(ws, wstr.length() - offset);
}
size_t tokens = row_tokens.size();
if (mark_) {
tokens += 2; // Start/end markers as separate tokens
}
max_tokens = std::max(max_tokens, tokens);
max_tokens = std::max(max_tokens, row_tokens.size());
++curr_input;
}
std::vector<int64_t> output_dims(input_dims);
// Check if we have no output due to either empty input
// everything is a separator
if ((max_tokens - mark_ * 2) == 0) {
if (max_tokens == 0) {
output_dims.push_back(0);
TensorShape output_shape(output_dims);
ctx->Output(0, output_shape);
return Status::OK();
}
if (mark_) {
max_tokens += 2; // Start/end markers as separate tokens
}
output_dims.push_back(max_tokens);
TensorShape output_shape(output_dims);
@ -499,8 +500,7 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
using namespace re2;
// Represents a token that will be output after
// first is the index, second is the size;
using Token = std::pair<size_t, size_t>;
std::vector<std::vector<Token>> tokens;
std::vector<std::vector<StringPiece>> tokens;
tokens.reserve(N * C);
size_t max_tokens = 0;
@ -524,51 +524,43 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
StringPiece submatch;
bool match = true;
while (match) {
do {
match = regex_->Match(text, start_pos, end_pos, anchor, &submatch, 1);
if (match) {
// Record pos/len
assert(submatch.data() != nullptr);
size_t match_pos = submatch.data() - s.data();
assert(match_pos >= start_pos);
auto token_len = match_pos - start_pos;
// Guard against empty match and make
// sure we make progress either way
auto token_len = submatch.length();
if (token_len > 0) {
row.emplace_back(start_pos, token_len);
}
// Update starting position
// Guard against empty string match
auto match_len = submatch.length();
if (match_len > 0) {
start_pos = match_pos + match_len;
row.push_back(submatch);
start_pos = match_pos + token_len;
} else {
start_pos = match_pos + 1;
}
} else {
// record trailing token
auto trailing_len = end_pos - start_pos;
if (trailing_len > 0) {
row.emplace_back(start_pos, trailing_len);
}
}
}
size_t tokens_num = row.size();
if (mark_) {
tokens_num += 2; // Start/end markers as separate tokens
}
max_tokens = std::max(max_tokens, tokens_num);
} while (match);
max_tokens = std::max(max_tokens, row.size());
++curr_input;
}
// Check for empty output
std::vector<int64_t> output_dims(input_dims);
// Check if we have no output due to either empty input
// everything is a separator
if ((max_tokens - mark_ * 2) == 0) {
if (max_tokens == 0) {
output_dims.push_back(0);
TensorShape output_shape(output_dims);
ctx->Output(0, output_shape);
return Status::OK();
}
if (mark_) {
max_tokens += 2; // Start/end markers as separate tokens
}
output_dims.push_back(max_tokens);
TensorShape output_shape(output_dims);
@ -590,15 +582,8 @@ Status Tokenizer::ExpressionTokenize(OpKernelContext* ctx,
++output_index;
}
// Output tokens for this row
const char* data = curr_input->data();
for (const auto& token : row) {
#ifdef _DEBUG
auto s_len = curr_input->length();
assert(token.second > 0);
assert(token.first < s_len);
assert(token.first + token.second <= s_len);
#endif
(output_data + output_index)->assign(data + token.first, token.second);
(output_data + output_index)->assign(token.data(), token.length());
++output_index;
}
if (mark_) {

View file

@ -11,7 +11,7 @@ namespace test {
namespace tokenizer_test {
const std::string start_mark{0x2};
const std::string end_mark{0x3};
const std::string padval("0xdeadbeaf");
const std::string padval(u8"0xdeadbeaf");
constexpr const char* domain = onnxruntime::kMSDomain;
const int opset_ver = 1;
@ -693,51 +693,9 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) {
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ContribOpTest, TokenizerExpression_SimpleSep) {
OpTester test("Tokenizer", opset_ver, domain);
const std::string tokenexp(";");
InitTestAttr(test, true, {}, 1, tokenexp);
std::vector<int64_t> dims{4};
std::vector<std::string> input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(6));
std::vector<std::string> output{
start_mark,
u8"a",
u8"b",
end_mark,
padval,
padval,
start_mark,
u8"a",
u8"b",
end_mark,
padval,
padval,
start_mark,
u8"b",
u8"c",
u8"d",
u8"e",
end_mark,
start_mark,
u8"a",
u8"b",
u8"c",
end_mark,
padval,
};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ContribOpTest, TokenizerExpression_RegEx) {
OpTester test("Tokenizer", opset_ver, domain);
const std::string tokenexp("a.");
const std::string tokenexp(u8"a.");
InitTestAttr(test, true, {}, 1, tokenexp);
std::vector<int64_t> dims{4};
@ -748,16 +706,16 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) {
output_dims.push_back(int64_t(3));
std::vector<std::string> output{
start_mark,
u8"b",
u8"a;",
end_mark,
start_mark,
u8";;b",
u8"a;",
end_mark,
start_mark,
u8"b;c;;;d;e",
end_mark,
padval,
start_mark,
u8";b;;;c",
u8"a;",
end_mark,
};
@ -767,7 +725,7 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) {
TEST(ContribOpTest, TokenizerExpression_RegRep) {
OpTester test("Tokenizer", opset_ver, domain);
const std::string tokenexp("c;*");
const std::string tokenexp(u8"c;+");
InitTestAttr(test, true, {}, 1, tokenexp);
std::vector<int64_t> dims{4};
@ -775,22 +733,18 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) {
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(4));
output_dims.push_back(int64_t(3));
std::vector<std::string> output{
start_mark,
u8"a;b",
end_mark,
padval,
start_mark,
u8"a;;;b",
end_mark,
padval,
start_mark,
u8"b;",
u8"d;e",
u8"c;;;",
end_mark,
start_mark,
u8"a;;b;;;",
end_mark,
padval};
@ -800,7 +754,7 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) {
TEST(ContribOpTest, TokenizerExpression_Grouping) {
OpTester test("Tokenizer", opset_ver, domain);
const std::string tokenexp("(a;)|(b;)");
const std::string tokenexp(u8"(a;)|(b;)");
InitTestAttr(test, true, {}, 1, tokenexp);
std::vector<int64_t> dims{4};
@ -811,20 +765,44 @@ TEST(ContribOpTest, TokenizerExpression_Grouping) {
output_dims.push_back(int64_t(4));
std::vector<std::string> output{
start_mark,
u8"b",
u8"a;",
end_mark,
padval,
start_mark,
u8";;b",
u8"a;",
end_mark,
padval,
start_mark,
u8"c;;;d;e",
u8"b;",
end_mark,
padval,
start_mark,
u8"a;",
u8"b;",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ContribOpTest, TokenizerExpression_RegDot) {
OpTester test("Tokenizer", opset_ver, domain);
const std::string tokenexp(u8".");
InitTestAttr(test, true, {}, 1, tokenexp);
std::vector<int64_t> dims{1};
std::vector<std::string> input{u8"a;;;b"};
test.AddInput<std::string>("T", dims, input);
std::vector<int64_t> output_dims(dims);
output_dims.push_back(int64_t(7));
std::vector<std::string> output{
start_mark,
u8"a",
u8";",
u8";;c",
u8";",
u8";",
u8"b",
end_mark};
test.AddOutput<std::string>("Y", output_dims, output);