diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9b8f1e6c7f..87cd91d902 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -8,6 +8,7 @@ Do not modify directly.* * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax + * com.microsoft.BifurcationDetector * com.microsoft.CDist * com.microsoft.ComplexMul * com.microsoft.ComplexMulConj @@ -463,6 +464,62 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BifurcationDetector** + + Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, + starting from previous suffix match index, and predicted tokens. + Concat predicted tokens, starting from bifurcation index, to the back + of current tokens. This forms the output tokens. + Detect suffix match index in source tokens, between source tokens and output tokens. + Detection is based on finding the appearances of last n-gram in output tokens + in source tokens. + A match is considered found if source tokens contain a single matching n-gram. + Return the index of the start of the n-gram in source tokens. + No matching if found if src tokens contain multiple or zero matching n-grams. Return -1. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
max_ngram_size : int
+
The maximum NGram size for suffix matching.
+
min_ngram_size : int
+
The minimum NGram size for suffix matching.
+
+ +#### Inputs (3 - 4) + +
+
src_tokens : T
+
Encoder input ids.
+
cur_tokens : T
+
Decoder input ids.
+
prev_suffix_match_idx : T
+
Previous suffix match index
+
pred_tokens (optional) : T
+
Predicted token ids from aggressive decoding
+
+ +#### Outputs + +
+
tokens : T
+
Decoder input ids after merging predicted tokens
+
suffix_match_idx : T
+
new suffix match index
+
+ +#### Type Constraints + +
+
T : tensor(int64)
+
Constrain to integer types.
+
+ + ### **com.microsoft.CDist** #### Version diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5099eca81c..2c69f40c5e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -372,6 +372,7 @@ Do not modify directly.* |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| +|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T** = tensor(float)
**T2** = tensor(int32)| diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc new file mode 100644 index 0000000000..bce490d772 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.cc @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bifurcation_detector.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + BifurcationDetector, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + BifurcationDetector); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h new file mode 100644 index 0000000000..22ea0c0ea9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/bifurcation_detector.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { +class BifurcationDetector : public OpKernel { + public: + explicit BifurcationDetector(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(info.GetAttr("min_ngram_size", &min_ngram_size_).IsOK()); + ORT_ENFORCE(min_ngram_size_ > 0); + ORT_ENFORCE(info.GetAttr("max_ngram_size", &max_ngram_size_).IsOK()); + ORT_ENFORCE(max_ngram_size_ > 0); + ORT_ENFORCE(max_ngram_size_ >= min_ngram_size_); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* src_tokens = context->Input(0); + const Tensor* cur_tokens = context->Input(1); + const Tensor* prev_suffix_match_idx = context->Input(2); + const Tensor* pred_tokens = context->Input(3); + const auto* src_tokens_data = static_cast(src_tokens->DataRaw()); + const auto* cur_tokens_data = static_cast(cur_tokens->DataRaw()); + int64_t src_tokens_len = src_tokens->Shape().GetDims().at(0); + int64_t cur_tokens_len = cur_tokens->Shape().GetDims().at(0); + + Tensor* out_tokens = nullptr; + + // Find the bifurcation index of predicted tokens, + // between source tokens, starting from previous suffix match index, + // and predicted tokens. + // Concat predicted tokens, starting from bifurcation index, to the back + // of current tokens. This forms the output tokens. + if (nullptr == pred_tokens) { + // No prediction tokens. Output tokens equals to current tokens. + out_tokens = context->Output(0, cur_tokens->Shape()); + auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw()); + memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t)); + } else { + const auto* pred_tokens_data = static_cast(pred_tokens->DataRaw()); + const int64_t prev_suffix_match_idx_data = static_cast(prev_suffix_match_idx->DataRaw())[0]; + int64_t pred_tokens_len = pred_tokens->Shape().GetDims().at(0); + // Find bifurcation index between prediction tokens, and source tokens + // starting from previous suffix match index. + ORT_ENFORCE(src_tokens_len >= prev_suffix_match_idx_data); + ORT_ENFORCE(pred_tokens_len == (src_tokens_len + 1 - prev_suffix_match_idx_data)); + int64_t pred_bifur_idx = 0; + for (; pred_bifur_idx < src_tokens_len - prev_suffix_match_idx_data; ++pred_bifur_idx) { + if (pred_tokens_data[pred_bifur_idx] != src_tokens_data[pred_bifur_idx + prev_suffix_match_idx_data]) { + break; + } + } + // pred_bifur_idx in [0, pred_tokens_len - 1] + out_tokens = context->Output(0, TensorShape({cur_tokens_len + pred_bifur_idx + 1})); + auto* out_tokens_data = static_cast(out_tokens->MutableDataRaw()); + memcpy(out_tokens_data, cur_tokens_data, cur_tokens_len * sizeof(int64_t)); + memcpy(out_tokens_data + cur_tokens_len, pred_tokens_data, (pred_bifur_idx + 1) * sizeof(int64_t)); + } + + // Detect suffix match index in source tokens, between source tokens and output tokens. + // Detection is based on finding the appearances of last n-gram in output tokens + // in source tokens. + // A match is considered found if source tokens contain a single matching n-gram. + // Return the index of the start of the n-gram in source tokens. + // No matching if found if src tokens contain multiple or zero matching n-grams. + // Return -1. + int64_t tokens_len = out_tokens->Shape().GetDims().at(0); + int64_t min_gram = min_ngram_size_; + int64_t max_gram = max_ngram_size_; + int64_t suffix_idx = -1; + const auto* tokens_data = static_cast(out_tokens->DataRaw()); + for (int64_t i = min_gram; i < max_gram + 1; ++i) { + if (i > tokens_len) { + break; + } + auto it = std::search( + src_tokens_data, + src_tokens_data + src_tokens_len, + tokens_data + tokens_len - i, + tokens_data + tokens_len); + if (it == (src_tokens_data + src_tokens_len)) { + break; + } else { + suffix_idx = it - src_tokens_data + i; + if (suffix_idx >= src_tokens_len) { + break; + } + auto it_2 = std::search( + src_tokens_data + suffix_idx - i + 1, + src_tokens_data + src_tokens_len, + tokens_data + tokens_len - i, + tokens_data + tokens_len); + if (it_2 != (src_tokens_data + src_tokens_len)) { + suffix_idx = -1; + continue; + } + } + } + + Tensor* out_suffix_match_idx = context->Output(1, prev_suffix_match_idx->Shape()); + auto* out_suffix_match_idx_data = static_cast(out_suffix_match_idx->MutableDataRaw()); + out_suffix_match_idx_data[0] = suffix_idx; + + return Status::OK(); + } + + private: + int64_t min_ngram_size_; + int64_t max_ngram_size_; +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 6847d24b23..702b5bcd67 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -38,6 +38,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); #ifdef BUILD_MS_EXPERIMENTAL_OPS class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT); @@ -208,6 +209,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef BUILD_MS_EXPERIMENTAL_OPS BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c20f3bd478..27b854e397 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -236,7 +236,7 @@ void embedLayerNormalizationShapeInference(InferenceContext& ctx) { "gamma should have 2 dimension, dimension size known, " "and same hidden size as word_embedding."); } - + auto& beta_shape = getInputShape(ctx, 6); auto& beta_dims = gamma_shape.dim(); if (beta_dims.size() != 1 || @@ -772,13 +772,13 @@ GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.03 .Const("one", 1.0, elem_type) .Add(hasBias ? "X_bias = Add (X, bias)" : "X_bias = Identity (X)") .Add(R"( - T1 = Mul (X_bias, X_bias) - T2 = Mul (c, T1) - T3 = Add (b, T2) - T4 = Mul (X_bias, T3) - T5 = Tanh (T4) - T6 = Add (one, T5) - T7 = Mul (X_bias, T6) + T1 = Mul (X_bias, X_bias) + T2 = Mul (c, T1) + T3 = Add (b, T2) + T4 = Mul (X_bias, T3) + T5 = Tanh (T4) + T6 = Add (one, T5) + T7 = Mul (X_bias, T6) Y = Mul (a, T7) )"); @@ -825,6 +825,43 @@ Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form } propagateShapeFromInputToOutput(ctx, 1, 0); }); + + static const char* BifurcationDetector_ver1_doc = R"DOC( +Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, +starting from previous suffix match index, and predicted tokens. +Concat predicted tokens, starting from bifurcation index, to the back +of current tokens. This forms the output tokens. +Detect suffix match index in source tokens, between source tokens and output tokens. +Detection is based on finding the appearances of last n-gram in output tokens +in source tokens. +A match is considered found if source tokens contain a single matching n-gram. +Return the index of the start of the n-gram in source tokens. +No matching if found if src tokens contain multiple or zero matching n-grams. Return -1. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(BifurcationDetector) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(BifurcationDetector_ver1_doc) + .Attr("min_ngram_size", "The minimum NGram size for suffix matching.", AttributeProto::INT, static_cast(1)) + .Attr("max_ngram_size", "The maximum NGram size for suffix matching.", AttributeProto::INT, static_cast(3)) + .Input(0, "src_tokens", "Encoder input ids.", "T") + .Input(1, "cur_tokens", "Decoder input ids.", "T") + .Input(2, "prev_suffix_match_idx", "Previous suffix match index", "T") + .Input(3, "pred_tokens", "Predicted token ids from aggressive decoding", "T", OpSchema::Optional) + .Output(0, "tokens", "Decoder input ids after merging predicted tokens", "T") + .Output(1, "suffix_match_idx", "new suffix match index", "T") + .TypeConstraint("T", {"tensor(int64)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + propagateElemTypeFromInputToOutput(ctx, 2, 1); + if (hasInputShape(ctx, 2)) { + propagateShapeFromInputToOutput(ctx, 2, 1); + } + // output tokens lengths is dynamic as it depends on the bifurcation index of predicted tokens and source tokens, + // and current tokens length. + // tokens_length = cur_tokens_length + bifurcation_index + 1. + }); } void RegisterContribSchemas() { diff --git a/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc new file mode 100644 index 0000000000..51829e8cf8 --- /dev/null +++ b/onnxruntime/test/contrib_ops/bifurcation_detector_op_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(BifurcationDetectorTest, Test1) { + OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain); + + tester.AddInput("src_tokens", {4}, {1, 5, 3, 4}); + tester.AddInput("cur_tokens", {1}, {2}); + tester.AddInput("prev_suffix_match_idx", {}, {0}); + tester.AddInput("pred_tokens", {5}, {1, 5, 3, 4, 2}); + tester.AddOutput("tokens", {6}, {2, 1, 5, 3, 4, 2}); + tester.AddOutput("suffix_match_idx", {}, {-1}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(BifurcationDetectorTest, Test2) { + OpTester tester("BifurcationDetector", 1, onnxruntime::kMSDomain); + + tester.AddInput("src_tokens", {26}, {756, 194, 39, 1015, 5529, 1216, 24, 72, 23, 1976, 6174, 1340, + 6, 39, 194, 2161, 1480, 4955, 8, 7806, 65, 1091, 8, 560, + 4077, 196}); + tester.AddInput("cur_tokens", {6}, {2, 756, 194, 39, 8155, 23}); + tester.AddInput("find_end_idx", {}, {0}); + tester.AddOutput("tokens", {6}, {2, 756, 194, 39, 8155, 23}); + tester.AddOutput("new_end_idx", {}, {9}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json index d9c2666284..bb0f31e904 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json @@ -67,6 +67,10 @@ "BiasGelu com.microsoft CPUExecutionProvider", 12457646955212583504 ], + [ + "BifurcationDetector com.microsoft CPUExecutionProvider", + 12148442056374193608 + ], [ "CDist com.microsoft CPUExecutionProvider", 889036143745127232