From e983f371211fd21bcfa9d25ec1a42f4fd92b648d Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 19 Oct 2021 19:53:56 -0700 Subject: [PATCH] Bifurcation detector for aggressive decoding (#9432) ``` Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, starting from previous suffix match index, and predicted tokens. Concat predicted tokens, starting from bifurcation index, to the back of current tokens. This forms the output tokens. Detect suffix match index in source tokens, between source tokens and output tokens. Detection is based on finding the appearances of last n-gram in output tokens in source tokens. A match is considered found if source tokens contain a single matching n-gram. Return the index of the start of the n-gram in source tokens. No matching if found if src tokens contain multiple or zero matching n-grams. Return -1. ``` --- docs/ContribOperators.md | 57 +++++++++ docs/OperatorKernels.md | 1 + .../cpu/bert/bifurcation_detector.cc | 18 +++ .../cpu/bert/bifurcation_detector.h | 116 ++++++++++++++++++ .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/contrib_defs.cc | 53 ++++++-- .../bifurcation_detector_op_test.cc | 44 +++++++ .../kernel_def_hashes/contrib.cpu.json | 4 + 8 files changed, 287 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc create mode 100644 onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h create mode 100644 onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9b8f1e6c7f..87cd91d902 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -8,6 +8,7 @@ Do not modify directly.* * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax + * com.microsoft.BifurcationDetector * com.microsoft.CDist * com.microsoft.ComplexMul * com.microsoft.ComplexMulConj @@ -463,6 +464,62 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BifurcationDetector** + + Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, + starting from previous suffix match index, and predicted tokens. + Concat predicted tokens, starting from bifurcation index, to the back + of current tokens. This forms the output tokens. + Detect suffix match index in source tokens, between source tokens and output tokens. + Detection is based on finding the appearances of last n-gram in output tokens + in source tokens. + A match is considered found if source tokens contain a single matching n-gram. + Return the index of the start of the n-gram in source tokens. + No matching if found if src tokens contain multiple or zero matching n-grams. Return -1. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
max_ngram_size : int
+
The maximum NGram size for suffix matching.
+
min_ngram_size : int
+
The minimum NGram size for suffix matching.
+
+ +#### Inputs (3 - 4) + +
+
src_tokens : T
+
Encoder input ids.
+
cur_tokens : T
+
Decoder input ids.
+
prev_suffix_match_idx : T
+
Previous suffix match index
+
pred_tokens (optional) : T
+
Predicted token ids from aggressive decoding
+
+ +#### Outputs + +
+
tokens : T
+
Decoder input ids after merging predicted tokens
+
suffix_match_idx : T
+
new suffix match index
+
+ +#### Type Constraints + +
+
T : tensor(int64)
+
Constrain to integer types.
+
+ + ### **com.microsoft.CDist** #### Version diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5099eca81c..2c69f40c5e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -372,6 +372,7 @@ Do not modify directly.* |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| +|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T** = tensor(float)
**T2** = tensor(int32)| diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc new file mode 100644 index 0000000000..bce490d772 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bifurcation_detector.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + BifurcationDetector, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + BifurcationDetector); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h new file mode 100644 index 0000000000..22ea0c0ea9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { +class BifurcationDetector : public OpKernel { + public: + explicit BifurcationDetector(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(info.GetAttr("min_ngram_size", &min_ngram_size_).IsOK()); + ORT_ENFORCE(min_ngram_size_ > 0); + ORT_ENFORCE(info.GetAttr("max_ngram_size", &max_ngram_size_).IsOK()); + ORT_ENFORCE(max_ngram_size_ > 0); + ORT_ENFORCE(max_ngram_size_ >= min_ngram_size_); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* src_tokens = context->Input(0); + const Tensor* cur_tokens = context->Input(1); + const Tensor* prev_suffix_match_idx = context->Input(2); + const Tensor* pred_tokens = context->Input(3); + const auto* src_tokens_data = static_cast(src_tokens->DataRaw()); + const auto* cur_tokens_data = static_cast(cur_tokens->DataRaw()); + int64_t src_tokens_len = src_tokens->Shape().GetDims().at(0); + int64_t cur_tokens_len = cur_tokens->Shape().GetDims().at(0); + + Tensor* out_tokens = nullptr; + + // Find the bifurcation index of predicted tokens, + // between source tokens, starting from previous suffix match index, + // and predicted tokens. + // Concat predicted tokens, starting from bifurcation index, to the back + // of current tokens. This forms the output tokens. + if (nullptr == pred_tokens) { + // No prediction tokens. Output tokens equals to current tokens. + out_tokens = context->Output(0, cur_tokens->Shape()); + auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw()); + memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t)); + } else { + const auto* pred_tokens_data = static_cast(pred_tokens->DataRaw()); + const int64_t prev_suffix_match_idx_data = static_cast(prev_suffix_match_idx->DataRaw())[0]; + int64_t pred_tokens_len = pred_tokens->Shape().GetDims().at(0); + // Find bifurcation index between prediction tokens, and source tokens + // starting from previous suffix match index. + ORT_ENFORCE(src_tokens_len >= prev_suffix_match_idx_data); + ORT_ENFORCE(pred_tokens_len == (src_tokens_len + 1 - prev_suffix_match_idx_data)); + int64_t pred_bifur_idx = 0; + for (; pred_bifur_idx < src_tokens_len - prev_suffix_match_idx_data; ++pred_bifur_idx) { + if (pred_tokens_data[pred_bifur_idx] != src_tokens_data[pred_bifur_idx + prev_suffix_match_idx_data]) { + break; + } + } + // pred_bifur_idx in [0, pred_tokens_len - 1] + out_tokens = context->Output(0, TensorShape({cur_tokens_len + pred_bifur_idx + 1})); + auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw()); + memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t)); + memcpy(out_tokens_data + cur_tokens_len, pred_tokens_data, (pred_bifur_idx + 1) * sizeof(int64_t)); + } + + // Detect suffix match index in source tokens, between source tokens and output tokens. + // Detection is based on finding the appearances of last n-gram in output tokens + // in source tokens. + // A match is considered found if source tokens contain a single matching n-gram. + // Return the index of the start of the n-gram in source tokens. + // No matching if found if src tokens contain multiple or zero matching n-grams. + // Return -1. + int64_t tokens_len = out_tokens->Shape().GetDims().at(0); + int64_t min_gram = min_ngram_size_; + int64_t max_gram = max_ngram_size_; + int64_t suffix_idx = -1; + const auto* tokens_data = static_cast(out_tokens->DataRaw()); + for (int64_t i = min_gram; i < max_gram + 1; ++i) { + if (i > tokens_len) { + break; + } + auto it = std::search( + src_tokens_data, + src_tokens_data + src_tokens_len, + tokens_data + tokens_len - i, + tokens_data + tokens_len); + if (it == (src_tokens_data + src_tokens_len)) { + break; + } else { + suffix_idx = it - src_tokens_data + i; + if (suffix_idx >= src_tokens_len) { + break; + } + auto it_2 = std::search( + src_tokens_data + suffix_idx - i + 1, + src_tokens_data + src_tokens_len, + tokens_data + tokens_len - i, + tokens_data + tokens_len); + if (it_2 != (src_tokens_data + src_tokens_len)) { + suffix_idx = -1; + continue; + } + } + } + + Tensor* out_suffix_match_idx = context->Output(1, prev_suffix_match_idx->Shape()); + auto* out_suffix_match_idx_data = static_cast(out_suffix_match_idx->MutableDataRaw()); + out_suffix_match_idx_data[0] = suffix_idx; + + return Status::OK(); + } + + private: + int64_t min_ngram_size_; + int64_t max_ngram_size_; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 6847d24b23..702b5bcd67 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -38,6 +38,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); #ifdef BUILD_MS_EXPERIMENTAL_OPS class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT); @@ -208,6 +209,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef BUILD_MS_EXPERIMENTAL_OPS BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c20f3bd478..27b854e397 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -236,7 +236,7 @@ void embedLayerNormalizationShapeInference(InferenceContext& ctx) { "gamma should have 2 dimension, dimension size known, " "and same hidden size as word_embedding."); } - + auto& beta_shape = getInputShape(ctx, 6); auto& beta_dims = gamma_shape.dim(); if (beta_dims.size() != 1 || @@ -772,13 +772,13 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03 .Const("one", 1.0, elem_type) .Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)") .Add(R"( - T1 = Mul (X_bias, X_bias) - T2 = Mul (c, T1) - T3 = Add (b, T2) - T4 = Mul (X_bias, T3) - T5 = Tanh (T4) - T6 = Add (one, T5) - T7 = Mul (X_bias, T6) + T1 = Mul (X_bias, X_bias) + T2 = Mul (c, T1) + T3 = Add (b, T2) + T4 = Mul (X_bias, T3) + T5 = Tanh (T4) + T6 = Add (one, T5) + T7 = Mul (X_bias, T6) Y = Mul (a, T7) )"); @@ -825,6 +825,43 @@ Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form } propagateShapeFromInputToOutput(ctx, 1, 0); }); + + static const char* BifurcationDetector_ver1_doc = R"DOC( +Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, +starting from previous suffix match index, and predicted tokens. +Concat predicted tokens, starting from bifurcation index, to the back +of current tokens. This forms the output tokens. +Detect suffix match index in source tokens, between source tokens and output tokens. +Detection is based on finding the appearances of last n-gram in output tokens +in source tokens. +A match is considered found if source tokens contain a single matching n-gram. +Return the index of the start of the n-gram in source tokens. +No matching if found if src tokens contain multiple or zero matching n-grams. Return -1. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(BifurcationDetector) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(BifurcationDetector_ver1_doc) + .Attr("min_ngram_size", "The minimum NGram size for suffix matching.", AttributeProto::INT, static_cast(1)) + .Attr("max_ngram_size", "The maximum NGram size for suffix matching.", AttributeProto::INT, static_cast(3)) + .Input(0, "src_tokens", "Encoder input ids.", "T") + .Input(1, "cur_tokens", "Decoder input ids.", "T") + .Input(2, "prev_suffix_match_idx", "Previous suffix match index", "T") + .Input(3, "pred_tokens", "Predicted token ids from aggressive decoding", "T", OpSchema::Optional) + .Output(0, "tokens", "Decoder input ids after merging predicted tokens", "T") + .Output(1, "suffix_match_idx", "new suffix match index", "T") + .TypeConstraint("T", {"tensor(int64)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + propagateElemTypeFromInputToOutput(ctx, 2, 1); + if (hasInputShape(ctx, 2)) { + propagateShapeFromInputToOutput(ctx, 2, 1); + } + // output tokens lengths is dynamic as it depends on the bifurcation index of predicted tokens and source tokens, + // and current tokens length. + // tokens_length = cur_tokens_length + bifurcation_index + 1. + }); } void RegisterContribSchemas() { diff --git a/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc new file mode 100644 index 0000000000..51829e8cf8 --- /dev/null +++ b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(BifurcationDetectorTest, Test1) { + OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain); + + tester.AddInput("src_tokens", {4}, {1, 5, 3, 4}); + tester.AddInput("cur_tokens", {1}, {2}); + tester.AddInput("prev_suffix_match_idx", {}, {0}); + tester.AddInput("pred_tokens", {5}, {1, 5, 3, 4, 2}); + tester.AddOutput("tokens", {6}, {2, 1, 5, 3, 4, 2}); + tester.AddOutput("suffix_match_idx", {}, {-1}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(BifurcationDetectorTest, Test2) { + OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain); + + tester.AddInput("src_tokens", {26}, {756, 194, 39, 1015, 5529, 1216, 24, 72, 23, 1976, 6174, 1340, + 6, 39, 194, 2161, 1480, 4955, 8, 7806, 65, 1091, 8, 560, + 4077, 196}); + tester.AddInput("cur_tokens", {6}, {2, 756, 194, 39, 8155, 23}); + tester.AddInput("find_end_idx", {}, {0}); + tester.AddOutput("tokens", {6}, {2, 756, 194, 39, 8155, 23}); + tester.AddOutput("new_end_idx", {}, {9}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json index d9c2666284..bb0f31e904 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json @@ -67,6 +67,10 @@ "BiasGelu com.microsoft CPUExecutionProvider", 12457646955212583504 ], + [ + "BifurcationDetector com.microsoft CPUExecutionProvider", + 12148442056374193608 + ], [ "CDist com.microsoft CPUExecutionProvider", 889036143745127232