Phi3 MoE cuda kernel (#21819)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Ye Wang 2024-08-27 09:21:30 -07:00 committed by GitHub
parent 252222034f
commit 1d059b8702
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1080 additions and 595 deletions

View file

@ -3083,6 +3083,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>
#### Inputs (5 - 8)
@ -4398,7 +4400,7 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>
Int4 MoE
Quantized MoE
#### Version
@ -4409,10 +4411,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>expert_weight_bits</tt> : int</dt>
<dd>Number of bits used in quantized weights. Default is 4 bits</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>
#### Inputs (7 - 11)
@ -4423,19 +4429,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>

View file

@ -79,7 +79,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),

View file

@ -0,0 +1,31 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4100)
#pragma warning(disable : 4244)
#pragma warning(disable : 4200)
#endif
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
namespace ort_fastertransformer {
template class MoeGemmRunner<half, uint8_t>;
} // namespace ort_fastertransformer

View file

@ -127,7 +127,7 @@ __launch_bounds__(TPB) __global__
const int block_row = blockIdx.x;
const bool should_process_row = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
const int thread_row_offset = blockIdx.x * num_experts;
float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
@ -135,7 +135,7 @@ __launch_bounds__(TPB) __global__
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__
}
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) {
// Does not support pre-Kepler architectures
;
}
#else
template <typename T, int TPB, int NUM_EXPERTS>
__launch_bounds__(TPB) __global__
void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) {
static constexpr int K = 2;
using cub_kvp = cub::KeyValuePair<int, T>;
using KVBlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ float result_kvp_value[K];
__shared__ typename KVBlockReduce::TempStorage kvTmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
int num_rows = gridDim.x;
const int block_row = blockIdx.x;
const int thread_row_offset = blockIdx.x * NUM_EXPERTS;
float factor[K];
bool logits_mask[K];
#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f);
cub_kvp inp_kvp;
#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const int prior_winning_expert = indices[K * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = K * block_row + k_idx;
result_kvp_value[k_idx] = (float)result_kvp.value;
indices[idx] = result_kvp.key;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
#pragma unroll
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
const int idx = thread_row_offset + expert;
factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]);
logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);
if (k_idx == 1 && expert == indices[K * block_row]) {
logits_mask[1] = true;
}
}
}
#pragma unroll
for (int k_idx = 0; k_idx < K; ++k_idx) {
float row_sum(0);
#pragma unroll
for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) {
const int idx = thread_row_offset + ii;
row_sum += logits_mask[k_idx] ? 0 : exp((static_cast<float>(inputs[idx]) - result_kvp_value[k_idx]));
}
#pragma unroll
for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS);
}
const float normalizing_factor = 1.f / row_sum;
const int idx = K * block_row + k_idx;
if (threadIdx.x == indices[idx]) {
const int input_idx = thread_row_offset + threadIdx.x;
output[idx] = logits_mask[k_idx] ? 0
: exp((static_cast<float>(inputs[input_idx]) - result_kvp_value[k_idx])) *
normalizing_factor;
}
}
}
#endif
// ====================== TopK softmax things ===============================
/*
@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T
template <typename T>
void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output,
int *indices, int *source_row, int num_rows, int num_experts, int k,
bool normalize_routing_weights, cudaStream_t stream) {
bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
if (use_sparse_mixer) {
static constexpr int TPB = WARP_SIZE * WARPS_PER_TB;
static constexpr float jitter_eps = 0.01f;
switch (num_experts) {
case 8: {
sparse_mixer_top2<T, TPB, 8><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}
case 16: {
sparse_mixer_top2<T, TPB, 16><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
break;
}
default: {
ORT_THROW("Sparse mixer only supports 8 and 16 experts");
}
}
return;
}
switch (num_experts) {
case 2: {
topk_gating_softmax_launcher_helper<T, 2, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i
template <typename T, typename WeightType, typename Enable>
CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version, bool has_fc3,
bool normalize_routing_weights)
bool normalize_routing_weights, bool use_sparse_mixer)
: has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0),
normalize_routing_weights_(normalize_routing_weights) {
normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) {
moe_gemm_runner_.initialize(sm_version);
}
@ -729,7 +851,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
configure_ws_ptrs(workspace_ptr, static_cast<size_t>(num_rows), static_cast<size_t>(hidden_size),
static_cast<size_t>(inter_size), static_cast<size_t>(num_experts), static_cast<size_t>(k));
topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
source_rows_, num_rows, num_experts, k, normalize_routing_weights_,
use_sparse_mixer_, stream);
const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
sorter_.run(reinterpret_cast<void *>(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_,
@ -748,7 +871,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
stream);
}
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows);
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size,
// expanded_active_expert_rows);
moe_gemm_runner_.moe_gemm_bias_act(
permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases,
fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index,
@ -868,9 +992,9 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
// experts in the end.
// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0,
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input
// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus
// of the expanded index.
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ...
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we
// simply take the modulus of the expanded index.
template <typename T>
__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output,
@ -878,9 +1002,9 @@ __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *perm
int *expanded_source_row_to_expanded_dest_row, int num_rows,
int active_rows, int cols) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need
// the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in
// MoE. 1 thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x;
const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red
// ========================= TopK Softmax specializations ===========================
template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int,
int, bool, cudaStream_t);
int, bool, bool, cudaStream_t);
// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;
template class CutlassMoeFCRunner<half, half>;
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
template class CutlassMoeFCRunner<half, uint8_t>;
// ===================== Specializations for init routing =========================
template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int,
@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const
template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *,
const half *, const int *, const int *, int, int, int, cudaStream_t);
} // namespace ort_fastertransformer
} // namespace ort_fastertransformer

View file

@ -109,7 +109,7 @@ template <typename T, /*The type used for activations/scales/compute*/
typename Enable = void>
class CutlassMoeFCRunner {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);
@ -161,6 +161,7 @@ class CutlassMoeFCRunner {
bool has_fc3_;
bool normalize_routing_weights_;
bool use_sparse_mixer_;
// Cuda events
contrib::cuda::AutoDestoryCudaEvent cuda_event_;
@ -175,7 +176,7 @@ class CutlassMoeFCRunner {
template <typename WeightType>
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
public:
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;

View file

@ -49,7 +49,7 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
const int sm = device_prop.major * 10 + device_prop.minor;
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
normalize_routing_weights_, use_sparse_mixer_);
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),

