update lamb and GatherGrad kernel for ROCm EP (#7184)

With ROCm4.1, the CUDA implementation of Lamb and GatherGrad can be
utilized for ROCm EP.
This commit is contained in:
Weixing Zhang 2021-04-02 09:02:49 -07:00 committed by GitHub
parent 17f91ff410
commit a3f17c8b0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 31 additions and 1085 deletions

View file

@ -1036,7 +1036,7 @@ if (onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template)
endif()
# During transition to separate hipFFT repo, put hipfft/include early
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen)
if (onnxruntime_ENABLE_TRAINING)

View file

@ -134,7 +134,7 @@ void Check<double>(const OpTester::Data& expected_data,
}
double threshold = 0.001;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005;
#endif
@ -186,7 +186,7 @@ void InternalNumericalCheck(const OpTester::Data& expected_data,
}
float threshold = 0.0001f;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005f;
#endif
@ -247,7 +247,7 @@ void Check<MLFloat16>(const OpTester::Data& expected_data,
}
float threshold = 0.001f;
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA)
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005f;
#endif
for (int i = 0; i < size; ++i) {

View file

@ -170,7 +170,7 @@ TEST(GatherGradOpTest, GatherFewDistinctIndices) {
RunGatherGradTestWithRandomData<float>(0, {2, 32}, {6, 128}, absolute_error);
}
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
namespace {
void RunGatherGradConsistentOutputTest(
int64_t axis,

View file

@ -104,7 +104,7 @@ IAllocatorUniquePtr<T> GetOffsetsFromCounts(
// adapted from here:
// https://github.com/pytorch/pytorch/blob/b186831c08e0e4e447eedb8a5cfab582995d37f9/aten/src/ATen/native/cuda/Embedding.cu#L121
template <typename T, typename TIndex>
template <typename T, typename TIndex, int NumElementsPerThread>
__global__ void DirectSumKernel(
const TIndex* dX_indices_sorted,
const TIndex* dY_indices_sorted,
@ -116,21 +116,20 @@ __global__ void DirectSumKernel(
int64_t num_batches) {
GatheredIndexIndex_t idx = blockIdx.x * 4 + threadIdx.y;
const int SZ = 4;
if (idx < num_gathered_indices && (idx == 0 || dX_indices_sorted[idx] != dX_indices_sorted[idx - 1])) {
do {
for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * SZ;
const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
const auto dX_row_offset =
(batch_idx * gather_dimension_size + dX_indices_sorted[idx]) * num_gathered_per_index;
const auto dY_row_offset =
(batch_idx * num_gathered_indices + dY_indices_sorted[idx]) * num_gathered_per_index;
AccumulationType_t<T> dY_value[SZ];
AccumulationType_t<T> dX_value[SZ];
AccumulationType_t<T> dY_value[NumElementsPerThread];
AccumulationType_t<T> dX_value[NumElementsPerThread];
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
for (int ii = 0; ii < NumElementsPerThread; ii++) {
const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE;
if (gathered_element_idx < num_gathered_per_index) {
dY_value[ii] = static_cast<AccumulationType_t<T>>(dY_data[dY_row_offset + gathered_element_idx]);
@ -139,12 +138,12 @@ __global__ void DirectSumKernel(
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
for (int ii = 0; ii < NumElementsPerThread; ii++) {
dX_value[ii] += dY_value[ii];
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
for (int ii = 0; ii < NumElementsPerThread; ii++) {
const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE;
if (gathered_element_idx < num_gathered_per_index) {
dX_data[dX_row_offset + gathered_element_idx] = static_cast<T>(dX_value[ii]);
@ -169,9 +168,9 @@ void DirectSumImpl(
int64_t gather_dimension_size,
int64_t num_batches) {
dim3 block(GPU_WARP_SIZE, 4);
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128));
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
DirectSumKernel<<<grid, block, 0, stream>>>(
DirectSumKernel<T, TIndex, GridDim::maxElementsPerThread><<<grid, block, 0, stream>>>(
dX_indices_sorted,
dY_indices_sorted,
dY_data,
@ -509,9 +508,9 @@ void Impl_Simplified(
dX_indices_sorted, dY_indices_sorted);
dim3 block(GPU_WARP_SIZE, 4);
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128));
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
DirectSumKernel<<<grid, block, 0, stream>>>(
DirectSumKernel<T, TIndex, GridDim::maxElementsPerThread><<<grid, block, 0, stream>>>(
dX_indices_sorted.get(),
dY_indices_sorted.get(),
dY_data,

View file

@ -1,705 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cmath>
#include "core/providers/rocm/rocm_allocator.h"
#include "core/providers/rocm/reduction/reduction_functions.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "orttraining/training_ops/rocm/optimizer/common.h"
#include "orttraining/training_ops/rocm/optimizer/lamb.h"
namespace onnxruntime {
namespace rocm {
std::vector<std::pair<int, int>> GenerateLambExtraAliasMapping() {
// Starting index of extra inputs.
constexpr int input_index_bias = 5;
// Starting index of extra outputs.
constexpr int output_index_bias = 1;
// Count of extra I/O groups. One group corresponds to a weight update.
constexpr int group_count = 1024;
// length of [w, g, m1, m2, w_mixed_precision].
constexpr int input_stride = 5;
// length of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
constexpr int output_stride = 5;
std::vector<std::pair<int, int>> alias_pairs{};
for (int i = 0; i < group_count; ++i) {
const int input = input_index_bias + i * input_stride;
const int output = output_index_bias + i * output_stride;
// w --> w_new
alias_pairs.emplace_back(std::make_pair(input, output));
// g --> g_new
alias_pairs.emplace_back(std::make_pair(input + 1, output + 1));
// m1 --> m1_new
alias_pairs.emplace_back(std::make_pair(input + 2, output + 2));
// m2 --> m2_new
alias_pairs.emplace_back(std::make_pair(input + 3, output + 3));
// w_mixed_precision --> w_mixed_precision_new
alias_pairs.emplace_back(std::make_pair(input + 4, output + 4));
}
// update_count are updated in place.
alias_pairs.emplace_back(std::make_pair(4, 0));
return alias_pairs;
}
// TODO: Once Schema is checked in to onnx lets fix this to match that
#define REGISTER_LAMB_KERNEL_TYPED(T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LambOptimizer, \
kMSDomain, \
1, \
T1##_##T2##_##T3##_##T4##_##T_GRAD_NORM##_##T_MIXED_PRECISION_FP, \
kRocmExecutionProvider, \
KernelDefBuilder() \
.Alias(GenerateLambExtraAliasMapping()) \
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep do_update in CPU */ \
.InputMemoryType<OrtMemTypeCPUInput>(4) /* Keep iteration_count in CPU */ \
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* Keep iteration_count in CPU */ \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()) \
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T3>()) \
.TypeConstraint("T4", DataTypeImpl::GetTensorType<T4>()) \
.TypeConstraint("T_MIXED_PRECISION_FP", DataTypeImpl::GetTensorType<T_MIXED_PRECISION_FP>()) \
.TypeConstraint("T_GRAD_NORM", DataTypeImpl::GetTensorType<T_GRAD_NORM>()), \
LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>);
REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, MLFloat16, MLFloat16)
REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, float, MLFloat16)
REGISTER_LAMB_KERNEL_TYPED(float, float, float, float, float, MLFloat16)
// REGISTER_LAMB_KERNEL_TYPED(double, double, double, double, double, MLFloat16)
// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, MLFloat16, MLFloat16)
// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, float, MLFloat16)
REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, MLFloat16, MLFloat16)
REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, float, MLFloat16)
void check_inputs_and_outputs(
const Tensor* w,
const Tensor* g,
const Tensor* m1,
const Tensor* m2,
const Tensor* w_mixed_precision,
const Tensor* w_new,
const Tensor* g_new,
const Tensor* m1_new,
const Tensor* m2_new,
const Tensor* w_mixed_precision_new) {
// Throw if we have incomplete input or output lists.
ORT_ENFORCE(w, "Weight tensor should not be null.");
ORT_ENFORCE(g, "gradient tensor should not be null.");
ORT_ENFORCE(m1, "First-order momentum tensor should not be null.");
ORT_ENFORCE(m2, "Second-order momentum tensor should not be null.");
ORT_ENFORCE(m1_new, "New first-order momentum tensor should not be null.");
ORT_ENFORCE(m2_new, "New second-order momentum tensor should not be null.");
// Check if all shapes are good.
ORT_ENFORCE(m1->Shape() == m1_new->Shape());
ORT_ENFORCE(m2->Shape() == m2_new->Shape());
if (w_new)
ORT_ENFORCE(w->Shape() == w_new->Shape());
if (g_new)
ORT_ENFORCE(g->Shape() == g_new->Shape());
if (w_mixed_precision && w_mixed_precision_new)
ORT_ENFORCE(w_mixed_precision->Shape() == w_mixed_precision_new->Shape());
}
template <typename TWeight, typename TGradient, typename TMomentum, typename TMixedPrecision>
Status copy_inputs_to_outputs(
hipStream_t stream,
OpKernelContext* ctx,
const int non_grouped_input_count,
const int non_grouped_output_count,
const int group_count,
const int input_group_size,
const int output_group_size) {
const Tensor* step_tensor = ctx->Input<Tensor>(4);
if (step_tensor) {
const int64_t* step_data = step_tensor->template Data<int64_t>();
Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape());
ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together.");
int64_t* step_data_new = step_tensor_new->template MutableData<int64_t>();
*step_data_new = *step_data;
}
for (int group_index = 0; group_index < group_count; ++group_index) {
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
const Tensor& w = *ctx->Input<Tensor>(input_start_index);
const Tensor& g = *ctx->Input<Tensor>(input_start_index + 1);
const Tensor& m1 = *ctx->Input<Tensor>(input_start_index + 2);
const Tensor& m2 = *ctx->Input<Tensor>(input_start_index + 3);
const Tensor* w_mixed_precision = ctx->Input<Tensor>(input_start_index + 4);
const int output_start_index = non_grouped_output_count + group_index * output_group_size;
Tensor* w_new = ctx->Output(output_start_index, w.Shape());
Tensor* g_new = ctx->Output(output_start_index + 1, g.Shape());
Tensor& m1_new = *ctx->Output(output_start_index + 2, m1.Shape());
Tensor& m2_new = *ctx->Output(output_start_index + 3, m2.Shape());
Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr;
// TODO: temporary hack until View is improved (it doesn't work with Alias)
if (w_new != nullptr)
w_new->SetByteOffset(w.ByteOffset());
if (g_new != nullptr)
g_new->SetByteOffset(g.ByteOffset());
if (w_mixed_precision_new != nullptr)
w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset());
if (w_new) {
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TWeight>(stream, w, *w_new));
}
if (g_new) {
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TGradient>(stream, g, *g_new));
}
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMomentum>(stream, m1, m1_new));
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMomentum>(stream, m2, m2_new));
if (w_mixed_precision_new) {
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMixedPrecision>(stream, *w_mixed_precision, *w_mixed_precision_new));
}
}
return Status::OK();
}
template <typename HipT2, typename HipT3, typename HipT4, typename HipT_GRAD_NORM>
Status launch_lamb_compute_direction(
hipStream_t stream,
const int64_t update_count,
const int group_count,
const HipT2* p_loss_scale,
const HipT_GRAD_NORM* p_g_norm,
std::vector<int>& tensor_sizes,
std::vector<const HipT2*>& p_ws,
std::vector<const HipT3*>& p_gs,
std::vector<const HipT4*>& p_m1s,
std::vector<const HipT4*>& p_m2s,
std::vector<HipT3*>& p_ds,
std::vector<HipT4*>& p_m1_news,
std::vector<HipT4*>& p_m2_news,
const std::vector<float>& alphas,
const std::vector<float>& betas,
const std::vector<float>& lambdas,
const std::vector<float>& epsilons,
const std::vector<float>& max_norms,
const int64_t do_bias_correction) {
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_gs.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_m1s.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_m2s.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_m1_news.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_m2_news.size()));
ORT_ENFORCE(group_count == static_cast<int>(alphas.size()));
ORT_ENFORCE(group_count == static_cast<int>(betas.size()));
ORT_ENFORCE(group_count == static_cast<int>(lambdas.size()));
ORT_ENFORCE(group_count == static_cast<int>(epsilons.size()));
constexpr int tensor_count_per_group = 6;
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
// Bucketize tensor groups by the associated optimizer configuration.
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
std::map<std::tuple<float, float, float, float, float>, std::vector<std::vector<void*>>> buckets;
std::map<std::tuple<float, float, float, float, float>, std::vector<int>> tensor_sizes_in_buckets;
for (int i = 0; i < group_count; ++i) {
if (tensor_sizes[i] > max_tensor_size) {
// For the first iteration (indexed by 0), the update count should be 2.
const float alpha_correction =
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alphas[i], update_count) : 1.f;
const float beta_correction =
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(betas[i], update_count) : 1.f;
LambComputeDirection(
stream,
p_ws[i],
p_gs[i],
p_m1s[i],
p_m2s[i],
p_loss_scale,
p_g_norm,
HipT4(alphas[i]),
HipT4(betas[i]),
HipT2(lambdas[i]),
HipT4(epsilons[i]),
HipT2(max_norms[i]),
HipT4(alpha_correction),
HipT4(beta_correction),
p_ds[i],
p_m1_news[i],
p_m2_news[i],
tensor_sizes[i]);
} else {
std::vector<void*> ptrs(tensor_count_per_group);
ptrs[0] = const_cast<HipT2*>(p_ws[i]); // weight tensor
ptrs[1] = const_cast<HipT3*>(p_gs[i]); // gradient (reused to store update direction)
ptrs[2] = const_cast<HipT4*>(p_m1s[i]); // 1st momentum
ptrs[3] = const_cast<HipT4*>(p_m2s[i]); // 2nd momentum
ptrs[4] = p_m1_news[i]; // new 1st momentum
ptrs[5] = p_m2_news[i]; // new 2nd momentum
auto key = std::make_tuple(alphas[i], betas[i], lambdas[i], epsilons[i], max_norms[i]);
buckets[key].push_back(ptrs);
tensor_sizes_in_buckets[key].push_back(tensor_sizes[i]);
}
}
for (auto& pair : buckets) {
const auto key = pair.first;
float alpha = 0.f, beta = 0.f, lambda = 0.f, epsilon = 0.f, max_norm = 0.f;
std::tie(alpha, beta, lambda, epsilon, max_norm) = key;
// For the first iteration (indexed by 0), the update count should be 1.
const float alpha_correction =
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alpha, update_count) : 1.f;
const float beta_correction =
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(beta, update_count) : 1.f;
typedef LambMultiTensorComputeDirectionFunctor<HipT2, HipT3, HipT4, HipT_GRAD_NORM> LambStage1;
LambStage1 lamb_stage1;
launch_multi_tensor_functor<tensor_count_per_group, LambStage1>(
stream,
2048 * 32,
tensor_sizes_in_buckets[key],
buckets[key],
lamb_stage1,
p_loss_scale, p_g_norm, lambda, alpha, beta, epsilon, HipT2(max_norm), alpha_correction, beta_correction);
}
return Status::OK();
}
template <typename HipTNorm, typename HipTIn1, typename HipTIn2>
Status launch_lamb_reduction(
const RocmKernel& kernel,
const int group_count,
std::vector<int>& tensor_sizes,
std::vector<HipTNorm*>& p_w_norms,
std::vector<HipTNorm*>& p_d_norms,
std::vector<const HipTIn1*>& p_ws,
std::vector<HipTIn2*>& p_ds,
void* reduction_buffer,
size_t reduction_buffer_size) {
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_w_norms.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_d_norms.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
constexpr int tensor_count_per_group = 4;
hipStream_t stream = kernel.Stream();
// Bucketize tensor groups by the associated optimizer configuration.
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
std::vector<std::vector<void*>> buckets;
std::vector<int> tensor_sizes_in_buckets;
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
for (int i = 0; i < group_count; ++i) {
if (tensor_sizes[i] > max_tensor_size) {
ORT_RETURN_IF_ERROR(reduce_square_sum(
stream,
p_ws[i],
p_w_norms[i],
tensor_sizes[i],
reduction_buffer,
reduction_buffer_size));
ORT_RETURN_IF_ERROR(reduce_square_sum(
stream,
p_ds[i],
p_d_norms[i],
tensor_sizes[i],
reduction_buffer,
reduction_buffer_size));
} else {
std::vector<void*> ptrs(tensor_count_per_group);
ptrs[0] = const_cast<HipTIn1*>(p_ws[i]); // weight tensor
ptrs[1] = const_cast<HipTIn2*>(p_ds[i]); // update direction
ptrs[2] = p_w_norms[i]; // weight tensor's norm
ptrs[3] = p_d_norms[i]; // update direction's norm
buckets.push_back(ptrs);
tensor_sizes_in_buckets.push_back(tensor_sizes[i]);
}
}
if (buckets.size() > 0) {
ORT_ENFORCE(tensor_sizes_in_buckets.size() > 0);
}
if (tensor_sizes_in_buckets.size() > 0) {
ORT_ENFORCE(buckets.size() > 0);
}
// Only launch multi-tensor function if we have at least one tensor in the buckets.
if (tensor_sizes_in_buckets.size() > 0 && buckets.size() > 0) {
typedef LambMultiTensorReductionFunctor<HipTIn1, HipTIn2, HipTNorm, HipTNorm, HipTNorm> TReducer;
TReducer reducer;
launch_multi_tensor_functor<tensor_count_per_group, TReducer>(
stream,
2048 * 32,
tensor_sizes_in_buckets,
buckets,
reducer,
kernel,
reduction_buffer,
reduction_buffer_size);
}
return Status::OK();
}
template <typename HipT1, typename HipT2, typename HipT3, typename HipT_MIXED_PRECISION_FP>
Status launch_lamb_update(
hipStream_t stream,
const int group_count,
const HipT1* eta,
const float ratio_min,
const float ratio_max,
std::vector<int>& tensor_sizes,
std::vector<HipT2*>& p_w_norms,
std::vector<HipT2*>& p_d_norms,
std::vector<const HipT2*>& p_ws,
std::vector<HipT3*>& p_ds,
/* output */ std::vector<HipT2*>& p_w_news,
/* output */ std::vector<HipT3*>& p_g_news,
/* output */ std::vector<HipT_MIXED_PRECISION_FP*>& p_w_mixed_precision_news) {
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_w_norms.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_d_norms.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_w_news.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_g_news.size()));
ORT_ENFORCE(group_count == static_cast<int>(p_w_mixed_precision_news.size()));
constexpr int tensor_count_per_group = 7;
// Bucketize tensor groups by the associated optimizer configuration.
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
std::vector<std::vector<void*>> buckets;
std::vector<int> tensor_sizes_in_bucket;
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
for (int i = 0; i < group_count; ++i) {
if (tensor_sizes[i] > max_tensor_size) {
LambUpdate(
stream,
eta,
ratio_min,
ratio_max,
p_d_norms[i],
p_w_norms[i],
p_ws[i],
p_ds[i],
p_w_news[i],
p_g_news[i],
p_w_mixed_precision_news[i],
tensor_sizes[i]);
} else {
std::vector<void*> ptrs(tensor_count_per_group);
ptrs[0] = p_w_norms[i]; // weight tensor's norm
ptrs[1] = p_d_norms[i]; // direction's norm
ptrs[2] = const_cast<HipT2*>(p_ws[i]); // weight tensor
ptrs[3] = p_ds[i]; // direction
ptrs[4] = p_w_news[i]; // new weight tensor
ptrs[5] = p_g_news[i]; // new gradient tensor
ptrs[6] = p_w_mixed_precision_news[i]; // new half-precision weight tensor
buckets.push_back(ptrs);
tensor_sizes_in_bucket.push_back(tensor_sizes[i]);
}
}
if (buckets.size() > 0) {
ORT_ENFORCE(tensor_sizes_in_bucket.size() > 0);
}
if (tensor_sizes_in_bucket.size() > 0) {
ORT_ENFORCE(buckets.size() > 0);
}
// Only launch multi-tensor function if we have at least one tensor in the buckets.
if (tensor_sizes_in_bucket.size() > 0 && buckets.size() > 0) {
typedef LambMultiTensorUpdateFunctor<
HipT1, HipT2, HipT3, HipT_MIXED_PRECISION_FP>
LambStage2;
LambStage2 lamb_stage2;
launch_multi_tensor_functor<tensor_count_per_group, LambStage2>(
stream,
2048 * 32,
tensor_sizes_in_bucket,
buckets,
lamb_stage2,
eta,
ratio_min,
ratio_max);
}
return Status::OK();
}
template <typename T1, typename T2, typename T3, typename T4, typename T_GRAD_NORM, typename T_MIXED_PRECISION_FP>
Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::ComputeInternal(OpKernelContext* ctx) const {
// HipT* are types used to invoke ROCM-based functions. It, for example, maps
// MLFloat16 in ONNXRuntime to half in ROCM.
typedef typename ToHipType<T1>::MappedType HipT1;
typedef typename ToHipType<T2>::MappedType HipT2;
typedef typename ToHipType<T3>::MappedType HipT3;
typedef typename ToHipType<T4>::MappedType HipT4;
typedef typename ToHipType<T_GRAD_NORM>::MappedType HipT_GRAD_NORM;
typedef typename ToHipType<T_MIXED_PRECISION_FP>::MappedType HipT_MIXED_PRECISION_FP;
constexpr int non_grouped_input_count = 5;
constexpr int input_group_size = 5;
constexpr int output_group_size = 5;
constexpr int non_grouped_output_count = 1;
constexpr int minimal_input_count = non_grouped_input_count + 1 * input_group_size - 1;
constexpr int minimal_output_count = non_grouped_output_count + 1 * output_group_size - 1;
const int grouped_input_tensor_count = ctx->InputCount() - non_grouped_input_count;
const int grouped_output_tensor_count = ctx->OutputCount() - non_grouped_output_count;
// At least one variable group for updating one weight tensor.
ORT_ENFORCE(
ctx->InputCount() >= minimal_input_count,
"Expect at least ", minimal_input_count, " inputs but got ",
ctx->InputCount());
// At least one variable group for updating one weight tensor.
ORT_ENFORCE(
ctx->OutputCount() >= minimal_output_count,
"Expect at least ", minimal_output_count, " outputs but got ",
ctx->OutputCount());
// In addition to the first non_grouped_input_count inputs, all inputs are repeated sequence of [w, g, m1, m2, w_mixed_precision].
ORT_ENFORCE(
grouped_input_tensor_count % input_group_size == 0,
"Input count must be ", non_grouped_input_count, " + ", input_group_size,
" x (number of weights to optimize).");
// Outputs are repeated sequence of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
ORT_ENFORCE(
grouped_output_tensor_count % output_group_size == 0,
"Output count must be ", non_grouped_output_count, " + ", output_group_size,
" x (number of weights to optimize).");
// Number of repeated [w, g, m1, m2, w_mixed_precision]'s should match number of repeated [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
ORT_ENFORCE(
grouped_input_tensor_count / input_group_size == grouped_output_tensor_count / output_group_size,
"Input and output tensor counts are not aligned. Please check LambOptimizer's input and output lists.");
// Number of [w, g, m1, m2, (w_mixed_precision)] (or [w_new, m1_new, m2_new, (w_mixed_precision_new)]).
const int group_count = (grouped_input_tensor_count + input_group_size - 1) / input_group_size;
// At least we need one group of alpha, beta, lambda, ..., for processing one group.
ORT_ENFORCE(alpha_.size() >= static_cast<size_t>(group_count));
ORT_ENFORCE(beta_.size() >= static_cast<size_t>(group_count));
ORT_ENFORCE(lambda_.size() >= static_cast<size_t>(group_count));
ORT_ENFORCE(epsilon_.size() >= static_cast<size_t>(group_count));
ORT_ENFORCE(max_norm_clip_.size() >= static_cast<size_t>(group_count));
// If gradient norm is not finite, we copy inputs to outputs directly.
if (ctx->Input<Tensor>(0)) {
auto update_signal_tensor = ctx->Input<Tensor>(0);
auto update_signal = *update_signal_tensor->template Data<bool>();
if (!update_signal) {
return copy_inputs_to_outputs<T2, T3, T4, T_MIXED_PRECISION_FP>(
Stream(),
ctx,
non_grouped_input_count,
non_grouped_output_count,
group_count,
input_group_size,
output_group_size);
}
}
const HipT2* loss_scale_data = nullptr;
if (ctx->Input<Tensor>(1)) {
const Tensor& loss_scale_tensor = *ctx->Input<Tensor>(1);
loss_scale_data = reinterpret_cast<const HipT2*>(loss_scale_tensor.template Data<T2>());
}
const HipT_GRAD_NORM* g_norm_data = nullptr;
if (ctx->Input<Tensor>(2)) {
const Tensor& g_norm_tensor = *ctx->Input<Tensor>(2);
g_norm_data = reinterpret_cast<const HipT_GRAD_NORM*>(g_norm_tensor.template Data<T_GRAD_NORM>());
}
const Tensor& eta = *ctx->Input<Tensor>(3);
const HipT1* eta_data = reinterpret_cast<const HipT1*>(eta.template Data<T1>());
const Tensor* step_tensor = ctx->Input<Tensor>(4);
const int64_t* step_data = nullptr;
if (step_tensor) {
step_data = step_tensor->template Data<int64_t>();
}
// Allocate buffer for reduction computation of update directions.
// The i-th update direction's norm is stored at the i-th element.
// We reduce type T3 tensor to type T2 scalar. An example is that T3=float16
// and T2=float.
IAllocatorUniquePtr<T2> d_norm_buffer = GetScratchBuffer<T2>(group_count);
HipT2* d_norm_data = reinterpret_cast<HipT2*>(d_norm_buffer.get());
HIP_RETURN_IF_ERROR(hipMemsetAsync(d_norm_data, 0, group_count * sizeof(T2), Stream()));
// Allocate buffer for reduction computation of weight tensor.
// The i-th weight's norm is stored at the i-th element.
// We reduce type T2 tensor to type T2 scalar. An example is that T2=float.
IAllocatorUniquePtr<T2> w_norm_buffer = GetScratchBuffer<T2>(group_count);
HipT2* w_norm_data = reinterpret_cast<HipT2*>(w_norm_buffer.get());
HIP_RETURN_IF_ERROR(hipMemsetAsync(w_norm_data, 0, group_count * sizeof(T2), Stream()));
// Find the max size of updated weight tensors.
int max_tensor_size = 0;
for (int group_index = 0; group_index < group_count; ++group_index) {
// Prepare used input tensors for this group.
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
const Tensor& w = *ctx->Input<Tensor>(input_start_index);
max_tensor_size = std::max(max_tensor_size, static_cast<int>(w.Shape().Size()));
}
const size_t reduction_buffer_size = [&]() {
// Allocate a buffer in byte for reduction API calls.
size_t rbs = compute_reduction_buffer_size<HipT2>(max_tensor_size);
// Enlarge reduction buffer to accomodate multi-tensor reduction kernel as well
const int tensor_group_size = 4; // w, d, w_norm, d_norm
const int max_blocks = ChunkGroup<tensor_group_size>::max_block_count;
const size_t multitensor_block_reduce_buffer_size = 2 * max_blocks * sizeof(HipT2);
rbs = std::max(rbs, multitensor_block_reduce_buffer_size);
return rbs;
}();
// Allocate reduction buffer whose size is reduction_buffer_size bytes.
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(reduction_buffer_size);
// Input tensors' pointers.
std::vector<const HipT2*> p_ws(group_count);
std::vector<const HipT3*> p_gs(group_count);
std::vector<const HipT4*> p_m1s(group_count);
std::vector<const HipT4*> p_m2s(group_count);
std::vector<const HipT_MIXED_PRECISION_FP*> p_w_mixed_precisions(group_count);
// ds' is an mutable version of gs' because we want to reuse
// gs' memory to store the update direction to avoid allocating a model-scale buffer.
std::vector<HipT3*> p_ds(group_count);
// Intermediate tensors, weight tensors' and directions' norms.
std::vector<HipT2*> p_w_norms(group_count);
std::vector<HipT2*> p_d_norms(group_count);
// Output tensors' pointers.
std::vector<HipT2*> p_w_news(group_count);
std::vector<HipT3*> p_g_news(group_count);
std::vector<HipT4*> p_m1_news(group_count);
std::vector<HipT4*> p_m2_news(group_count);
std::vector<HipT_MIXED_PRECISION_FP*> p_w_mixed_precision_news(group_count);
// The i-th element in following array is the size of
// the i-th updated weight tensor and other related tensors.
std::vector<int> tensor_sizes(group_count);
for (int group_index = 0; group_index < group_count; ++group_index) {
// Prepare used input tensors for this group.
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
const Tensor* w = ctx->Input<Tensor>(input_start_index);
const Tensor* g = ctx->Input<Tensor>(input_start_index + 1);
const Tensor* m1 = ctx->Input<Tensor>(input_start_index + 2);
const Tensor* m2 = ctx->Input<Tensor>(input_start_index + 3);
const Tensor* w_mixed_precision = ctx->Input<Tensor>(input_start_index + 4);
// Prepare used outputs tensors for this group.
const int output_start_index = non_grouped_output_count + group_index * output_group_size;
Tensor* w_new = ctx->Output(output_start_index, w->Shape());
Tensor* g_new = ctx->Output(output_start_index + 1, g->Shape());
Tensor* m1_new = ctx->Output(output_start_index + 2, m1->Shape());
Tensor* m2_new = ctx->Output(output_start_index + 3, m2->Shape());
Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr;
// TODO: temporary hack until View is improved (it doesn't work with Alias)
if (w_new != nullptr)
w_new->SetByteOffset(w->ByteOffset());
if (g_new != nullptr)
g_new->SetByteOffset(g->ByteOffset());
if (w_mixed_precision_new != nullptr)
w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset());
check_inputs_and_outputs(w, g, m1, m2, w_mixed_precision, w_new, g_new, m1_new, m2_new, w_mixed_precision_new);
// We should throw for preventing overflow in reduction APIs.
// The index in ROCM system is 32-bit integer.
ORT_ENFORCE(
w->Shape().Size() <
static_cast<int64_t>(std::numeric_limits<int>::max()));
tensor_sizes[group_index] = static_cast<int>(w->Shape().Size());
// Input tensors' pointers.
p_ws[group_index] = reinterpret_cast<const HipT2*>(w->template Data<T2>());
p_gs[group_index] = reinterpret_cast<const HipT3*>(g->template Data<T3>());
p_m1s[group_index] = reinterpret_cast<const HipT4*>(m1->template Data<T4>());
p_m2s[group_index] = reinterpret_cast<const HipT4*>(m2->template Data<T4>());
p_w_mixed_precisions[group_index] = w_mixed_precision != nullptr ? reinterpret_cast<const HipT_MIXED_PRECISION_FP*>(w_mixed_precision->template Data<T_MIXED_PRECISION_FP>()) : nullptr;
// The following cast is for reusing gradient tensor g to store update direction d.
p_ds[group_index] = const_cast<HipT3*>(reinterpret_cast<const HipT3*>(g->template Data<T3>()));
// Set up which pointer to store which tensor's norm.
p_w_norms[group_index] = w_norm_data + group_index;
p_d_norms[group_index] = d_norm_data + group_index;
// Output tensors' pointers.
p_w_news[group_index] = w_new != nullptr ? reinterpret_cast<HipT2*>(w_new->template MutableData<T2>()) : nullptr;
p_g_news[group_index] = g_new != nullptr ? reinterpret_cast<HipT3*>(g_new->template MutableData<T3>()) : nullptr;
p_m1_news[group_index] = reinterpret_cast<HipT4*>(m1_new->template MutableData<T4>());
p_m2_news[group_index] = reinterpret_cast<HipT4*>(m2_new->template MutableData<T4>());
p_w_mixed_precision_news[group_index] = w_mixed_precision_new != nullptr ? reinterpret_cast<HipT_MIXED_PRECISION_FP*>(w_mixed_precision_new->template MutableData<T_MIXED_PRECISION_FP>()) : nullptr;
}
ORT_RETURN_IF_ERROR(launch_lamb_compute_direction(
Stream(),
step_data ? *step_data : 0,
group_count,
loss_scale_data,
g_norm_data,
tensor_sizes,
p_ws, p_gs, p_m1s, p_m2s,
p_ds,
p_m1_news, p_m2_news,
alpha_, beta_, lambda_, epsilon_, max_norm_clip_,
do_bias_correction_));
ORT_RETURN_IF_ERROR(launch_lamb_reduction(
*this,
group_count,
tensor_sizes,
p_w_norms,
p_d_norms,
p_ws,
p_ds,
reduction_buffer.get(),
reduction_buffer_size));
ORT_RETURN_IF_ERROR(launch_lamb_update(
Stream(),
group_count,
eta_data,
ratio_min_,
ratio_max_,
tensor_sizes,
p_w_norms,
p_d_norms,
p_ws,
p_ds,
p_w_news,
p_g_news,
p_w_mixed_precision_news));
if (step_tensor) {
Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape());
ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together.");
int64_t* step_data_new = step_tensor_new->template MutableData<int64_t>();
*step_data_new = *step_data + 1;
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -159,9 +159,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float_float_float_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_float_MLFloat16, LambOptimizer)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double_double_double_MLFloat16, LambOptimizer)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, LambOptimizer)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_float_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double_double_double_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_float_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_float_MLFloat16, LambOptimizer)>,

