Bifurcation detector for aggressive decoding (#9432)

```
Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
starting from previous suffix match index, and predicted tokens.
Concat predicted tokens, starting from bifurcation index, to the back
of current tokens. This forms the output tokens.
Detect suffix match index in source tokens, between source tokens and output tokens.
Detection is based on finding the appearances of last n-gram in output tokens
in source tokens.
A match is considered found if source tokens contain a single matching n-gram.
Return the index of the start of the n-gram in source tokens.
No matching if found if src tokens contain multiple or zero matching n-grams. Return -1.
```
This commit is contained in:
Bowen Bao 2021-10-19 19:53:56 -07:00 committed by GitHub
parent 20eaed43e5
commit e983f37121
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 287 additions and 8 deletions

View file

@ -8,6 +8,7 @@ Do not modify directly.*
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
* <a href="#com.microsoft.BifurcationDetector">com.microsoft.BifurcationDetector</a>
* <a href="#com.microsoft.CDist">com.microsoft.CDist</a>
* <a href="#com.microsoft.ComplexMul">com.microsoft.ComplexMul</a>
* <a href="#com.microsoft.ComplexMulConj">com.microsoft.ComplexMulConj</a>
@ -463,6 +464,62 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.BifurcationDetector"></a><a name="com.microsoft.bifurcationdetector">**com.microsoft.BifurcationDetector**</a>
Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
starting from previous suffix match index, and predicted tokens.
Concat predicted tokens, starting from bifurcation index, to the back
of current tokens. This forms the output tokens.
Detect suffix match index in source tokens, between source tokens and output tokens.
Detection is based on finding the appearances of last n-gram in output tokens
in source tokens.
A match is considered found if source tokens contain a single matching n-gram.
Return the index of the start of the n-gram in source tokens.
No matching if found if src tokens contain multiple or zero matching n-grams. Return -1.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>max_ngram_size</tt> : int</dt>
<dd>The maximum NGram size for suffix matching.</dd>
<dt><tt>min_ngram_size</tt> : int</dt>
<dd>The minimum NGram size for suffix matching.</dd>
</dl>
#### Inputs (3 - 4)
<dl>
<dt><tt>src_tokens</tt> : T</dt>
<dd>Encoder input ids.</dd>
<dt><tt>cur_tokens</tt> : T</dt>
<dd>Decoder input ids.</dd>
<dt><tt>prev_suffix_match_idx</tt> : T</dt>
<dd>Previous suffix match index</dd>
<dt><tt>pred_tokens</tt> (optional) : T</dt>
<dd>Predicted token ids from aggressive decoding</dd>
</dl>
#### Outputs
<dl>
<dt><tt>tokens</tt> : T</dt>
<dd>Decoder input ids after merging predicted tokens</dd>
<dt><tt>suffix_match_idx</tt> : T</dt>
<dd>new suffix match index</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(int64)</dt>
<dd>Constrain to integer types.</dd>
</dl>
### <a name="com.microsoft.CDist"></a><a name="com.microsoft.cdist">**com.microsoft.CDist**</a>
#### Version

View file

@ -372,6 +372,7 @@ Do not modify directly.*
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T** = tensor(float)<br/> **T2** = tensor(int32)|

View file

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "bifurcation_detector.h"
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
BifurcationDetector,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
BifurcationDetector);
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
class BifurcationDetector : public OpKernel {
public:
explicit BifurcationDetector(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("min_ngram_size", &min_ngram_size_).IsOK());
ORT_ENFORCE(min_ngram_size_ > 0);
ORT_ENFORCE(info.GetAttr<int64_t>("max_ngram_size", &max_ngram_size_).IsOK());
ORT_ENFORCE(max_ngram_size_ > 0);
ORT_ENFORCE(max_ngram_size_ >= min_ngram_size_);
}
Status Compute(OpKernelContext* context) const override {
const Tensor* src_tokens = context->Input<Tensor>(0);
const Tensor* cur_tokens = context->Input<Tensor>(1);
const Tensor* prev_suffix_match_idx = context->Input<Tensor>(2);
const Tensor* pred_tokens = context->Input<Tensor>(3);
const auto* src_tokens_data = static_cast<const int64_t*>(src_tokens->DataRaw());
const auto* cur_tokens_data = static_cast<const int64_t*>(cur_tokens->DataRaw());
int64_t src_tokens_len = src_tokens->Shape().GetDims().at(0);
int64_t cur_tokens_len = cur_tokens->Shape().GetDims().at(0);
Tensor* out_tokens = nullptr;
// Find the bifurcation index of predicted tokens,
// between source tokens, starting from previous suffix match index,
// and predicted tokens.
// Concat predicted tokens, starting from bifurcation index, to the back
// of current tokens. This forms the output tokens.
if (nullptr == pred_tokens) {
// No prediction tokens. Output tokens equals to current tokens.
out_tokens = context->Output(0, cur_tokens->Shape());
auto* out_tokens_data = static_cast<int64_t*>(out_tokens->MutableDataRaw());
memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t));
} else {
const auto* pred_tokens_data = static_cast<const int64_t*>(pred_tokens->DataRaw());
const int64_t prev_suffix_match_idx_data = static_cast<const int64_t*>(prev_suffix_match_idx->DataRaw())[0];
int64_t pred_tokens_len = pred_tokens->Shape().GetDims().at(0);
// Find bifurcation index between prediction tokens, and source tokens
// starting from previous suffix match index.
ORT_ENFORCE(src_tokens_len >= prev_suffix_match_idx_data);
ORT_ENFORCE(pred_tokens_len == (src_tokens_len + 1 - prev_suffix_match_idx_data));
int64_t pred_bifur_idx = 0;
for (; pred_bifur_idx < src_tokens_len - prev_suffix_match_idx_data; ++pred_bifur_idx) {
if (pred_tokens_data[pred_bifur_idx] != src_tokens_data[pred_bifur_idx + prev_suffix_match_idx_data]) {
break;
}
}
// pred_bifur_idx in [0, pred_tokens_len - 1]
out_tokens = context->Output(0, TensorShape({cur_tokens_len + pred_bifur_idx + 1}));
auto* out_tokens_data = static_cast<int64_t*>(out_tokens->MutableDataRaw());
memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t));
memcpy(out_tokens_data + cur_tokens_len, pred_tokens_data, (pred_bifur_idx + 1) * sizeof(int64_t));
}
// Detect suffix match index in source tokens, between source tokens and output tokens.
// Detection is based on finding the appearances of last n-gram in output tokens
// in source tokens.
// A match is considered found if source tokens contain a single matching n-gram.
// Return the index of the start of the n-gram in source tokens.
// No matching if found if src tokens contain multiple or zero matching n-grams.
// Return -1.
int64_t tokens_len = out_tokens->Shape().GetDims().at(0);
int64_t min_gram = min_ngram_size_;
int64_t max_gram = max_ngram_size_;
int64_t suffix_idx = -1;
const auto* tokens_data = static_cast<const int64_t*>(out_tokens->DataRaw());
for (int64_t i = min_gram; i < max_gram + 1; ++i) {
if (i > tokens_len) {
break;
}
auto it = std::search(
src_tokens_data,
src_tokens_data + src_tokens_len,
tokens_data + tokens_len - i,
tokens_data + tokens_len);
if (it == (src_tokens_data + src_tokens_len)) {
break;
} else {
suffix_idx = it - src_tokens_data + i;
if (suffix_idx >= src_tokens_len) {
break;
}
auto it_2 = std::search(
src_tokens_data + suffix_idx - i + 1,
src_tokens_data + src_tokens_len,
tokens_data + tokens_len - i,
tokens_data + tokens_len);
if (it_2 != (src_tokens_data + src_tokens_len)) {
suffix_idx = -1;
continue;
}
}
}
Tensor* out_suffix_match_idx = context->Output(1, prev_suffix_match_idx->Shape());
auto* out_suffix_match_idx_data = static_cast<int64_t*>(out_suffix_match_idx->MutableDataRaw());
out_suffix_match_idx_data[0] = suffix_idx;
return Status::OK();
}
private:
int64_t min_ngram_size_;
int64_t max_ngram_size_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -38,6 +38,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
#ifdef BUILD_MS_EXPERIMENTAL_OPS
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT);
@ -208,6 +209,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector)>,
#ifdef BUILD_MS_EXPERIMENTAL_OPS
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT)>,