View file

@ -22,6 +22,7 @@ enum class MoEParallelType {
enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};
struct MoEParameters {
@ -225,9 +226,15 @@ class MoEBase {
}
normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;
use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}
}
bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ort_fastertransformer::ActivationType activation_type_;
};

View file

@ -37,61 +37,54 @@ template <>
struct ToCudaTypeWrapper<uint8_t, true> {
using MappedType = cutlass::uint4b_t;
};
} // anonymous namespace
QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {}
Status QMoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
const Tensor* fc1_scales = context->Input<Tensor>(3);
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
const Tensor* fc2_experts_weights = context->Input<Tensor>(5);
const Tensor* fc2_scales = context->Input<Tensor>(6);
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7);
const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(8);
const Tensor* fc3_scales_optional = context->Input<Tensor>(9);
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(10);
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters.
#endif
MoEParameters moe_params;
MoEQuantType quant_type = MoEQuantType::UINT4;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));
ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts,
moe_params.hidden_size, moe_params.inter_size));
// Support int4 only at the moment. We can add uint8 if needed.
static constexpr bool use_quint4x2 = true;
using T = MLFloat16;
using CudaT = typename ToCudaType<T>::MappedType;
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, use_quint4x2>::MappedType;
QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("expert_weight_bits", &expert_weight_bits_).IsOK());
ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4,
"expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_);
}
template <typename CudaWeightT>
Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
MoEParameters& moe_params,
const Tensor* input,
const Tensor* router_probs,
const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional,
const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional,
const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional,
const Tensor* fc1_scales,
const Tensor* fc2_scales,
const Tensor* fc3_scales_optional,
const cudaDeviceProp& device_prop) const {
auto stream = context->GetComputeStream();
auto& device_prop = GetDeviceProp();
const int sm = device_prop.major * 10 + device_prop.minor;
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaWeightT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
using T = MLFloat16;
using CudaT = typename ToCudaType<T>::MappedType;
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaWeightT> moe_runner(sm,
fc3_experts_weights_optional != nullptr,
normalize_routing_weights_,
use_sparse_mixer_);
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts),
static_cast<size_t>(k_));
size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int);
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, ws_size, false, stream);
IAllocatorUniquePtr<void> fc2_output = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
IAllocatorUniquePtr<void> expert_scales =
@ -140,13 +133,56 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
reinterpret_cast<int*>(expert_for_source_row.get()), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size), static_cast<int>(k_), Stream(context));
return Status::OK();
}
Status QMoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
const Tensor* fc1_scales = context->Input<Tensor>(3);
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
const Tensor* fc2_experts_weights = context->Input<Tensor>(5);
const Tensor* fc2_scales = context->Input<Tensor>(6);
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7);
const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(8);
const Tensor* fc3_scales_optional = context->Input<Tensor>(9);
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(10);
MoEQuantType quant_type = expert_weight_bits_ == 4 ? MoEQuantType::UINT4 : MoEQuantType::UINT8;
MoEParameters moe_params;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));
ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts,
moe_params.hidden_size, moe_params.inter_size));
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters.
#endif
if (quant_type == MoEQuantType::UINT4) {
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, true>::MappedType;
return QuantizedMoEImpl<CudaWeightT>(context, moe_params, input, router_probs,
fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights,
fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional,
GetDeviceProp());
} else {
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, false>::MappedType;
return QuantizedMoEImpl<CudaWeightT>(context, moe_params, input, router_probs,
fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights,
fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional,
GetDeviceProp());
}
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
return Status::OK();
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -18,6 +18,25 @@ class QMoE final : public CudaKernel, public MoEBase {
public:
explicit QMoE(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
template <typename CudaWeightT>
Status QuantizedMoEImpl(OpKernelContext* context,
MoEParameters& moe_params,
const Tensor* input,
const Tensor* router_probs,
const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional,
const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional,
const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional,
const Tensor* fc1_scales,
const Tensor* fc2_scales,
const Tensor* fc3_scales_optional,
const cudaDeviceProp& device_prop) const;
int64_t expert_weight_bits_;
};
} // namespace cuda

View file

@ -95,6 +95,10 @@ void RegisterCollectiveOps() {
"Whether to normalize routing weights",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("use_sparse_mixer",
"Whether to use sparse mixer",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("local_experts_start_index",
"The start index of local experts",
AttributeProto::INT,

View file

@ -1395,6 +1395,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
.Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu"))
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
@ -1410,7 +1411,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(
QMoE, 1,
OpSchema()
.SetDoc("Int4 MoE")
.SetDoc("Quantized MoE")
.Attr("activation_type",
"Activation function to use. Choose from relu, gelu, silu and identity. Default is relu",
AttributeProto::STRING,
@ -1423,18 +1424,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to normalize routing weights",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("expert_weight_bits",
"Number of bits used in quantized weights. Default is 4 bits",
AttributeProto::INT,
static_cast<int64_t>(4))
.Input(0,
"input",
"2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
"(batch_size, sequence_length, hidden_size)",
"T")
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size / 2)", "T1")
.Input(2,
"fc1_experts_weights",
"3D input tensor with shape (num_experts, hidden_size, inter_size) "
"or (num_experts, hidden_size, inter_size / 2)",
"T1")
.Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T")
.Input(4,
"fc1_experts_bias",
"2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
.Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size / 2)", "T1")
.Input(5,
"fc2_experts_weights",
"3D input tensor with shape (num_experts, inter_size, hidden_size) "
"or (num_experts, inter_size, hidden_size / 2)",
"T1")
.Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T")
.Input(7,
"fc2_experts_bias",
@ -1443,7 +1457,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(8,
"fc3_experts_weights",
"3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)",
"3D optional input tensor with shape (num_experts, hidden_size, inter_size) "
"or (num_experts, hidden_size, inter_size / 2)",
"T1",
OpSchema::Optional)
.Input(9,

View file

@ -1,361 +0,0 @@
# --------------------------------------------------------------------------
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest
from collections import OrderedDict
import numpy
import torch
import torch.nn.functional as F
from onnx import TensorProto, helper
from torch import nn
import onnxruntime
torch.manual_seed(42)
numpy.random.seed(42)
ORT_DTYPE = TensorProto.FLOAT
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
THRESHOLD = 3e-2
def value_string_of(numpy_array):
arr = numpy_array.flatten()
lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)]
return "{\n " + "f,\n ".join(lines) + "f}"
def print_tensor(name, numpy_array):
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")
def create_moe_onnx_graph(
num_rows,
num_experts,
hidden_size,
inter_size,
fc1_experts_weights,
fc2_experts_weights,
fc3_experts_weights,
topk,
):
nodes = [
helper.make_node(
"MoE",
[
"input",
"router_probs",
"fc1_experts_weights",
"",
"fc2_experts_weights",
"",
"fc3_experts_weights",
],
["output"],
"MoE_0",
k=topk,
normalize_routing_weights=1,
activation_type="silu",
domain="com.microsoft",
),
]
fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]
torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
initializers = [
helper.make_tensor(
"fc1_experts_weights",
ORT_DTYPE,
fc1_shape,
fc1_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
ORT_DTYPE,
fc2_shape,
fc2_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc3_experts_weights",
ORT_DTYPE,
fc3_shape,
fc3_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
]
graph_inputs = [
helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
ORT_DTYPE,
[num_rows, num_experts],
)
)
graph_outputs = [
helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]),
]
graph = helper.make_graph(
nodes,
"MoE_Graph",
graph_inputs,
graph_outputs,
initializers,
)
model = helper.make_model(graph)
return model.SerializeToString()
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"silu": nn.SiLU,
}
ACT2FN = ClassInstantier(ACT2CLS)
class MixtralConfig:
def __init__(
self,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
rope_theta=1e6,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states_1 = self.act_fn(self.w1(hidden_states))
current_hidden_states_3 = self.w3(hidden_states)
current_hidden_states = current_hidden_states_1 * current_hidden_states_3
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accommodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, config, batch_size, sequence_length):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
w1_list = []
w2_list = []
w3_list = []
for i in range(self.num_experts):
w1_list.append(self.experts[i].w1.weight)
w2_list.append(self.experts[i].w2.weight)
w3_list.append(self.experts[i].w3.weight)
self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
self.moe_experts_weight3 = torch.stack(w3_list, dim=0)
self.batch_size = batch_size
self.sequence_length = sequence_length
self.moe_onnx_graph = create_moe_onnx_graph(
self.batch_size * self.sequence_length,
self.num_experts,
self.hidden_dim,
self.ffn_dim,
self.moe_experts_weight1,
self.moe_experts_weight2,
self.moe_experts_weight3,
self.top_k,
)
self.ort_sess = self.create_ort_session()
def create_ort_session(self):
from onnxruntime import InferenceSession, SessionOptions
sess_options = SessionOptions()
cuda_providers = ["CUDAExecutionProvider"]
if cuda_providers[0] not in onnxruntime.get_available_providers():
return None
sess_options.log_severity_level = 2
ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])
return ort_session
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states # , router_logits
def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
ort_inputs = {
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
}
ort_output = None
if self.ort_sess is not None:
ort_output = self.ort_sess.run(None, ort_inputs)
return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits
# print_tensor("input", ort_inputs["input"])
# print_tensor("router_probs", ort_inputs["router_probs"])
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
# print_tensor("output", ort_output[0])
return None
def parity_check(self):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
torch_output = self.forward(hidden_state)
ort_output = self.ort_forward(hidden_state)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
print(
"batch_size:",
self.batch_size,
" sequence_length:",
self.sequence_length,
" max_diff:",
(torch_output - ort_output).abs().max(),
" parity: OK",
)
class TestMixtralMoE(unittest.TestCase):
def test_mixtral_moe_parity(self):
for batch_size in [1, 16]:
for sequence_length in [128, 1024]:
# use a small sizes to speed up the test
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.parity_check()
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load diff