View file

@ -1,113 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/rocm/tensor/gather_grad.h"
#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h"
#include "core/providers/common.h"
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
GatherGrad,
kMSDomain,
1,
kRocmExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0)
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()})
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherGrad);
namespace {
template <typename T, typename Tin>
Status CallGatherGradImpl(
const RocmKernel& rocm_kernel,
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
const Tensor& grad, const Tensor& indices,
Tensor& output) {
using HipT = typename ToHipType<T>::MappedType;
const T* grad_data = grad.template Data<T>();
T* output_data = output.template MutableData<T>();
const Tin* indices_data = indices.template Data<Tin>();
GatherGradImpl(
rocm_kernel,
reinterpret_cast<const HipT*>(grad_data),
indices_data,
indices.Shape().Size(),
num_weights,
stride,
reinterpret_cast<HipT*>(output_data),
num_inputs,
param_itrs);
return Status::OK();
}
template <typename T>
Status DispatchToGatherGradImplByTin(
MLDataType tin_data_type,
const RocmKernel& rocm_kernel,
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
const Tensor& grad, const Tensor& indices,
Tensor& output) {
if (utils::IsPrimitiveDataType<int32_t>(tin_data_type)) {
return CallGatherGradImpl<T, int32_t>(
rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
} else if (utils::IsPrimitiveDataType<int64_t>(tin_data_type)) {
return CallGatherGradImpl<T, int64_t>(
rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported Tin type: ", tin_data_type);
}
Status DispatchToGatherGradImpl(
MLDataType t_data_type, MLDataType tin_data_type,
const RocmKernel& rocm_kernel,
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
const Tensor& grad, const Tensor& indices,
Tensor& output) {
if (utils::IsPrimitiveDataType<float>(t_data_type)) {
return DispatchToGatherGradImplByTin<float>(
tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
} else if (utils::IsPrimitiveDataType<MLFloat16>(t_data_type)) {
return DispatchToGatherGradImplByTin<MLFloat16>(
tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported T type: ", t_data_type);
}
} // namespace
Status GatherGrad::ComputeInternal(OpKernelContext* context) const {
const Tensor* shape = context->Input<Tensor>(0);
const TensorShape data_shape(shape->template Data<int64_t>(), shape->Shape().Size());
const Tensor* indices = context->Input<Tensor>(1);
const Tensor* grad = context->Input<Tensor>(2);
Tensor* output = context->Output(0, data_shape);
HIP_RETURN_IF_ERROR(hipMemsetAsync(output->MutableDataRaw(), 0, output->SizeInBytes(), Stream()));
MLDataType T_type = grad->DataType();
MLDataType Tin_type = indices->DataType();
const auto axis = HandleNegativeAxis(axis_, data_shape.NumDimensions());
const int64_t stride = data_shape.SizeFromDimension(axis + 1);
const int64_t num_weights = data_shape.Size() / stride;
const int64_t num_inputs = data_shape.SizeFromDimension(axis);
const int64_t param_itrs = data_shape.SizeFromDimension(0) / num_inputs;
return DispatchToGatherGradImpl(
T_type, Tin_type, *this,
num_weights, stride, num_inputs, param_itrs,
*grad, *indices, *output);
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -1,216 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include <hipcub/hipcub.hpp>
namespace onnxruntime {
namespace rocm {
template <typename T>
__global__ void _Iota(
hipcub::CountingInputIterator<T> input,
size_t length,
T* output) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, length);
output[idx] = input[idx];
}
template <typename T, typename Tin, int NumElementsPerThread>
__global__ void _GatherGradImpl(
const Tin* input,
const Tin* indices,
const T* grad_output,
T* grad_weight,
int64_t numel,
int64_t input_numel,
int64_t param_itrs,
int64_t stride) {
int idx = blockIdx.x * 4 + threadIdx.y;
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
do {
for (int itr = 0; itr < param_itrs; ++itr) {
const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
const int weight_row = itr * input_numel + ((int)input[idx]) * stride; //the offset of the input
const int grad_row = (itr * numel + ((int)indices[idx])) * stride; //the offset of the gradient
float gradient[NumElementsPerThread];
float weight[NumElementsPerThread];
#pragma unroll
for (int ii = 0; ii < NumElementsPerThread; ii++) {
int feature_dim = start_feature + ii * GPU_WARP_SIZE;
if (feature_dim < stride) {
gradient[ii] = static_cast<float>(grad_output[grad_row + feature_dim]);
weight[ii] = static_cast<float>(grad_weight[weight_row + feature_dim]);
}
}
#pragma unroll
for (int ii = 0; ii < NumElementsPerThread; ii++) {
weight[ii] += gradient[ii];
}
#pragma unroll
for (int ii = 0; ii < NumElementsPerThread; ii++) {
int feature_dim = start_feature + ii * GPU_WARP_SIZE;
if (feature_dim < stride) {
grad_weight[weight_row + feature_dim] = static_cast<T>(weight[ii]);
}
}
}
idx++;
} while (idx < numel && input[idx] == input[idx - 1]);
}
}
// Special optimization for the case which the gather is on axis=0
template <typename T, typename Tin, int NumElementsPerThread>
__global__ void _GatherAxis0GradImpl(
const Tin* input,
const Tin* indices,
const T* grad_output,
T* grad_weight,
int64_t numel,
int64_t input_numel,
int64_t stride)
{
int idx = blockIdx.x * 4 + threadIdx.y;
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
const int weight_row = ((int)input[idx]) * stride; //the offset of the input
float weight[NumElementsPerThread];
for (int ii = 0; ii < NumElementsPerThread; ii++) {
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
if (feature_dim < stride)
weight[ii] = static_cast<float>(grad_weight[weight_row + feature_dim]);
}
do {
const int grad_row = ((int)indices[idx]) * stride; //the offset of the gradient
float gradient[NumElementsPerThread];
#pragma unroll
for (int ii = 0; ii < NumElementsPerThread; ii++) {
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
if (feature_dim < stride) {
gradient[ii] = static_cast<float>(grad_output[grad_row + feature_dim]);
weight[ii] += gradient[ii];
}
}
idx++;
} while (idx < numel && input[idx] == input[idx - 1]);
#pragma unroll
for (int ii = 0; ii < NumElementsPerThread; ii++) {
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
if (feature_dim < stride)
grad_weight[weight_row + feature_dim] = static_cast<T>(weight[ii]);
}
}
}
template <typename T, typename Tin>
void GatherGradImpl(
const RocmKernel& rocm_kernel,
const T* grad_data,
const Tin* indices_data,
const int64_t num_indices,
const int64_t num_weights,
const int64_t stride,
T* output_data,
const int64_t num_inputs, //The number of input elements starting from the gathering dimension
const int64_t param_itrs //The size of dimensions of the data before gathering dimension
) {
// allocate intermediate buffers
auto original_indices = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
hipStream_t stream = rocm_kernel.Stream();
// initialize original_indices with [0, num_indices)
{
const auto blocks_per_grid = CeilDiv(num_indices, GridDim::maxThreadsPerBlock);
hipcub::CountingInputIterator<Tin> counting_input(Tin{});
hipLaunchKernelGGL(_Iota, dim3(blocks_per_grid), dim3(GridDim::maxThreadsPerBlock), 0, stream,
counting_input, num_indices, original_indices.get());
}
auto indices_data_sorted = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
auto original_indices_sorted = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
// sort indices and original indices
size_t sort_temp_storage_size_bytes = 0;
HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs(
nullptr, sort_temp_storage_size_bytes,
indices_data, indices_data_sorted.get(),
original_indices.get(), original_indices_sorted.get(),
num_indices, 0, sizeof(Tin)*8, stream));
auto sort_temp_storage = rocm_kernel.GetScratchBuffer<void>(sort_temp_storage_size_bytes);
HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs(
sort_temp_storage.get(), sort_temp_storage_size_bytes,
indices_data, indices_data_sorted.get(),
original_indices.get(), original_indices_sorted.get(),
num_indices, 0, sizeof(Tin)*8, stream));
dim3 block(GPU_WARP_SIZE, 4);
dim3 grid(CeilDiv(num_indices, 4), CeilDiv(stride, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
// commented optimization resulted in increase variance of loss for BERT across multiple reasons
//if (param_itrs == 1)
//{
// hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherAxis0GradImpl<T, Tin, GridDim::maxElementsPerThread>), dim3(grid), dim3(block), 0, stream,
// indices_data_sorted.get(),
// original_indices_sorted.get(),
// grad_data,
// output_data,
// num_indices,
// num_inputs,
// stride);
//} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherGradImpl<T, Tin, GridDim::maxElementsPerThread>), dim3(grid), dim3(block), 0, stream,
indices_data_sorted.get(),
original_indices_sorted.get(),
grad_data,
output_data,
num_indices,
num_inputs,
param_itrs,
stride);
//}
}
#define SPECIALIZED_GRAD_IMPL2(T) \
template void GatherGradImpl<T, int64_t>( \
const RocmKernel& rocm_kernel, \
const T* grad_data, \
const int64_t* indices_data, \
const int64_t num_indices, \
const int64_t num_weights, \
const int64_t stride, \
T* output_data, \
const int64_t num_inputs, \
const int64_t param_itrs); \
template void GatherGradImpl<T, int32_t>( \
const RocmKernel& rocm_kernel, \
const T* grad_data, \
const int32_t* indices_data, \
const int64_t num_indices, \
const int64_t num_weights, \
const int64_t stride, \
T* output_data, \
const int64_t num_inputs, \
const int64_t param_itrs);
SPECIALIZED_GRAD_IMPL2(float)
SPECIALIZED_GRAD_IMPL2(half)
} // namespace rocm
} // namespace onnxruntime

View file

@ -1,25 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
template <typename T, typename Tin>
void GatherGradImpl(
const RocmKernel& rocm_kernel,
const T* grad_data,
const Tin* indices_data,
const int64_t num_indices,
const int64_t num_weights,
const int64_t stride,
T* output_data,
const int64_t num_inputs,
const int64_t param_itrs);
} // namespace rocm
} // namespace onnxruntime

