LLVM16 compat changes (#20294)

The change is similar to #15672 and #11667, for making the code compatible with CUDA 12 and LLVM16 on Mariner2.
This commit is contained in:
Changming Sun 2024-04-12 10:16:12 -07:00 committed by GitHub
parent cd7112f800
commit 794d39a977
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 7 additions and 6 deletions

View file

@ -304,7 +304,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
int m = parameters.token_count;
int n = parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size;
int k = parameters.input_hidden_size;
gemm_buffer = this->GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
gemm_buffer = this->template GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
cublasHandle_t cublas = this->GetCublasHandle(context);
@ -328,7 +328,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
false,
use_memory_efficient_attention,
no_qkv_workspace);
auto work_space = this->GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
typedef typename ToCudaType<T>::MappedType CudaT;
PackedAttentionData<CudaT> data;

View file

@ -298,7 +298,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
use_flash_attention,
use_memory_efficient_attention,
no_qkv_workspace);
auto work_space = this->GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
typedef typename ToCudaType<T>::MappedType CudaT;
PackedMultiHeadAttentionData<CudaT> data;

View file

@ -4,6 +4,7 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h"
@ -130,14 +131,14 @@ class MoEBase {
fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ",
fc3_experts_weights_optional->Shape().GetDims(), " and ", fc1_experts_weights_dims);
fc3_experts_weights_optional->Shape(), " and ", TensorShape(fc1_experts_weights_dims));
}
if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr &&
fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ",
fc3_experts_bias_optional->Shape().GetDims(), " and ", fc1_experts_bias_optional->Shape().GetDims());
fc3_experts_bias_optional->Shape(), " and ", fc1_experts_bias_optional->Shape());
}
parameters.num_rows = num_rows;
@ -199,7 +200,7 @@ class MoEBase {
if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc3_experts_scales must be equal to fc1_experts_scales, got ",
fc3_experts_scales->Shape().GetDims(), " and ", fc1_experts_scales_dims);
fc3_experts_scales->Shape(), " and ", TensorShape(fc1_experts_scales_dims));
}
return Status::OK();