diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9b8f1e6c7f..87cd91d902 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -8,6 +8,7 @@ Do not modify directly.*
* com.microsoft.BiasDropout
* com.microsoft.BiasGelu
* com.microsoft.BiasSoftmax
+ * com.microsoft.BifurcationDetector
* com.microsoft.CDist
* com.microsoft.ComplexMul
* com.microsoft.ComplexMulConj
@@ -463,6 +464,62 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.BifurcationDetector**
+
+ Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
+ starting from previous suffix match index, and predicted tokens.
+ Concat predicted tokens, starting from bifurcation index, to the back
+ of current tokens. This forms the output tokens.
+ Detect suffix match index in source tokens, between source tokens and output tokens.
+ Detection is based on finding the appearances of last n-gram in output tokens
+ in source tokens.
+ A match is considered found if source tokens contain a single matching n-gram.
+ Return the index of the start of the n-gram in source tokens.
+ No matching if found if src tokens contain multiple or zero matching n-grams. Return -1.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- max_ngram_size : int
+- The maximum NGram size for suffix matching.
+- min_ngram_size : int
+- The minimum NGram size for suffix matching.
+
+
+#### Inputs (3 - 4)
+
+
+- src_tokens : T
+- Encoder input ids.
+- cur_tokens : T
+- Decoder input ids.
+- prev_suffix_match_idx : T
+- Previous suffix match index
+- pred_tokens (optional) : T
+- Predicted token ids from aggressive decoding
+
+
+#### Outputs
+
+
+- tokens : T
+- Decoder input ids after merging predicted tokens
+- suffix_match_idx : T
+- new suffix match index
+
+
+#### Type Constraints
+
+
+- T : tensor(int64)
+- Constrain to integer types.
+
+
+
### **com.microsoft.CDist**
#### Version
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 5099eca81c..2c69f40c5e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -372,6 +372,7 @@ Do not modify directly.*
|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
+|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T** = tensor(float)
**T2** = tensor(int32)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc
new file mode 100644
index 0000000000..bce490d772
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "bifurcation_detector.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+ONNX_OPERATOR_KERNEL_EX(
+ BifurcationDetector,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ BifurcationDetector);
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h
new file mode 100644
index 0000000000..22ea0c0ea9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h
@@ -0,0 +1,116 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+class BifurcationDetector : public OpKernel {
+ public:
+ explicit BifurcationDetector(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(info.GetAttr("min_ngram_size", &min_ngram_size_).IsOK());
+ ORT_ENFORCE(min_ngram_size_ > 0);
+ ORT_ENFORCE(info.GetAttr("max_ngram_size", &max_ngram_size_).IsOK());
+ ORT_ENFORCE(max_ngram_size_ > 0);
+ ORT_ENFORCE(max_ngram_size_ >= min_ngram_size_);
+ }
+
+ Status Compute(OpKernelContext* context) const override {
+ const Tensor* src_tokens = context->Input(0);
+ const Tensor* cur_tokens = context->Input(1);
+ const Tensor* prev_suffix_match_idx = context->Input(2);
+ const Tensor* pred_tokens = context->Input(3);
+ const auto* src_tokens_data = static_cast(src_tokens->DataRaw());
+ const auto* cur_tokens_data = static_cast(cur_tokens->DataRaw());
+ int64_t src_tokens_len = src_tokens->Shape().GetDims().at(0);
+ int64_t cur_tokens_len = cur_tokens->Shape().GetDims().at(0);
+
+ Tensor* out_tokens = nullptr;
+
+ // Find the bifurcation index of predicted tokens,
+ // between source tokens, starting from previous suffix match index,
+ // and predicted tokens.
+ // Concat predicted tokens, starting from bifurcation index, to the back
+ // of current tokens. This forms the output tokens.
+ if (nullptr == pred_tokens) {
+ // No prediction tokens. Output tokens equals to current tokens.
+ out_tokens = context->Output(0, cur_tokens->Shape());
+ auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw());
+ memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t));
+ } else {
+ const auto* pred_tokens_data = static_cast(pred_tokens->DataRaw());
+ const int64_t prev_suffix_match_idx_data = static_cast(prev_suffix_match_idx->DataRaw())[0];
+ int64_t pred_tokens_len = pred_tokens->Shape().GetDims().at(0);
+ // Find bifurcation index between prediction tokens, and source tokens
+ // starting from previous suffix match index.
+ ORT_ENFORCE(src_tokens_len >= prev_suffix_match_idx_data);
+ ORT_ENFORCE(pred_tokens_len == (src_tokens_len + 1 - prev_suffix_match_idx_data));
+ int64_t pred_bifur_idx = 0;
+ for (; pred_bifur_idx < src_tokens_len - prev_suffix_match_idx_data; ++pred_bifur_idx) {
+ if (pred_tokens_data[pred_bifur_idx] != src_tokens_data[pred_bifur_idx + prev_suffix_match_idx_data]) {
+ break;
+ }
+ }
+ // pred_bifur_idx in [0, pred_tokens_len - 1]
+ out_tokens = context->Output(0, TensorShape({cur_tokens_len + pred_bifur_idx + 1}));
+ auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw());
+ memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t));
+ memcpy(out_tokens_data + cur_tokens_len, pred_tokens_data, (pred_bifur_idx + 1) * sizeof(int64_t));
+ }
+
+ // Detect suffix match index in source tokens, between source tokens and output tokens.
+ // Detection is based on finding the appearances of last n-gram in output tokens
+ // in source tokens.
+ // A match is considered found if source tokens contain a single matching n-gram.
+ // Return the index of the start of the n-gram in source tokens.
+ // No matching if found if src tokens contain multiple or zero matching n-grams.
+ // Return -1.
+ int64_t tokens_len = out_tokens->Shape().GetDims().at(0);
+ int64_t min_gram = min_ngram_size_;
+ int64_t max_gram = max_ngram_size_;
+ int64_t suffix_idx = -1;
+ const auto* tokens_data = static_cast(out_tokens->DataRaw());
+ for (int64_t i = min_gram; i < max_gram + 1; ++i) {
+ if (i > tokens_len) {
+ break;
+ }
+ auto it = std::search(
+ src_tokens_data,
+ src_tokens_data + src_tokens_len,
+ tokens_data + tokens_len - i,
+ tokens_data + tokens_len);
+ if (it == (src_tokens_data + src_tokens_len)) {
+ break;
+ } else {
+ suffix_idx = it - src_tokens_data + i;
+ if (suffix_idx >= src_tokens_len) {
+ break;
+ }
+ auto it_2 = std::search(
+ src_tokens_data + suffix_idx - i + 1,
+ src_tokens_data + src_tokens_len,
+ tokens_data + tokens_len - i,
+ tokens_data + tokens_len);
+ if (it_2 != (src_tokens_data + src_tokens_len)) {
+ suffix_idx = -1;
+ continue;
+ }
+ }
+ }
+
+ Tensor* out_suffix_match_idx = context->Output(1, prev_suffix_match_idx->Shape());
+ auto* out_suffix_match_idx_data = static_cast(out_suffix_match_idx->MutableDataRaw());
+ out_suffix_match_idx_data[0] = suffix_idx;
+
+ return Status::OK();
+ }
+
+ private:
+ int64_t min_ngram_size_;
+ int64_t max_ngram_size_;
+};
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 6847d24b23..702b5bcd67 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -38,6 +38,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
#ifdef BUILD_MS_EXPERIMENTAL_OPS
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT);
@@ -208,6 +209,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifdef BUILD_MS_EXPERIMENTAL_OPS
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index c20f3bd478..27b854e397 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -236,7 +236,7 @@ void embedLayerNormalizationShapeInference(InferenceContext& ctx) {
"gamma should have 2 dimension, dimension size known, "
"and same hidden size as word_embedding.");
}
-
+
auto& beta_shape = getInputShape(ctx, 6);
auto& beta_dims = gamma_shape.dim();
if (beta_dims.size() != 1 ||
@@ -772,13 +772,13 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03
.Const("one", 1.0, elem_type)
.Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)")
.Add(R"(
- T1 = Mul (X_bias, X_bias)
- T2 = Mul (c, T1)
- T3 = Add (b, T2)
- T4 = Mul (X_bias, T3)
- T5 = Tanh (T4)
- T6 = Add (one, T5)
- T7 = Mul (X_bias, T6)
+ T1 = Mul (X_bias, X_bias)
+ T2 = Mul (c, T1)
+ T3 = Add (b, T2)
+ T4 = Mul (X_bias, T3)
+ T5 = Tanh (T4)
+ T6 = Add (one, T5)
+ T7 = Mul (X_bias, T6)
Y = Mul (a, T7)
)");
@@ -825,6 +825,43 @@ Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form
}
propagateShapeFromInputToOutput(ctx, 1, 0);
});
+
+ static const char* BifurcationDetector_ver1_doc = R"DOC(
+Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
+starting from previous suffix match index, and predicted tokens.
+Concat predicted tokens, starting from bifurcation index, to the back
+of current tokens. This forms the output tokens.
+Detect suffix match index in source tokens, between source tokens and output tokens.
+Detection is based on finding the appearances of last n-gram in output tokens
+in source tokens.
+A match is considered found if source tokens contain a single matching n-gram.
+Return the index of the start of the n-gram in source tokens.
+No matching if found if src tokens contain multiple or zero matching n-grams. Return -1.
+)DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(BifurcationDetector)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(BifurcationDetector_ver1_doc)
+ .Attr("min_ngram_size", "The minimum NGram size for suffix matching.", AttributeProto::INT, static_cast(1))
+ .Attr("max_ngram_size", "The maximum NGram size for suffix matching.", AttributeProto::INT, static_cast(3))
+ .Input(0, "src_tokens", "Encoder input ids.", "T")
+ .Input(1, "cur_tokens", "Decoder input ids.", "T")
+ .Input(2, "prev_suffix_match_idx", "Previous suffix match index", "T")
+ .Input(3, "pred_tokens", "Predicted token ids from aggressive decoding", "T", OpSchema::Optional)
+ .Output(0, "tokens", "Decoder input ids after merging predicted tokens", "T")
+ .Output(1, "suffix_match_idx", "new suffix match index", "T")
+ .TypeConstraint("T", {"tensor(int64)"}, "Constrain to integer types.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ propagateElemTypeFromInputToOutput(ctx, 2, 1);
+ if (hasInputShape(ctx, 2)) {
+ propagateShapeFromInputToOutput(ctx, 2, 1);
+ }
+ // output tokens lengths is dynamic as it depends on the bifurcation index of predicted tokens and source tokens,
+ // and current tokens length.
+ // tokens_length = cur_tokens_length + bifurcation_index + 1.
+ });
}
void RegisterContribSchemas() {
diff --git a/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc
new file mode 100644
index 0000000000..51829e8cf8
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "gtest/gtest.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+TEST(BifurcationDetectorTest, Test1) {
+ OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain);
+
+ tester.AddInput("src_tokens", {4}, {1, 5, 3, 4});
+ tester.AddInput("cur_tokens", {1}, {2});
+ tester.AddInput("prev_suffix_match_idx", {}, {0});
+ tester.AddInput("pred_tokens", {5}, {1, 5, 3, 4, 2});
+ tester.AddOutput("tokens", {6}, {2, 1, 5, 3, 4, 2});
+ tester.AddOutput("suffix_match_idx", {}, {-1});
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(BifurcationDetectorTest, Test2) {
+ OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain);
+
+ tester.AddInput("src_tokens", {26}, {756, 194, 39, 1015, 5529, 1216, 24, 72, 23, 1976, 6174, 1340,
+ 6, 39, 194, 2161, 1480, 4955, 8, 7806, 65, 1091, 8, 560,
+ 4077, 196});
+ tester.AddInput("cur_tokens", {6}, {2, 756, 194, 39, 8155, 23});
+ tester.AddInput("find_end_idx", {}, {0});
+ tester.AddOutput("tokens", {6}, {2, 756, 194, 39, 8155, 23});
+ tester.AddOutput("new_end_idx", {}, {9});
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
index d9c2666284..bb0f31e904 100644
--- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
+++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
@@ -67,6 +67,10 @@
"BiasGelu com.microsoft CPUExecutionProvider",
12457646955212583504
],
+ [
+ "BifurcationDetector com.microsoft CPUExecutionProvider",
+ 12148442056374193608
+ ],
[
"CDist com.microsoft CPUExecutionProvider",
889036143745127232