View file

@ -236,7 +236,7 @@ void embedLayerNormalizationShapeInference(InferenceContext& ctx) {
"gamma should have 2 dimension, dimension size known, "
"and same hidden size as word_embedding.");
}
auto& beta_shape = getInputShape(ctx, 6);
auto& beta_dims = gamma_shape.dim();
if (beta_dims.size() != 1 ||
@ -772,13 +772,13 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03
.Const("one", 1.0, elem_type)
.Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)")
.Add(R"(
T1 = Mul (X_bias, X_bias)
T2 = Mul (c, T1)
T3 = Add (b, T2)
T4 = Mul (X_bias, T3)
T5 = Tanh (T4)
T6 = Add (one, T5)
T7 = Mul (X_bias, T6)
T1 = Mul (X_bias, X_bias)
T2 = Mul (c, T1)
T3 = Add (b, T2)
T4 = Mul (X_bias, T3)
T5 = Tanh (T4)
T6 = Add (one, T5)
T7 = Mul (X_bias, T6)
Y = Mul (a, T7)
)");
@ -825,6 +825,43 @@ Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form
}
propagateShapeFromInputToOutput(ctx, 1, 0);
});
static const char* BifurcationDetector_ver1_doc = R"DOC(
Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
starting from previous suffix match index, and predicted tokens.
Concat predicted tokens, starting from bifurcation index, to the back
of current tokens. This forms the output tokens.
Detect suffix match index in source tokens, between source tokens and output tokens.
Detection is based on finding the appearances of last n-gram in output tokens
in source tokens.
A match is considered found if source tokens contain a single matching n-gram.
Return the index of the start of the n-gram in source tokens.
No matching if found if src tokens contain multiple or zero matching n-grams. Return -1.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(BifurcationDetector)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(BifurcationDetector_ver1_doc)
.Attr("min_ngram_size", "The minimum NGram size for suffix matching.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("max_ngram_size", "The maximum NGram size for suffix matching.", AttributeProto::INT, static_cast<int64_t>(3))
.Input(0, "src_tokens", "Encoder input ids.", "T")
.Input(1, "cur_tokens", "Decoder input ids.", "T")
.Input(2, "prev_suffix_match_idx", "Previous suffix match index", "T")
.Input(3, "pred_tokens", "Predicted token ids from aggressive decoding", "T", OpSchema::Optional)
.Output(0, "tokens", "Decoder input ids after merging predicted tokens", "T")
.Output(1, "suffix_match_idx", "new suffix match index", "T")
.TypeConstraint("T", {"tensor(int64)"}, "Constrain to integer types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
propagateElemTypeFromInputToOutput(ctx, 2, 1);
if (hasInputShape(ctx, 2)) {
propagateShapeFromInputToOutput(ctx, 2, 1);
}
// output tokens lengths is dynamic as it depends on the bifurcation index of predicted tokens and source tokens,
// and current tokens length.
// tokens_length = cur_tokens_length + bifurcation_index + 1.
});
}
void RegisterContribSchemas() {

View file

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
TEST(BifurcationDetectorTest, Test1) {
OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain);
tester.AddInput<int64_t>("src_tokens", {4}, {1, 5, 3, 4});
tester.AddInput<int64_t>("cur_tokens", {1}, {2});
tester.AddInput<int64_t>("prev_suffix_match_idx", {}, {0});
tester.AddInput<int64_t>("pred_tokens", {5}, {1, 5, 3, 4, 2});
tester.AddOutput<int64_t>("tokens", {6}, {2, 1, 5, 3, 4, 2});
tester.AddOutput<int64_t>("suffix_match_idx", {}, {-1});
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
TEST(BifurcationDetectorTest, Test2) {
OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain);
tester.AddInput<int64_t>("src_tokens", {26}, {756, 194, 39, 1015, 5529, 1216, 24, 72, 23, 1976, 6174, 1340,
6, 39, 194, 2161, 1480, 4955, 8, 7806, 65, 1091, 8, 560,
4077, 196});
tester.AddInput<int64_t>("cur_tokens", {6}, {2, 756, 194, 39, 8155, 23});
tester.AddInput<int64_t>("find_end_idx", {}, {0});
tester.AddOutput<int64_t>("tokens", {6}, {2, 756, 194, 39, 8155, 23});
tester.AddOutput<int64_t>("new_end_idx", {}, {9});
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
} // namespace test
} // namespace onnxruntime

View file

@ -67,6 +67,10 @@
"BiasGelu com.microsoft CPUExecutionProvider",
12457646955212583504
],
[
"BifurcationDetector com.microsoft CPUExecutionProvider",
12148442056374193608
],
[
"CDist com.microsoft CPUExecutionProvider",
889036143745127232