mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Merge remote-tracking branch 'upstream/main' into snnn/vcpkg2
This commit is contained in:
commit
9634ab5f24
9 changed files with 581 additions and 15 deletions
|
|
@ -14,6 +14,13 @@ pipelines:
|
|||
lastModifiedDate: 2024-10-25
|
||||
armory:
|
||||
lastModifiedDate: 2024-10-25
|
||||
binary:
|
||||
credscan:
|
||||
lastModifiedDate: 2025-02-06
|
||||
binskim:
|
||||
lastModifiedDate: 2025-02-06
|
||||
spotbugs:
|
||||
lastModifiedDate: 2025-02-06
|
||||
usedNonDefaultBranch: true
|
||||
1299:
|
||||
retail:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,20 @@
|
|||
"createdDate": "2024-11-13 11:20:17Z",
|
||||
"expirationDate": "2025-05-02 11:55:15Z",
|
||||
"justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 11:55:15Z"
|
||||
},
|
||||
"6f6606e50e82b2d3c823c435151f4b69c1fbde92f274753b793d948856cfc462": {
|
||||
"signature": "6f6606e50e82b2d3c823c435151f4b69c1fbde92f274753b793d948856cfc462",
|
||||
"alternativeSignatures": [],
|
||||
"target": "ScanTelemetry_20250206154816289.json",
|
||||
"line": 1,
|
||||
"memberOf": [
|
||||
"default"
|
||||
],
|
||||
"tool": "credscan",
|
||||
"ruleId": "CSCAN-AZURE0130",
|
||||
"createdDate": "2025-02-06 15:53:46Z",
|
||||
"expirationDate": "2025-07-26 16:26:55Z",
|
||||
"justification": "This error is baselined with an expiration date of 180 days from 2025-02-06 16:26:55Z"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
temperature = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// The following parameter is read from environment variable for testing purpose.
|
||||
use_fast_topk = ParseEnvironmentVariableWithDefault<bool>(kBeamSearchUseFastTopK, true);
|
||||
}
|
||||
|
||||
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
|
||||
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
|
||||
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
|
||||
|
|
|
|||
|
|
@ -199,8 +199,14 @@ struct IGenerationParameters {
|
|||
int extra_decoding_ids_input_id = -1;
|
||||
int cross_qk_output_id = -1;
|
||||
int no_speech_probs_output_id = -1;
|
||||
|
||||
// Parameter for testing slow topk path. It can be updated by the below environment variable.
|
||||
bool use_fast_topk = true;
|
||||
};
|
||||
|
||||
// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled).
|
||||
constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK";
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -524,7 +524,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
|
||||
}
|
||||
|
||||
if (num_beams <= 32) {
|
||||
gsl::span<float> scores_to_process = beam_state->next_scores;
|
||||
if (parameters->use_fast_topk && num_beams <= 32) {
|
||||
constexpr size_t max_parts_of_vocab = 128;
|
||||
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
|
||||
float* topk_tmp_buffer = beam_state->topk_buffer.data();
|
||||
|
|
@ -546,13 +547,6 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
beam_state->next_tokens.data(),
|
||||
beam_state->next_indices.data(),
|
||||
cuda_stream);
|
||||
|
||||
// Select [batch_size, 2 * num_beams] from [batch_size * num_beams, 2 * num_beams]
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, 2 * num_beams);
|
||||
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
|
||||
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
|
||||
#endif
|
||||
} else {
|
||||
// Apply top-k selection like the following:
|
||||
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
|
@ -588,18 +582,20 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
|
||||
batch_size, top_k, vocab_size, cuda_stream);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("next_scores before scorer", topk_scores->Data<float>(), batch_size, top_k);
|
||||
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
|
||||
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
|
||||
#endif
|
||||
scores_to_process = gsl::span<float>(topk_scores->MutableData<float>(), batch_size * top_k);
|
||||
}
|
||||
|
||||
// gsl::span doesn't convert from non const to const, so all we're doing here is making each const.
|
||||
gsl::span<const float> next_scores(beam_state->next_scores.data(), beam_state->next_scores.size());
|
||||
gsl::span<const float> next_scores(scores_to_process.data(), scores_to_process.size());
|
||||
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
|
||||
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("next_scores before scorer", next_scores.data(), batch_size, 2 * num_beams);
|
||||
dumper->Print("next_tokens before scorer", next_tokens.data(), batch_size, 2 * num_beams);
|
||||
dumper->Print("next_indices before scorer", next_indices.data(), batch_size, 2 * num_beams);
|
||||
#endif
|
||||
|
||||
beam_scorer->Process(
|
||||
*sequences,
|
||||
next_scores,
|
||||
|
|
@ -735,6 +731,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
|
|||
next_tokens,
|
||||
next_indices,
|
||||
stream_);
|
||||
|
||||
CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_));
|
||||
|
||||
cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,
|
||||
|
|
|
|||
453
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Normal file
453
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
#include "contrib_ops/webgpu/bert/flash_attention.h"
|
||||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
|
||||
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
using namespace onnxruntime::webgpu;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::contrib::multihead_attention_helper;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
// Expectations are
|
||||
// qkv have same number of heads and hidden dimension (head size).
|
||||
// qkv are in BSNH format.
|
||||
// B - batch size but shader only supports batch_size 1.
|
||||
// S - current sequence length but shader supports only S = 1.
|
||||
// N - number of heads.
|
||||
// H - head size or hidden dimension for each qkv head.
|
||||
// KV cache is stored as BN(total_sequence_length)H
|
||||
// Attention bias is in BN(total_sequence_length)
|
||||
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
if (has_past_) {
|
||||
shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
}
|
||||
shader.AddOutput("present_key", ShaderUsage::UseUniform);
|
||||
shader.AddOutput("present_value", ShaderUsage::UseUniform);
|
||||
|
||||
shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n"
|
||||
<< "let kIdx = workgroup_id.x;\n"
|
||||
<< "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n";
|
||||
if (has_past_) {
|
||||
shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n"
|
||||
<< " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"
|
||||
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
|
||||
<< " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n"
|
||||
<< " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n"
|
||||
<< " }\n"
|
||||
<< "}\n"
|
||||
<< "else if (kIdx >= uniforms.past_sequence_length) {\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n";
|
||||
}
|
||||
shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n"
|
||||
<< " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n"
|
||||
<< " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n"
|
||||
<< " // Assumes kv have BNSH layout.\n"
|
||||
<< " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n"
|
||||
<< " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n"
|
||||
<< " present_key[presentKeyOffset+w] = key[nOffset+w];\n"
|
||||
<< " present_value[presentKeyOffset+w] = value[nOffset+w];\n"
|
||||
<< " }\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
|
||||
const Tensor* K, const Tensor* past_key, Tensor* present_key,
|
||||
const Tensor* V, const Tensor* past_value, Tensor* present_value,
|
||||
int past_sequence_length, int total_sequence_length) {
|
||||
// CopyKVCache takes past key/value and current key/value and copies them to present key and value.
|
||||
// This makes it so that FlashAttention only needs to look at present key and value, and saves
|
||||
// number of input buffers in the shader, which we run out of (<=8) without this optimization.
|
||||
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
|
||||
bool has_past = (past_sequence_length != 0);
|
||||
CopyKVCacheProgram program{"CopyKVCache", has_past};
|
||||
if (has_past) {
|
||||
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{V, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{past_key, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{past_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
|
||||
} else {
|
||||
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
|
||||
}
|
||||
|
||||
program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components},
|
||||
{present_value, ProgramTensorMetadataDependency::Rank, components}});
|
||||
|
||||
program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads_)
|
||||
.SetWorkgroupSize(1)
|
||||
.CacheHint(std::to_string(components) + std::to_string(has_past))
|
||||
.AddUniformVariables({{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
|
||||
{static_cast<uint32_t>(parameters.head_size_ / components)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
// Expectations are
|
||||
// qkv have same number of heads and hidden dimension (head size).
|
||||
// qkv are in BSNH format.
|
||||
// B - batch size but shader only supports batch_size 1.
|
||||
// S - current sequence length but shader supports only S = 1.
|
||||
// N - number of heads.
|
||||
// H - head size or hidden dimension for each qkv head.
|
||||
// KV cache is stored as BN(total_sequence_length)H
|
||||
// Attention bias is in BN(new_sequence_length)(total_sequence_length)
|
||||
//
|
||||
// Expectation is that present_key, and present_value contain past key and values since
|
||||
// we are out of storage buffers a shader can have and both past/present cant be passed.
|
||||
// The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads.
|
||||
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AddInput("present_key", ShaderUsage::UseUniform);
|
||||
shader.AddInput("present_value", ShaderUsage::UseUniform);
|
||||
if (has_attention_bias_) {
|
||||
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
|
||||
}
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform);
|
||||
|
||||
shader.AdditionalImplementation() << "const qkv_head_size: u32 = " << qkv_head_size_ << ";\n"
|
||||
<< "const num_heads: u32 =" << qkv_num_heads_ << ";\n";
|
||||
|
||||
shader.AdditionalImplementation() << R"HELPER_FN(
|
||||
// For max performance max_k_step should be the same as sg_size, however we might run out of registers
|
||||
// for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16).
|
||||
const max_k_step: u32 = 16u;
|
||||
const vec_factor: u32 = 4u;
|
||||
const qkv_head_size_vec: u32 = qkv_head_size / vec_factor;
|
||||
const min_value : q_element_t = q_element_t(-65504.0);
|
||||
|
||||
// Default SHM usage limit is 16KB in Dawn.
|
||||
var<workgroup> k_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB.
|
||||
var<workgroup> v_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB.
|
||||
|
||||
// Private memory per lane.
|
||||
var<private> q_tile : array<q_value_t, qkv_head_size_vec>;
|
||||
var<private> o_tile : array<q_value_t, qkv_head_size_vec>;
|
||||
fn loadq(q_idx_global : u32, head_idx: u32)
|
||||
{
|
||||
// Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA
|
||||
// This is the layout if TransferBSDToBNSH has not been run.
|
||||
let offset = q_idx_global * (qkv_head_size_vec) * num_heads + qkv_head_size_vec * head_idx;
|
||||
// Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run.
|
||||
//let offset = head_idx * uniforms.new_sequence_length * qkv_head_size_vec + q_idx_global * qkv_head_size_vec;
|
||||
for (var idx:u32 = 0; idx < qkv_head_size_vec; idx++)
|
||||
{
|
||||
q_tile[idx] = q[idx+offset];
|
||||
}
|
||||
}
|
||||
fn loadk(k_start : u32, head_idx: u32, local_idx: u32, k_step: u32)
|
||||
{
|
||||
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
|
||||
let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + k_start * qkv_head_size_vec;
|
||||
for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x)
|
||||
{
|
||||
let slot = u32(idx/qkv_head_size_vec);
|
||||
let val = select(q_value_t(0), present_key[offset+idx], k_start + slot < uniforms.present_sequence_length);
|
||||
k_tile[slot][idx%qkv_head_size_vec] = val;
|
||||
}
|
||||
}
|
||||
fn loadv(v_start : u32, head_idx: u32, local_idx: u32, k_step: u32)
|
||||
{
|
||||
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
|
||||
let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + v_start * qkv_head_size_vec;
|
||||
for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x)
|
||||
{
|
||||
let slot = u32(idx/qkv_head_size_vec);
|
||||
let val = select(q_value_t(0), present_value[offset+idx], v_start + slot < uniforms.present_sequence_length);
|
||||
v_tile[slot][idx%qkv_head_size_vec] = val;
|
||||
}
|
||||
}
|
||||
fn writeo(o_idx_global: u32, head_idx: u32)
|
||||
{
|
||||
// Stored as float16[batch_size,sequence_length,3072]
|
||||
let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec;
|
||||
for (var idx:u32 = 0; idx < qkv_head_size_vec; idx ++)
|
||||
{
|
||||
output[offset+idx] = o_tile[idx];
|
||||
}
|
||||
}
|
||||
)HELPER_FN";
|
||||
|
||||
if (has_attention_bias_) {
|
||||
shader.AdditionalImplementation() << R"HELPER_FN(
|
||||
fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4<q_element_t>
|
||||
{
|
||||
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
|
||||
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) {
|
||||
return vec4<q_element_t>(0);
|
||||
}
|
||||
let offset_base = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length;
|
||||
let offset = offset_base + k_idx_global;
|
||||
let offset_max = offset_base + uniforms.present_sequence_length;
|
||||
let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
|
||||
let c2 = q_element_t(attention_bias[min(offset+1, offset_max)]);
|
||||
let c3 = q_element_t(attention_bias[min(offset+2, offset_max)]);
|
||||
let c4 = q_element_t(attention_bias[min(offset+3, offset_max)]);
|
||||
return vec4<q_element_t>(c1,c2,c3,c4);
|
||||
}
|
||||
)HELPER_FN";
|
||||
} else {
|
||||
shader.AdditionalImplementation() << R"HELPER_FN(
|
||||
fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4<q_element_t>
|
||||
{
|
||||
return vec4<q_element_t>(0);
|
||||
}
|
||||
)HELPER_FN";
|
||||
}
|
||||
|
||||
// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / workgroup_size_x, 1)
|
||||
// Each lane/thread is responsible for a single q.
|
||||
shader.MainFunctionBody() << R"MAIN_FN(
|
||||
let head_idx = workgroup_id.x;
|
||||
let capped_sg_id = min(sg_id, max_k_step);
|
||||
let capped_sg_size = min(sg_size, max_k_step);
|
||||
|
||||
// Load Q
|
||||
let q_idx_global = workgroup_id.y * workgroup_size_x + local_idx;
|
||||
let valid_q = q_idx_global < uniforms.new_sequence_length;
|
||||
if (valid_q)
|
||||
{
|
||||
loadq(q_idx_global, head_idx);
|
||||
}
|
||||
|
||||
var previous_max : q_element_t = min_value;
|
||||
var previous_denom : q_element_t = 0;
|
||||
|
||||
for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=capped_sg_size)
|
||||
{
|
||||
workgroupBarrier();
|
||||
loadk(k_start, head_idx, local_idx, capped_sg_size);
|
||||
loadv(k_start, head_idx, local_idx, capped_sg_size);
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute QKt
|
||||
var qk_1:vec4<q_element_t>;
|
||||
var qk_2:vec4<q_element_t>;
|
||||
var qk_3:vec4<q_element_t>;
|
||||
var qk_4:vec4<q_element_t>;
|
||||
if (sg_size > 8)
|
||||
{
|
||||
for (var i:u32 = 0u; i < qkv_head_size_vec; i++)
|
||||
{
|
||||
var k_local = k_tile[capped_sg_id][i];
|
||||
var q_own = q_tile[i];
|
||||
qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0));
|
||||
qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1));
|
||||
qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2));
|
||||
qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3));
|
||||
qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4));
|
||||
qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5));
|
||||
qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6));
|
||||
qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7));
|
||||
qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8));
|
||||
qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9));
|
||||
qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10));
|
||||
qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11));
|
||||
qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12));
|
||||
qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13));
|
||||
qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14));
|
||||
qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i:u32 = 0u; i < qkv_head_size_vec; i++)
|
||||
{
|
||||
var k_local = k_tile[capped_sg_id][i];
|
||||
var q_own = q_tile[i];
|
||||
qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0));
|
||||
qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1));
|
||||
qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2));
|
||||
qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3));
|
||||
qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4));
|
||||
qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5));
|
||||
qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6));
|
||||
qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7));
|
||||
}
|
||||
}
|
||||
|
||||
qk_1 = qk_1 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start, head_idx);
|
||||
qk_2 = qk_2 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+4, head_idx);
|
||||
if (sg_size > 8)
|
||||
{
|
||||
qk_3 = qk_3 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+8, head_idx);
|
||||
qk_4 = qk_4 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+12, head_idx);
|
||||
}
|
||||
|
||||
// Neuter qk values where K is out of bounds.
|
||||
qk_1[1] = select(min_value, qk_1[1], k_start+1 < uniforms.present_sequence_length);
|
||||
qk_1[2] = select(min_value, qk_1[2], k_start+2 < uniforms.present_sequence_length);
|
||||
qk_1[3] = select(min_value, qk_1[3], k_start+3 < uniforms.present_sequence_length);
|
||||
qk_2[0] = select(min_value, qk_2[0], k_start+4 < uniforms.present_sequence_length);
|
||||
qk_2[1] = select(min_value, qk_2[1], k_start+5 < uniforms.present_sequence_length);
|
||||
qk_2[2] = select(min_value, qk_2[2], k_start+6 < uniforms.present_sequence_length);
|
||||
qk_2[3] = select(min_value, qk_2[3], k_start+7 < uniforms.present_sequence_length);
|
||||
if (sg_size > 8)
|
||||
{
|
||||
qk_3[0] = select(min_value, qk_3[0], k_start+8 < uniforms.present_sequence_length);
|
||||
qk_3[1] = select(min_value, qk_3[1], k_start+9 < uniforms.present_sequence_length);
|
||||
qk_3[2] = select(min_value, qk_3[2], k_start+10 < uniforms.present_sequence_length);
|
||||
qk_3[3] = select(min_value, qk_3[3], k_start+11 < uniforms.present_sequence_length);
|
||||
qk_4[0] = select(min_value, qk_4[0], k_start+12 < uniforms.present_sequence_length);
|
||||
qk_4[1] = select(min_value, qk_4[1], k_start+13 < uniforms.present_sequence_length);
|
||||
qk_4[2] = select(min_value, qk_4[2], k_start+14 < uniforms.present_sequence_length);
|
||||
qk_4[3] = select(min_value, qk_4[3], k_start+15 < uniforms.present_sequence_length);
|
||||
}
|
||||
|
||||
//
|
||||
// Compute SoftMax as per Flash Attention technique.
|
||||
//
|
||||
// Crux of Flash Attention is here, that allows for partial softmax computation,
|
||||
// direct update of output and merging with previous results.
|
||||
// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
|
||||
// Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile.
|
||||
// mi_local is the max of Xi. Note: _ in this notation means what follows is a
|
||||
// subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b.
|
||||
//
|
||||
// for i = 1, #tiles do
|
||||
// Xi = Q[k,:] Kt[:, (i-1) b : i b]
|
||||
// mi_local= max_j=1:b (Xi[j])
|
||||
// Mi = max(M_(i-1), mi_local)
|
||||
// d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi)
|
||||
// o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:]
|
||||
// end
|
||||
//
|
||||
// In the code below:
|
||||
// dleft is the first term of d'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i).
|
||||
// sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi)
|
||||
// o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i
|
||||
//
|
||||
var local_max_temp = max(qk_1, qk_2);
|
||||
if (sg_size > 8)
|
||||
{
|
||||
local_max_temp = max(local_max_temp, qk_3);
|
||||
local_max_temp = max(local_max_temp, qk_4);
|
||||
}
|
||||
let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w));
|
||||
let new_max = max(previous_max, local_max);
|
||||
qk_1 = q_value_t(exp(vec4<f32>(qk_1) - f32(new_max)));
|
||||
qk_2 = q_value_t(exp(vec4<f32>(qk_2) - f32(new_max)));
|
||||
if (sg_size > 8) {
|
||||
qk_3 = q_value_t(exp(vec4<f32>(qk_3) - f32(new_max)));
|
||||
qk_4 = q_value_t(exp(vec4<f32>(qk_4) - f32(new_max)));
|
||||
}
|
||||
let sum_vec = qk_1 + qk_2 + qk_3 + qk_4;
|
||||
let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w;
|
||||
// Compute lhs term of update di prime and the compute di prime.
|
||||
let dleft = previous_denom * exp(previous_max-new_max);
|
||||
var d = dleft + sum;
|
||||
d = select(d,q_element_t(0.0000001),d==0);
|
||||
qk_1 = qk_1 / d;
|
||||
qk_2 = qk_2 / d;
|
||||
if (sg_size > 8) {
|
||||
qk_3 = qk_3 / d;
|
||||
qk_4 = qk_4 / d;
|
||||
}
|
||||
previous_max = new_max;
|
||||
previous_denom = d;
|
||||
let o_ratio = dleft / d;
|
||||
|
||||
if (sg_size > 8) {
|
||||
for (var i:u32 = 0; i < qkv_head_size_vec; i++)
|
||||
{
|
||||
var val = select(vec4<q_element_t>(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < uniforms.present_sequence_length);
|
||||
var sum = subgroupShuffle(val, 0) * qk_1[0];
|
||||
sum += subgroupShuffle(val, 1) * qk_1[1];
|
||||
sum += subgroupShuffle(val, 2) * qk_1[2];
|
||||
sum += subgroupShuffle(val, 3) * qk_1[3];
|
||||
sum += subgroupShuffle(val, 4) * qk_2[0];
|
||||
sum += subgroupShuffle(val, 5) * qk_2[1];
|
||||
sum += subgroupShuffle(val, 6) * qk_2[2];
|
||||
sum += subgroupShuffle(val, 7) * qk_2[3];
|
||||
sum += subgroupShuffle(val, 8) * qk_3[0];
|
||||
sum += subgroupShuffle(val, 9) * qk_3[1];
|
||||
sum += subgroupShuffle(val, 10) * qk_3[2];
|
||||
sum += subgroupShuffle(val, 11) * qk_3[3];
|
||||
sum += subgroupShuffle(val, 12) * qk_4[0];
|
||||
sum += subgroupShuffle(val, 13) * qk_4[1];
|
||||
sum += subgroupShuffle(val, 14) * qk_4[2];
|
||||
sum += subgroupShuffle(val, 15) * qk_4[3];
|
||||
o_tile[i] = o_tile[i] * o_ratio + sum;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i:u32 = 0; i < qkv_head_size_vec; i++)
|
||||
{
|
||||
var val = select(vec4<q_element_t>(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < uniforms.present_sequence_length);
|
||||
var sum = subgroupShuffle(val, 0) * qk_1[0];
|
||||
sum += subgroupShuffle(val, 1) * qk_1[1];
|
||||
sum += subgroupShuffle(val, 2) * qk_1[2];
|
||||
sum += subgroupShuffle(val, 3) * qk_1[3];
|
||||
sum += subgroupShuffle(val, 4) * qk_2[0];
|
||||
sum += subgroupShuffle(val, 5) * qk_2[1];
|
||||
sum += subgroupShuffle(val, 6) * qk_2[2];
|
||||
sum += subgroupShuffle(val, 7) * qk_2[3];
|
||||
o_tile[i] = o_tile[i] * o_ratio + sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (valid_q) {
|
||||
writeo(q_idx_global, head_idx);
|
||||
}
|
||||
)MAIN_FN";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
|
||||
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
|
||||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
|
||||
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length_, parameters.total_sequence_length_));
|
||||
|
||||
const uint32_t tile_size = 64;
|
||||
bool has_attention_bias = attention_bias != nullptr;
|
||||
FlashAttentionProgram program{"FlashAttention", has_attention_bias, parameters.head_size_, parameters.num_heads_};
|
||||
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
|
||||
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
|
||||
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4},
|
||||
{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}});
|
||||
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}});
|
||||
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
|
||||
: parameters.scale_;
|
||||
std::string cache_hint = std::to_string(has_attention_bias) +
|
||||
std::to_string(parameters.head_size_) +
|
||||
std::to_string(parameters.num_heads_);
|
||||
program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1)
|
||||
.SetWorkgroupSize(tile_size)
|
||||
.CacheHint(cache_hint)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
|
||||
{static_cast<uint32_t>(parameters.total_sequence_length_)},
|
||||
{alpha}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
|
||||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
|
||||
return parameters.batch_size_ == 1 &&
|
||||
bias == nullptr &&
|
||||
parameters.sequence_length_ > 1 &&
|
||||
context.Device().HasFeature(wgpu::FeatureName::Subgroups) &&
|
||||
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
|
||||
present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0;
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
66
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Normal file
66
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
#include "core/providers/webgpu/compute_context.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
using namespace onnxruntime::webgpu;
|
||||
|
||||
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
|
||||
public:
|
||||
CopyKVCacheProgram(const std::string& kernel_name, bool has_past)
|
||||
: Program{kernel_name}, has_past_(has_past) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"vectorized_head_size", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool has_past_;
|
||||
};
|
||||
|
||||
class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
|
||||
public:
|
||||
FlashAttentionProgram(const std::string& kernel_name,
|
||||
bool has_attention_bias,
|
||||
int qkv_head_size,
|
||||
int qkv_num_heads)
|
||||
: Program{kernel_name},
|
||||
has_attention_bias_(has_attention_bias),
|
||||
qkv_head_size_(qkv_head_size),
|
||||
qkv_num_heads_(qkv_num_heads) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"alpha", ProgramUniformVariableDataType::Float32});
|
||||
|
||||
private:
|
||||
bool has_attention_bias_;
|
||||
int qkv_head_size_;
|
||||
int qkv_num_heads_;
|
||||
};
|
||||
|
||||
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
|
||||
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
|
||||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
|
||||
|
||||
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
|
||||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
#include "contrib_ops/webgpu/bert/multihead_attention.h"
|
||||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
|
||||
#include "contrib_ops/webgpu/bert/flash_attention.h"
|
||||
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
|
|
@ -74,6 +75,11 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
|
|||
Tensor* present_key = context.Output(1, present_shape);
|
||||
Tensor* present_value = context.Output(2, present_shape);
|
||||
|
||||
if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
|
||||
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
|
||||
present_value, parameters, context);
|
||||
}
|
||||
|
||||
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, parameters.head_size_});
|
||||
TensorShape q_new_shape(q_new_dims);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/providers/model_tester.h"
|
||||
#include "test/util/include/current_test_name.h"
|
||||
#include "test/util/include/scoped_env_vars.h"
|
||||
#include "contrib_ops/cpu/transformers/generation_shared.h"
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_provider_options.h"
|
||||
|
|
@ -19,7 +21,7 @@ extern std::unique_ptr<Ort::Env> ort_env;
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(BeamSearchTest, GptBeamSearchFp32) {
|
||||
void RunGptBeamSearchFp32() {
|
||||
std::vector<int64_t> input_ids_shape{3, 12};
|
||||
std::vector<int32_t> input_ids{
|
||||
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
|
||||
|
|
@ -107,6 +109,16 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {
|
|||
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
|
||||
}
|
||||
|
||||
TEST(BeamSearchTest, GptBeamSearchFp32) {
|
||||
RunGptBeamSearchFp32();
|
||||
}
|
||||
|
||||
TEST(BeamSearchTest, GptBeamSearchFp32_DisableFastTopK) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{{onnxruntime::contrib::transformers::kBeamSearchUseFastTopK, "0"}}};
|
||||
RunGptBeamSearchFp32();
|
||||
}
|
||||
|
||||
TEST(BeamSearchTest, GptBeamSearchFp16) {
|
||||
std::vector<int64_t> input_ids_shape{3, 12};
|
||||
std::vector<int32_t> input_ids{
|
||||
|
|
|
|||
Loading…
Reference in a new issue