View file

@ -204,12 +204,8 @@ training_ops_excluded_files = [
'nn/batch_norm_grad.h',
'optimizer/adam.cc',
'optimizer/adam.cu',
'optimizer/lamb.cc',
'reduction/reduction_all.cc',
'reduction/reduction_ops.cc',
'tensor/gather_grad.cc',
'tensor/gather_grad_impl.cu',
'tensor/gather_grad_impl.h',
'tensor/gather_nd_grad_impl.cu',
'cuda_training_kernels.cc',
'cuda_training_kernels.h',
@ -245,8 +241,18 @@ def hipify(src_file_path, dst_file_path):
s = s.replace('GPU_WARP_SIZE = 32', 'GPU_WARP_SIZE = 64')
s = s.replace('std::exp', 'expf')
s = s.replace('std::log', 'logf')
s = s.replace('#include <cub/device/device_radix_sort.cuh>', '#include <hipcub/hipcub.hpp>')
s = s.replace('#include <cub/iterator/counting_input_iterator.cuh>', '')
s = s.replace('#include <cub/device/device_radix_sort.cuh>',
'#include <hipcub/hipcub.hpp>\n#include <hipcub/backend/rocprim/device/device_radix_sort.hpp>')
s = s.replace('#include <cub/device/device_reduce.cuh>',
'#include <hipcub/backend/rocprim/device/device_reduce.hpp>')
s = s.replace('#include <cub/device/device_run_length_encode.cuh>',
'#include <hipcub/backend/rocprim/device/device_run_length_encode.hpp>')
s = s.replace('#include <cub/device/device_scan.cuh>',
'#include <hipcub/backend/rocprim/device/device_scan.hpp>')
s = s.replace('#include <cub/iterator/counting_input_iterator.cuh>',
'#include <hipcub/backend/rocprim/iterator/counting_input_iterator.hpp>')
s = s.replace('#include <cub/iterator/discard_output_iterator.cuh>',
'#include <hipcub/backend/rocprim/iterator/discard_output_iterator.hpp>')
s = s.replace('typedef half MappedType', 'typedef __half MappedType')
# CUBLAS -> ROCBLAS
# s = s.replace('CUBLAS', 'HIPBLAS')