diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index c0a7797ba7..1210442580 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -15,6 +15,10 @@ namespace onnxruntime { namespace contrib { namespace rocm { +constexpr int kPastSequenceLengthInputIndex = 6; +constexpr int kPastInputIndex = 4; +constexpr int kPresentOutputIndex = 1; + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Attention, \ @@ -22,8 +26,10 @@ namespace rocm { 1, \ T, \ kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ Attention); REGISTER_KERNEL_TYPED(float) @@ -40,47 +46,41 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); const Tensor* relative_position_bias = context->Input(5); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); + AttentionParameters parameters; ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, relative_position_bias, - nullptr, - device_prop.maxThreadsPerBlock)); - - // input shape (batch_size, sequence_length, input_hidden_size) - const auto& shape = input->Shape(); - int batch_size = static_cast(shape[0]); - int sequence_length = static_cast(shape[1]); - int input_hidden_size = static_cast(shape[2]); - - // Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP - // bias shape (3 * hidden_size) - const auto& bias_shape = bias->Shape(); - int hidden_size = static_cast(bias_shape[0]) / 3; - - int head_size = hidden_size / num_heads_; + ¶meters, + device_prop.maxThreadsPerBlock, + past_seq_len)); + ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention TensorShapeVector output_shape(3); - output_shape[0] = shape[0]; - output_shape[1] = shape[1]; - output_shape[2] = static_cast(hidden_size); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); Tensor* output = context->Output(0, output_shape); - int past_sequence_length = 0; - Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); + std::vector present_dims{ + 2, parameters.batch_size, parameters.num_heads, + parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length, + parameters.head_size}; + TensorShape present_shape(present_dims); + Tensor* present = context->Output(kPresentOutputIndex, present_shape); rocblas_handle rocblas = GetRocblasHandle(context); constexpr size_t element_size = sizeof(T); - // Use GEMM for fully connection. - int m = batch_size * sequence_length; - int n = 3 * hidden_size; - int k = input_hidden_size; - auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size, context->GetComputeStream()); + int m = parameters.batch_size * parameters.sequence_length; + int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); + int k = parameters.input_hidden_size; + auto gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); typedef typename ToHipType::MappedType HipT; namespace blas = rocm::tunable::blas; @@ -108,8 +108,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { /*beta=*/1.0f, reinterpret_cast(gemm_buffer.get()), n)); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, - sequence_length, past_sequence_length); + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.sequence_length, + parameters.past_sequence_length); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); return LaunchAttentionKernel( @@ -118,16 +122,16 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { Stream(context), rocblas, element_size, - batch_size, - sequence_length, - num_heads_, - head_size, - past_sequence_length, - is_unidirectional_, + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.past_sequence_length, + parameters.is_unidirectional, reinterpret_cast(gemm_buffer.get()), nullptr == mask_index ? nullptr : mask_index->Data(), nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - mask_filter_value_, + parameters.mask_filter_value, nullptr == past ? nullptr : past->Data(), nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), work_space.get(), diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc index 21eff66004..a254a8c04f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -43,6 +43,10 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { Tensor* output = ctx->Output(0, input->Shape()); + // For inferencing, we support one more optional output which is the sum + // of the input and skip tensors + Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + if (input->Shape() != skip->Shape()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "skip is expected to have same shape as input"); @@ -101,6 +105,7 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { GetTuningContext(), Stream(ctx), reinterpret_cast(output->MutableData()), + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), reinterpret_cast(gamma->Data()), diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu index e543945c16..bf33f940b3 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -41,12 +41,12 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int element_count) { + RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) { // this must be true because element_count is the total size of the tensor assert(element_count % ld == 0); - SkipLayerNormParams params(tuning_ctx, stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count); + SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, gamma, beta, bias, epsilon, ld, element_count); if (tuning_ctx->IsTunableOpEnabled()) { static SkipLayerNormTunableOp op; @@ -57,13 +57,13 @@ Status LaunchSkipLayerNormKernel( } template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, const float* input, + RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, float* skip_input_bias_add_output, const float* input, const float* skip, const float* gamma, const float* beta, const float* bias, float epsilon, int ld, int element_count); template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, const half* input, + RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, half* skip_input_bias_add_output, const half* input, const half* skip, const half* gamma, const half* beta, const half* bias, float epsilon, int ld, int element_count); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h index 7289347400..911164af92 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -15,6 +15,7 @@ Status LaunchSkipLayerNormKernel( RocmTuningContext* tuning, hipStream_t stream, T* output, // output tensor + T* skip_input_bias_add_output, // optional output tensor const T* input, // input tensor const T* skip, // skip tensor const T* gamma, // Layer normalization gamma tensor diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h index ee38b1c7e7..bcef54871a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -26,7 +26,7 @@ half maybe2half(float x) { template __global__ void SkipLayerNormKernel( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output) { + const T epsilon, T* output, T* skip_input_bias_add_output) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -39,6 +39,11 @@ __global__ void SkipLayerNormKernel( const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + + if (skip_input_bias_add_output != nullptr) { + skip_input_bias_add_output[idx] = val; + } + output[idx] = val; } @@ -49,7 +54,8 @@ __global__ void SkipLayerNormKernel( template __global__ void SkipLayerNormKernelVec( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, bool hasBias) { + const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -58,7 +64,7 @@ __global__ void SkipLayerNormKernelVec( hipcub::KeyValuePair thread_data(0, 0); using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP]; + T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];; if (threadIdx.x * ILP < ld) { VecT* input_val = reinterpret_cast(&input_v); VecT* skip_val = reinterpret_cast(&skip_v); @@ -76,9 +82,19 @@ __global__ void SkipLayerNormKernelVec( #pragma unroll for (int k = 0; k < ILP; k++) { input_v[k] += hasBias ? skip_v[k] + bias_v[k] : skip_v[k]; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v[i] = input_v[i]; + } + const T rldval = reverse_ld * input_v[k]; thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * input_v[k])); } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + } + *(reinterpret_cast(&output[idx])) = *reinterpret_cast(&input_v[0]); } } @@ -90,12 +106,13 @@ __global__ void SkipLayerNormKernelVec( template __global__ void SkipLayerNormKernelSmall( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, bool hasBias) { + const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { const T rld = T(1.f / ld); const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP]; + T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; hipcub::KeyValuePair thread_data(T(0.f), T(0.f)); @@ -116,10 +133,20 @@ __global__ void SkipLayerNormKernelSmall( #pragma unroll for (int i = 0; i < ILP; i++) { input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v[i] = input_v[i]; + } + const T rldval = rld * input_v[i]; rldval_sum += rldval; rldvalsq_sum += rldval * input_v[i]; } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + } + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); } LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h index 44a8a65863..b8d0dfee74 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -20,11 +20,11 @@ namespace rocm { template struct SkipLayerNormParams : OpParams { - SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, + SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), - epsilon(epsilon), ld(ld), element_count(element_count) {} + : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), + gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} std::string Signature() const override { std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); @@ -32,6 +32,7 @@ struct SkipLayerNormParams : OpParams { } T* output; + T* skip_input_bias_add_output; const T* input; const T* skip; const T* gamma; @@ -51,8 +52,8 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { dim3(ThreadsPerBlock), 0, params->stream>>>( params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, - (params->bias == nullptr) ? false : true); + params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } @@ -66,51 +67,52 @@ Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { dim3(ThreadsPerBlock), 0, params->stream>>>( params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, - (params->bias == nullptr) ? false : true); + params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); return HIP_CALL(hipGetLastError()); } template Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { bool hasBias = (params->bias == nullptr) ? false : true; + bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; if (0 == (params->ld % 4)) { const int grid_size = params->element_count / params->ld; if (params->ld <= 32) { constexpr int block_size = 32; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 64) { constexpr int block_size = 64 / 2; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 128) { constexpr int block_size = 128 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 384) { constexpr int block_size = 384 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 768) { constexpr int block_size = 768 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 1024) { constexpr int block_size = 1024 / 4; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else { constexpr int block_size = 256; SkipLayerNormKernel<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output); } } else { const int grid_size = params->element_count / params->ld; @@ -118,27 +120,27 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { constexpr int block_size = 32; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 64) { constexpr int block_size = 64; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld <= 128) { constexpr int block_size = 128; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else if (params->ld == 384) { constexpr int block_size = 384; SkipLayerNormKernelSmall<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output, hasBias); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput); } else { constexpr int block_size = 256; SkipLayerNormKernel<<stream>>>( params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - maybe2half(params->epsilon), params->output); + maybe2half(params->epsilon), params->output, params->skip_input_bias_add_output); } } return HIP_CALL(hipPeekAtLastError()); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu index 10d73467d8..21ffe83293 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu @@ -17,12 +17,12 @@ namespace onnxruntime { template class SkipLayerNormSmall : public IKernelExplorer { public: - SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip, + SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp(¶ms_))); @@ -41,12 +41,12 @@ class SkipLayerNormSmall : public IKernelExplorer { template class SkipLayerNormRegular : public IKernelExplorer { public: - SkipLayerNormRegular(DeviceArray& output, DeviceArray& input, DeviceArray& skip, + SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp(¶ms_))); @@ -65,12 +65,12 @@ class SkipLayerNormRegular : public IKernelExplorer { template class SkipLayerNormStaticSelection : public IKernelExplorer { public: - SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& input, DeviceArray& skip, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, + SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, + DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} void Run() override { ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection(¶ms_))); @@ -89,12 +89,12 @@ class SkipLayerNormStaticSelection : public IKernelExplorer { template class SkipLayerNormTunable : public IKernelExplorer { public: - SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip, + SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, float epsilon, int hidden_size, int element_count) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), - static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - static_cast(bias.ptr()), epsilon, hidden_size, element_count) { + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), + static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) { params_.TuningContext()->EnableTunableOp(); } @@ -113,14 +113,14 @@ class SkipLayerNormTunable : public IKernelExplorer { contrib::rocm::SkipLayerNormTunableOp op_{}; }; -#define REGISTER_OP(name, type, threads_per_block, vec_size) \ - py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ - .def("Run", &name::Run) \ +#define REGISTER_OP(name, type, threads_per_block, vec_size) \ + py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run) \ .def("IsSupported", &name::IsSupported); #define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ @@ -141,7 +141,7 @@ class SkipLayerNormTunable : public IKernelExplorer { #define REGISTER_OP_TYPED(name, type) \ py::class_>(m, #name "_" #type) \ .def(py::init()) \ .def("SetRepeats", &name::SetRepeats) \ .def("Profile", &name::Profile) \ diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index d35965513d..c75358a2e5 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -56,6 +56,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: # Becuase of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small epsilon = 0.05 if hidden_size < 8 else 0.0005 output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + # output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + # enforce nullptr in the backend + output_optional = np.empty((0), dtype=dtype) input_d = ke.DeviceArray(input_x) skip_d = ke.DeviceArray(skip) @@ -63,8 +66,24 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) y_d = ke.DeviceArray(output_y) + optional_d = ke.DeviceArray(output_optional) f = getattr(ke, func) - my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size) + + # output_optional is newly added optional output tensor for SkipLayerNorm + # Right now we are not testing the tensor like we do for output_y + # the test for output_optional could be considered in future as needed + my_op = f( + y_d, + optional_d, + input_d, + skip_d, + gamma_d, + beta_d, + bias_d, + epsilon, + hidden_size, + batch_size * seq_len * hidden_size, + ) if my_op.IsSupported(): my_op.Run() @@ -106,6 +125,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): bias = np.random.rand(hidden_size).astype(dtype) epsilon = 0.0005 output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + # output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + # enforce nullptr in the backend - optional when profiling + output_optional = np.empty((0), dtype=dtype) input_d = ke.DeviceArray(input_x) skip_d = ke.DeviceArray(skip) @@ -113,8 +135,24 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): beta_d = ke.DeviceArray(beta) bias_d = ke.DeviceArray(bias) y_d = ke.DeviceArray(output_y) + optional_d = ke.DeviceArray(output_optional) f = getattr(ke, func) - my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size) + + # output_optional is newly added optional output tensor for SkipLayerNorm + # Right now we are not testing the tensor like we do for output_y + # the profile for output_optional could be considered in future as needed + my_op = f( + y_d, + optional_d, + input_d, + skip_d, + gamma_d, + beta_d, + bias_d, + epsilon, + hidden_size, + batch_size * seq_len * hidden_size, + ) duration_ms = -1 if my_op.IsSupported(): diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 4ed141c893..c501bd72a9 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -436,8 +436,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias) { hidden_size); } -// Don't enable this test for DML/ROCM builds as these EP doesn't produce the new optional output yet -#if !defined(USE_ROCM) && !defined(USE_DML) +// Don't enable this test for DML builds as these EP doesn't produce the new optional output yet +#if !defined(USE_DML) TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) { int batch_size = 2; int sequence_length = 2; @@ -494,6 +494,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) { hidden_size); } +// SkipSimplifiedLayerNorm has not been enabled for ROCm yet +#if !defined(USE_ROCM) TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { int batch_size = 1; int sequence_length = 2; @@ -530,6 +532,7 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { true); } #endif +#endif } // namespace test } // namespace onnxruntime