mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Phi3 MoE cuda kernel (#21819)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
252222034f
commit
1d059b8702
13 changed files with 1080 additions and 595 deletions
|
|
@ -3083,6 +3083,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Number of top experts to select from expert pool</dd>
|
||||
<dt><tt>normalize_routing_weights</tt> : int</dt>
|
||||
<dd>Whether to normalize routing weights</dd>
|
||||
<dt><tt>use_sparse_mixer</tt> : int</dt>
|
||||
<dd>Whether to use sparse mixer</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (5 - 8)
|
||||
|
|
@ -4398,7 +4400,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>
|
||||
|
||||
Int4 MoE
|
||||
Quantized MoE
|
||||
|
||||
#### Version
|
||||
|
||||
|
|
@ -4409,10 +4411,14 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dl>
|
||||
<dt><tt>activation_type</tt> : string</dt>
|
||||
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
|
||||
<dt><tt>expert_weight_bits</tt> : int</dt>
|
||||
<dd>Number of bits used in quantized weights. Default is 4 bits</dd>
|
||||
<dt><tt>k</tt> : int</dt>
|
||||
<dd>Number of top experts to select from expert pool</dd>
|
||||
<dt><tt>normalize_routing_weights</tt> : int</dt>
|
||||
<dd>Whether to normalize routing weights</dd>
|
||||
<dt><tt>use_sparse_mixer</tt> : int</dt>
|
||||
<dd>Whether to use sparse mixer</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (7 - 11)
|
||||
|
|
@ -4423,19 +4429,19 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dt><tt>router_probs</tt> : T</dt>
|
||||
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
|
||||
<dt><tt>fc1_experts_weights</tt> : T1</dt>
|
||||
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
|
||||
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
|
||||
<dt><tt>fc1_scales</tt> : T</dt>
|
||||
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
|
||||
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
|
||||
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
|
||||
<dt><tt>fc2_experts_weights</tt> : T1</dt>
|
||||
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size / 2)</dd>
|
||||
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
|
||||
<dt><tt>fc2_scales</tt> : T</dt>
|
||||
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
|
||||
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
|
||||
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
|
||||
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
|
||||
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
|
||||
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
|
||||
<dt><tt>fc3_scales</tt> (optional) : T</dt>
|
||||
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
|
||||
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");
|
||||
|
||||
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
|
||||
normalize_routing_weights_);
|
||||
normalize_routing_weights_, use_sparse_mixer_);
|
||||
|
||||
size_t ws_size = moe_runner.getWorkspaceSize(
|
||||
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4100)
|
||||
#pragma warning(disable : 4244)
|
||||
#pragma warning(disable : 4200)
|
||||
#endif
|
||||
|
||||
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
namespace ort_fastertransformer {
|
||||
template class MoeGemmRunner<half, uint8_t>;
|
||||
} // namespace ort_fastertransformer
|
||||
|
||||
|
|
@ -127,7 +127,7 @@ __launch_bounds__(TPB) __global__
|
|||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool should_process_row = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
const int thread_row_offset = blockIdx.x * num_experts;
|
||||
float output_row_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
|
|
@ -135,7 +135,7 @@ __launch_bounds__(TPB) __global__
|
|||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
const int idx = thread_row_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
|
|
@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
|
||||
template <typename T, int TPB, int NUM_EXPERTS>
|
||||
__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) {
|
||||
// Does not support pre-Kepler architectures
|
||||
;
|
||||
}
|
||||
#else
|
||||
|
||||
template <typename T, int TPB, int NUM_EXPERTS>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) {
|
||||
static constexpr int K = 2;
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using KVBlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
|
||||
__shared__ float result_kvp_value[K];
|
||||
__shared__ typename KVBlockReduce::TempStorage kvTmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * NUM_EXPERTS;
|
||||
|
||||
float factor[K];
|
||||
bool logits_mask[K];
|
||||
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < K; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f);
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
#pragma unroll
|
||||
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
|
||||
const int idx = thread_row_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[K * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = K * block_row + k_idx;
|
||||
result_kvp_value[k_idx] = (float)result_kvp.value;
|
||||
indices[idx] = result_kvp.key;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) {
|
||||
const int idx = thread_row_offset + expert;
|
||||
factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]);
|
||||
logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]);
|
||||
if (k_idx == 1 && expert == indices[K * block_row]) {
|
||||
logits_mask[1] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < K; ++k_idx) {
|
||||
float row_sum(0);
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
row_sum += logits_mask[k_idx] ? 0 : exp((static_cast<float>(inputs[idx]) - result_kvp_value[k_idx]));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) {
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS);
|
||||
}
|
||||
|
||||
const float normalizing_factor = 1.f / row_sum;
|
||||
|
||||
const int idx = K * block_row + k_idx;
|
||||
if (threadIdx.x == indices[idx]) {
|
||||
const int input_idx = thread_row_offset + threadIdx.x;
|
||||
output[idx] = logits_mask[k_idx] ? 0
|
||||
: exp((static_cast<float>(inputs[input_idx]) - result_kvp_value[k_idx])) *
|
||||
normalizing_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
|
|
@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T
|
|||
template <typename T>
|
||||
void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output,
|
||||
int *indices, int *source_row, int num_rows, int num_experts, int k,
|
||||
bool normalize_routing_weights, cudaStream_t stream) {
|
||||
bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
if (use_sparse_mixer) {
|
||||
static constexpr int TPB = WARP_SIZE * WARPS_PER_TB;
|
||||
static constexpr float jitter_eps = 0.01f;
|
||||
|
||||
switch (num_experts) {
|
||||
case 8: {
|
||||
sparse_mixer_top2<T, TPB, 8><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
sparse_mixer_top2<T, TPB, 16><<<num_rows, TPB, 0, stream>>>(input, output, indices, source_row, jitter_eps);
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
ORT_THROW("Sparse mixer only supports 8 and 16 experts");
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
switch (num_experts) {
|
||||
case 2: {
|
||||
topk_gating_softmax_launcher_helper<T, 2, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
|
||||
|
|
@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i
|
|||
|
||||
template <typename T, typename WeightType, typename Enable>
|
||||
CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version, bool has_fc3,
|
||||
bool normalize_routing_weights)
|
||||
bool normalize_routing_weights, bool use_sparse_mixer)
|
||||
: has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0),
|
||||
normalize_routing_weights_(normalize_routing_weights) {
|
||||
normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) {
|
||||
moe_gemm_runner_.initialize(sm_version);
|
||||
}
|
||||
|
||||
|
|
@ -729,7 +851,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
|
|||
configure_ws_ptrs(workspace_ptr, static_cast<size_t>(num_rows), static_cast<size_t>(hidden_size),
|
||||
static_cast<size_t>(inter_size), static_cast<size_t>(num_experts), static_cast<size_t>(k));
|
||||
topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
|
||||
source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
|
||||
source_rows_, num_rows, num_experts, k, normalize_routing_weights_,
|
||||
use_sparse_mixer_, stream);
|
||||
|
||||
const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
|
||||
sorter_.run(reinterpret_cast<void *>(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_,
|
||||
|
|
@ -748,7 +871,8 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
|
|||
stream);
|
||||
}
|
||||
|
||||
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows);
|
||||
// moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size,
|
||||
// expanded_active_expert_rows);
|
||||
moe_gemm_runner_.moe_gemm_bias_act(
|
||||
permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases,
|
||||
fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index,
|
||||
|
|
@ -868,9 +992,9 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
|
|||
// experts in the end.
|
||||
|
||||
// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0,
|
||||
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input
|
||||
// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus
|
||||
// of the expanded index.
|
||||
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ...
|
||||
// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we
|
||||
// simply take the modulus of the expanded index.
|
||||
|
||||
template <typename T>
|
||||
__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output,
|
||||
|
|
@ -878,9 +1002,9 @@ __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *perm
|
|||
int *expanded_source_row_to_expanded_dest_row, int num_rows,
|
||||
int active_rows, int cols) {
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
|
||||
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need
|
||||
// the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in
|
||||
// MoE. 1 thread block will be responsible for all k summations.
|
||||
const int expanded_dest_row = blockIdx.x;
|
||||
const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
if (threadIdx.x == 0) {
|
||||
|
|
@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red
|
|||
|
||||
// ========================= TopK Softmax specializations ===========================
|
||||
template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int,
|
||||
int, bool, cudaStream_t);
|
||||
int, bool, bool, cudaStream_t);
|
||||
template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int,
|
||||
int, bool, cudaStream_t);
|
||||
int, bool, bool, cudaStream_t);
|
||||
|
||||
// ==================== Variable batched GEMM specializations ==================================
|
||||
template class CutlassMoeFCRunner<float, float>;
|
||||
template class CutlassMoeFCRunner<half, half>;
|
||||
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
|
||||
template class CutlassMoeFCRunner<half, uint8_t>;
|
||||
|
||||
// ===================== Specializations for init routing =========================
|
||||
template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int,
|
||||
|
|
@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const
|
|||
template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *,
|
||||
const half *, const int *, const int *, int, int, int, cudaStream_t);
|
||||
|
||||
} // namespace ort_fastertransformer
|
||||
} // namespace ort_fastertransformer
|
||||
|
|
@ -109,7 +109,7 @@ template <typename T, /*The type used for activations/scales/compute*/
|
|||
typename Enable = void>
|
||||
class CutlassMoeFCRunner {
|
||||
public:
|
||||
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
|
||||
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
|
||||
|
||||
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);
|
||||
|
||||
|
|
@ -161,6 +161,7 @@ class CutlassMoeFCRunner {
|
|||
|
||||
bool has_fc3_;
|
||||
bool normalize_routing_weights_;
|
||||
bool use_sparse_mixer_;
|
||||
|
||||
// Cuda events
|
||||
contrib::cuda::AutoDestoryCudaEvent cuda_event_;
|
||||
|
|
@ -175,7 +176,7 @@ class CutlassMoeFCRunner {
|
|||
template <typename WeightType>
|
||||
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
|
||||
public:
|
||||
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
|
||||
CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
|
||||
|
||||
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const int sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
|
||||
normalize_routing_weights_);
|
||||
normalize_routing_weights_, use_sparse_mixer_);
|
||||
|
||||
size_t ws_size = moe_runner.getWorkspaceSize(
|
||||
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ enum class MoEParallelType {
|
|||
enum class MoEQuantType {
|
||||
None = 0,
|
||||
UINT4 = 1,
|
||||
UINT8 = 2,
|
||||
};
|
||||
|
||||
struct MoEParameters {
|
||||
|
|
@ -225,9 +226,15 @@ class MoEBase {
|
|||
}
|
||||
|
||||
normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;
|
||||
|
||||
use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
|
||||
if (use_sparse_mixer_) {
|
||||
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
|
||||
}
|
||||
}
|
||||
|
||||
bool normalize_routing_weights_;
|
||||
bool use_sparse_mixer_;
|
||||
int64_t k_;
|
||||
ort_fastertransformer::ActivationType activation_type_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -37,61 +37,54 @@ template <>
|
|||
struct ToCudaTypeWrapper<uint8_t, true> {
|
||||
using MappedType = cutlass::uint4b_t;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {}
|
||||
|
||||
Status QMoE::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* router_probs = context->Input<Tensor>(1);
|
||||
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
|
||||
const Tensor* fc1_scales = context->Input<Tensor>(3);
|
||||
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
|
||||
const Tensor* fc2_experts_weights = context->Input<Tensor>(5);
|
||||
const Tensor* fc2_scales = context->Input<Tensor>(6);
|
||||
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7);
|
||||
const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(8);
|
||||
const Tensor* fc3_scales_optional = context->Input<Tensor>(9);
|
||||
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(10);
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters.
|
||||
#endif
|
||||
|
||||
MoEParameters moe_params;
|
||||
MoEQuantType quant_type = MoEQuantType::UINT4;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
|
||||
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
|
||||
fc3_experts_weights_optional, fc3_experts_bias_optional));
|
||||
ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts,
|
||||
moe_params.hidden_size, moe_params.inter_size));
|
||||
|
||||
// Support int4 only at the moment. We can add uint8 if needed.
|
||||
static constexpr bool use_quint4x2 = true;
|
||||
using T = MLFloat16;
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, use_quint4x2>::MappedType;
|
||||
QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("expert_weight_bits", &expert_weight_bits_).IsOK());
|
||||
ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4,
|
||||
"expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_);
|
||||
}
|
||||
|
||||
template <typename CudaWeightT>
|
||||
Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
|
||||
MoEParameters& moe_params,
|
||||
const Tensor* input,
|
||||
const Tensor* router_probs,
|
||||
const Tensor* fc1_experts_weights,
|
||||
const Tensor* fc1_experts_bias_optional,
|
||||
const Tensor* fc2_experts_weights,
|
||||
const Tensor* fc2_experts_bias_optional,
|
||||
const Tensor* fc3_experts_weights_optional,
|
||||
const Tensor* fc3_experts_bias_optional,
|
||||
const Tensor* fc1_scales,
|
||||
const Tensor* fc2_scales,
|
||||
const Tensor* fc3_scales_optional,
|
||||
const cudaDeviceProp& device_prop) const {
|
||||
auto stream = context->GetComputeStream();
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
const int sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaWeightT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
|
||||
normalize_routing_weights_);
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
using T = MLFloat16;
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
||||
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaWeightT> moe_runner(sm,
|
||||
fc3_experts_weights_optional != nullptr,
|
||||
normalize_routing_weights_,
|
||||
use_sparse_mixer_);
|
||||
|
||||
size_t ws_size = moe_runner.getWorkspaceSize(
|
||||
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
|
||||
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
|
||||
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts),
|
||||
static_cast<size_t>(k_));
|
||||
size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
|
||||
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
|
||||
size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
|
||||
size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int);
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, ws_size, false, stream);
|
||||
IAllocatorUniquePtr<void> fc2_output = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
|
||||
IAllocatorUniquePtr<void> expert_scales =
|
||||
|
|
@ -140,13 +133,56 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
|
|||
reinterpret_cast<int*>(expert_for_source_row.get()), static_cast<int>(moe_params.num_rows),
|
||||
static_cast<int>(moe_params.hidden_size), static_cast<int>(k_), Stream(context));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QMoE::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* router_probs = context->Input<Tensor>(1);
|
||||
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
|
||||
const Tensor* fc1_scales = context->Input<Tensor>(3);
|
||||
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
|
||||
const Tensor* fc2_experts_weights = context->Input<Tensor>(5);
|
||||
const Tensor* fc2_scales = context->Input<Tensor>(6);
|
||||
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7);
|
||||
const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(8);
|
||||
const Tensor* fc3_scales_optional = context->Input<Tensor>(9);
|
||||
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(10);
|
||||
|
||||
MoEQuantType quant_type = expert_weight_bits_ == 4 ? MoEQuantType::UINT4 : MoEQuantType::UINT8;
|
||||
MoEParameters moe_params;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
|
||||
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
|
||||
fc3_experts_weights_optional, fc3_experts_bias_optional));
|
||||
ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts,
|
||||
moe_params.hidden_size, moe_params.inter_size));
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters.
|
||||
#endif
|
||||
|
||||
if (quant_type == MoEQuantType::UINT4) {
|
||||
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, true>::MappedType;
|
||||
return QuantizedMoEImpl<CudaWeightT>(context, moe_params, input, router_probs,
|
||||
fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights,
|
||||
fc2_experts_bias_optional, fc3_experts_weights_optional,
|
||||
fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional,
|
||||
GetDeviceProp());
|
||||
} else {
|
||||
using CudaWeightT = typename ToCudaTypeWrapper<uint8_t, false>::MappedType;
|
||||
return QuantizedMoEImpl<CudaWeightT>(context, moe_params, input, router_probs,
|
||||
fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights,
|
||||
fc2_experts_bias_optional, fc3_experts_weights_optional,
|
||||
fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional,
|
||||
GetDeviceProp());
|
||||
}
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -18,6 +18,25 @@ class QMoE final : public CudaKernel, public MoEBase {
|
|||
public:
|
||||
explicit QMoE(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
template <typename CudaWeightT>
|
||||
Status QuantizedMoEImpl(OpKernelContext* context,
|
||||
MoEParameters& moe_params,
|
||||
const Tensor* input,
|
||||
const Tensor* router_probs,
|
||||
const Tensor* fc1_experts_weights,
|
||||
const Tensor* fc1_experts_bias_optional,
|
||||
const Tensor* fc2_experts_weights,
|
||||
const Tensor* fc2_experts_bias_optional,
|
||||
const Tensor* fc3_experts_weights_optional,
|
||||
const Tensor* fc3_experts_bias_optional,
|
||||
const Tensor* fc1_scales,
|
||||
const Tensor* fc2_scales,
|
||||
const Tensor* fc3_scales_optional,
|
||||
const cudaDeviceProp& device_prop) const;
|
||||
|
||||
int64_t expert_weight_bits_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -95,6 +95,10 @@ void RegisterCollectiveOps() {
|
|||
"Whether to normalize routing weights",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("use_sparse_mixer",
|
||||
"Whether to use sparse mixer",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("local_experts_start_index",
|
||||
"The start index of local experts",
|
||||
AttributeProto::INT,
|
||||
|
|
|
|||
|
|
@ -1395,6 +1395,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
|
|||
.Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu"))
|
||||
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
|
||||
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
|
||||
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
|
||||
|
|
@ -1410,7 +1411,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
|
|||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
QMoE, 1,
|
||||
OpSchema()
|
||||
.SetDoc("Int4 MoE")
|
||||
.SetDoc("Quantized MoE")
|
||||
.Attr("activation_type",
|
||||
"Activation function to use. Choose from relu, gelu, silu and identity. Default is relu",
|
||||
AttributeProto::STRING,
|
||||
|
|
@ -1423,18 +1424,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"Whether to normalize routing weights",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("expert_weight_bits",
|
||||
"Number of bits used in quantized weights. Default is 4 bits",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(4))
|
||||
.Input(0,
|
||||
"input",
|
||||
"2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
|
||||
"(batch_size, sequence_length, hidden_size)",
|
||||
"T")
|
||||
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
|
||||
.Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size / 2)", "T1")
|
||||
.Input(2,
|
||||
"fc1_experts_weights",
|
||||
"3D input tensor with shape (num_experts, hidden_size, inter_size) "
|
||||
"or (num_experts, hidden_size, inter_size / 2)",
|
||||
"T1")
|
||||
.Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T")
|
||||
.Input(4,
|
||||
"fc1_experts_bias",
|
||||
"2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
|
||||
.Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size / 2)", "T1")
|
||||
.Input(5,
|
||||
"fc2_experts_weights",
|
||||
"3D input tensor with shape (num_experts, inter_size, hidden_size) "
|
||||
"or (num_experts, inter_size, hidden_size / 2)",
|
||||
"T1")
|
||||
.Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T")
|
||||
.Input(7,
|
||||
"fc2_experts_bias",
|
||||
|
|
@ -1443,7 +1457,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema::Optional)
|
||||
.Input(8,
|
||||
"fc3_experts_weights",
|
||||
"3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)",
|
||||
"3D optional input tensor with shape (num_experts, hidden_size, inter_size) "
|
||||
"or (num_experts, hidden_size, inter_size / 2)",
|
||||
"T1",
|
||||
OpSchema::Optional)
|
||||
.Input(9,
|
||||
|
|
|
|||
|
|
@ -1,361 +0,0 @@
|
|||
# --------------------------------------------------------------------------
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
# --------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from onnx import TensorProto, helper
|
||||
from torch import nn
|
||||
|
||||
import onnxruntime
|
||||
|
||||
torch.manual_seed(42)
|
||||
numpy.random.seed(42)
|
||||
|
||||
ORT_DTYPE = TensorProto.FLOAT
|
||||
NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
|
||||
THRESHOLD = 3e-2
|
||||
|
||||
|
||||
def value_string_of(numpy_array):
|
||||
arr = numpy_array.flatten()
|
||||
lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)]
|
||||
return "{\n " + "f,\n ".join(lines) + "f}"
|
||||
|
||||
|
||||
def print_tensor(name, numpy_array):
|
||||
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")
|
||||
|
||||
|
||||
def create_moe_onnx_graph(
|
||||
num_rows,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
inter_size,
|
||||
fc1_experts_weights,
|
||||
fc2_experts_weights,
|
||||
fc3_experts_weights,
|
||||
topk,
|
||||
):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"MoE",
|
||||
[
|
||||
"input",
|
||||
"router_probs",
|
||||
"fc1_experts_weights",
|
||||
"",
|
||||
"fc2_experts_weights",
|
||||
"",
|
||||
"fc3_experts_weights",
|
||||
],
|
||||
["output"],
|
||||
"MoE_0",
|
||||
k=topk,
|
||||
normalize_routing_weights=1,
|
||||
activation_type="silu",
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
fc1_shape = [num_experts, hidden_size, inter_size]
|
||||
fc2_shape = [num_experts, inter_size, hidden_size]
|
||||
fc3_shape = [num_experts, hidden_size, inter_size]
|
||||
|
||||
torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
|
||||
|
||||
initializers = [
|
||||
helper.make_tensor(
|
||||
"fc1_experts_weights",
|
||||
ORT_DTYPE,
|
||||
fc1_shape,
|
||||
fc1_experts_weights.to(torch_type).flatten().tolist(),
|
||||
raw=False,
|
||||
),
|
||||
helper.make_tensor(
|
||||
"fc2_experts_weights",
|
||||
ORT_DTYPE,
|
||||
fc2_shape,
|
||||
fc2_experts_weights.to(torch_type).flatten().tolist(),
|
||||
raw=False,
|
||||
),
|
||||
helper.make_tensor(
|
||||
"fc3_experts_weights",
|
||||
ORT_DTYPE,
|
||||
fc3_shape,
|
||||
fc3_experts_weights.to(torch_type).flatten().tolist(),
|
||||
raw=False,
|
||||
),
|
||||
]
|
||||
|
||||
graph_inputs = [
|
||||
helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]),
|
||||
]
|
||||
|
||||
graph_inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
"router_probs",
|
||||
ORT_DTYPE,
|
||||
[num_rows, num_experts],
|
||||
)
|
||||
)
|
||||
|
||||
graph_outputs = [
|
||||
helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"MoE_Graph",
|
||||
graph_inputs,
|
||||
graph_outputs,
|
||||
initializers,
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model.SerializeToString()
|
||||
|
||||
|
||||
class ClassInstantier(OrderedDict):
|
||||
def __getitem__(self, key):
|
||||
content = super().__getitem__(key)
|
||||
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
ACT2CLS = {
|
||||
"silu": nn.SiLU,
|
||||
}
|
||||
ACT2FN = ClassInstantier(ACT2CLS)
|
||||
|
||||
|
||||
class MixtralConfig:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
rope_theta=1e6,
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=2,
|
||||
num_local_experts=8,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_local_experts = num_local_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
|
||||
class MixtralBlockSparseTop2MLP(nn.Module):
|
||||
def __init__(self, config: MixtralConfig):
|
||||
super().__init__()
|
||||
self.ffn_dim = config.intermediate_size
|
||||
self.hidden_dim = config.hidden_size
|
||||
|
||||
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
current_hidden_states_1 = self.act_fn(self.w1(hidden_states))
|
||||
current_hidden_states_3 = self.w3(hidden_states)
|
||||
current_hidden_states = current_hidden_states_1 * current_hidden_states_3
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
"""
|
||||
This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accommodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, sequence_length):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
||||
|
||||
w1_list = []
|
||||
w2_list = []
|
||||
w3_list = []
|
||||
for i in range(self.num_experts):
|
||||
w1_list.append(self.experts[i].w1.weight)
|
||||
w2_list.append(self.experts[i].w2.weight)
|
||||
w3_list.append(self.experts[i].w3.weight)
|
||||
|
||||
self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
|
||||
self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
|
||||
self.moe_experts_weight3 = torch.stack(w3_list, dim=0)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.moe_onnx_graph = create_moe_onnx_graph(
|
||||
self.batch_size * self.sequence_length,
|
||||
self.num_experts,
|
||||
self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
self.moe_experts_weight1,
|
||||
self.moe_experts_weight2,
|
||||
self.moe_experts_weight3,
|
||||
self.top_k,
|
||||
)
|
||||
|
||||
self.ort_sess = self.create_ort_session()
|
||||
|
||||
def create_ort_session(self):
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
|
||||
sess_options = SessionOptions()
|
||||
|
||||
cuda_providers = ["CUDAExecutionProvider"]
|
||||
if cuda_providers[0] not in onnxruntime.get_available_providers():
|
||||
return None
|
||||
|
||||
sess_options.log_severity_level = 2
|
||||
ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])
|
||||
|
||||
return ort_session
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states # , router_logits
|
||||
|
||||
def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
ort_inputs = {
|
||||
"input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
|
||||
"router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
|
||||
}
|
||||
|
||||
ort_output = None
|
||||
if self.ort_sess is not None:
|
||||
ort_output = self.ort_sess.run(None, ort_inputs)
|
||||
return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits
|
||||
|
||||
# print_tensor("input", ort_inputs["input"])
|
||||
# print_tensor("router_probs", ort_inputs["router_probs"])
|
||||
# print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
|
||||
# print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
|
||||
# print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
|
||||
# print_tensor("output", ort_output[0])
|
||||
|
||||
return None
|
||||
|
||||
def parity_check(self):
|
||||
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
|
||||
torch_output = self.forward(hidden_state)
|
||||
ort_output = self.ort_forward(hidden_state)
|
||||
if ort_output is not None:
|
||||
assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
|
||||
print(
|
||||
"batch_size:",
|
||||
self.batch_size,
|
||||
" sequence_length:",
|
||||
self.sequence_length,
|
||||
" max_diff:",
|
||||
(torch_output - ort_output).abs().max(),
|
||||
" parity: OK",
|
||||
)
|
||||
|
||||
|
||||
class TestMixtralMoE(unittest.TestCase):
|
||||
def test_mixtral_moe_parity(self):
|
||||
for batch_size in [1, 16]:
|
||||
for sequence_length in [128, 1024]:
|
||||
# use a small sizes to speed up the test
|
||||
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
|
||||
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
|
||||
mixtral_moe.parity_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue