diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index d020d4efa0..a08786115b 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1036,7 +1036,7 @@ if (onnxruntime_USE_ROCM) target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template) endif() # During transition to separate hipFFT repo, put hipfft/include early - target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) + target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen) if (onnxruntime_ENABLE_TRAINING) diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 84f9570250..3a7c0ce625 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -134,7 +134,7 @@ void Check(const OpTester::Data& expected_data, } double threshold = 0.001; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) threshold = 0.005; #endif @@ -186,7 +186,7 @@ void InternalNumericalCheck(const OpTester::Data& expected_data, } float threshold = 0.0001f; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) threshold = 0.005f; #endif @@ -247,7 +247,7 @@ void Check(const OpTester::Data& expected_data, } float threshold = 0.001f; -#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) +#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) threshold = 0.005f; #endif for (int i = 0; i < size; ++i) { diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc index 8965572ba6..fe093b7bc9 100644 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/gather_grad_op_test.cc @@ -170,7 +170,7 @@ TEST(GatherGradOpTest, GatherFewDistinctIndices) { RunGatherGradTestWithRandomData(0, {2, 32}, {6, 128}, absolute_error); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) namespace { void RunGatherGradConsistentOutputTest( int64_t axis, diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu index 9c0537c81d..76fda57864 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_grad_impl.cu @@ -104,7 +104,7 @@ IAllocatorUniquePtr GetOffsetsFromCounts( // adapted from here: // https://github.com/pytorch/pytorch/blob/b186831c08e0e4e447eedb8a5cfab582995d37f9/aten/src/ATen/native/cuda/Embedding.cu#L121 -template +template __global__ void DirectSumKernel( const TIndex* dX_indices_sorted, const TIndex* dY_indices_sorted, @@ -116,21 +116,20 @@ __global__ void DirectSumKernel( int64_t num_batches) { GatheredIndexIndex_t idx = blockIdx.x * 4 + threadIdx.y; - const int SZ = 4; if (idx < num_gathered_indices && (idx == 0 || dX_indices_sorted[idx] != dX_indices_sorted[idx - 1])) { do { for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) { - const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * SZ; + const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread; const auto dX_row_offset = (batch_idx * gather_dimension_size + dX_indices_sorted[idx]) * num_gathered_per_index; const auto dY_row_offset = (batch_idx * num_gathered_indices + dY_indices_sorted[idx]) * num_gathered_per_index; - AccumulationType_t dY_value[SZ]; - AccumulationType_t dX_value[SZ]; + AccumulationType_t dY_value[NumElementsPerThread]; + AccumulationType_t dX_value[NumElementsPerThread]; #pragma unroll - for (int ii = 0; ii < SZ; ii++) { + for (int ii = 0; ii < NumElementsPerThread; ii++) { const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE; if (gathered_element_idx < num_gathered_per_index) { dY_value[ii] = static_cast>(dY_data[dY_row_offset + gathered_element_idx]); @@ -139,12 +138,12 @@ __global__ void DirectSumKernel( } #pragma unroll - for (int ii = 0; ii < SZ; ii++) { + for (int ii = 0; ii < NumElementsPerThread; ii++) { dX_value[ii] += dY_value[ii]; } #pragma unroll - for (int ii = 0; ii < SZ; ii++) { + for (int ii = 0; ii < NumElementsPerThread; ii++) { const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE; if (gathered_element_idx < num_gathered_per_index) { dX_data[dX_row_offset + gathered_element_idx] = static_cast(dX_value[ii]); @@ -169,9 +168,9 @@ void DirectSumImpl( int64_t gather_dimension_size, int64_t num_batches) { dim3 block(GPU_WARP_SIZE, 4); - dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128)); + dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE)); - DirectSumKernel<<>>( + DirectSumKernel<<>>( dX_indices_sorted, dY_indices_sorted, dY_data, @@ -509,9 +508,9 @@ void Impl_Simplified( dX_indices_sorted, dY_indices_sorted); dim3 block(GPU_WARP_SIZE, 4); - dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128)); + dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE)); - DirectSumKernel<<>>( + DirectSumKernel<<>>( dX_indices_sorted.get(), dY_indices_sorted.get(), dY_data, diff --git a/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc b/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc deleted file mode 100644 index 0809149ec5..0000000000 --- a/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc +++ /dev/null @@ -1,705 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/reduction/reduction_functions.h" -#include "core/providers/rocm/math/binary_elementwise_ops.h" -#include "orttraining/training_ops/rocm/optimizer/common.h" -#include "orttraining/training_ops/rocm/optimizer/lamb.h" - -namespace onnxruntime { -namespace rocm { - -std::vector> GenerateLambExtraAliasMapping() { - // Starting index of extra inputs. - constexpr int input_index_bias = 5; - // Starting index of extra outputs. - constexpr int output_index_bias = 1; - // Count of extra I/O groups. One group corresponds to a weight update. - constexpr int group_count = 1024; - // length of [w, g, m1, m2, w_mixed_precision]. - constexpr int input_stride = 5; - // length of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new]. - constexpr int output_stride = 5; - - std::vector> alias_pairs{}; - for (int i = 0; i < group_count; ++i) { - const int input = input_index_bias + i * input_stride; - const int output = output_index_bias + i * output_stride; - // w --> w_new - alias_pairs.emplace_back(std::make_pair(input, output)); - // g --> g_new - alias_pairs.emplace_back(std::make_pair(input + 1, output + 1)); - // m1 --> m1_new - alias_pairs.emplace_back(std::make_pair(input + 2, output + 2)); - // m2 --> m2_new - alias_pairs.emplace_back(std::make_pair(input + 3, output + 3)); - // w_mixed_precision --> w_mixed_precision_new - alias_pairs.emplace_back(std::make_pair(input + 4, output + 4)); - } - - // update_count are updated in place. - alias_pairs.emplace_back(std::make_pair(4, 0)); - - return alias_pairs; -} - -// TODO: Once Schema is checked in to onnx lets fix this to match that -#define REGISTER_LAMB_KERNEL_TYPED(T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - LambOptimizer, \ - kMSDomain, \ - 1, \ - T1##_##T2##_##T3##_##T4##_##T_GRAD_NORM##_##T_MIXED_PRECISION_FP, \ - kRocmExecutionProvider, \ - KernelDefBuilder() \ - .Alias(GenerateLambExtraAliasMapping()) \ - .InputMemoryType(0) /* Keep do_update in CPU */ \ - .InputMemoryType(4) /* Keep iteration_count in CPU */ \ - .OutputMemoryType(0) /* Keep iteration_count in CPU */ \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T4", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T_MIXED_PRECISION_FP", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T_GRAD_NORM", DataTypeImpl::GetTensorType()), \ - LambOptimizer); - -REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, MLFloat16, MLFloat16) -REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, float, MLFloat16) -REGISTER_LAMB_KERNEL_TYPED(float, float, float, float, float, MLFloat16) -// REGISTER_LAMB_KERNEL_TYPED(double, double, double, double, double, MLFloat16) -// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, MLFloat16, MLFloat16) -// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, float, MLFloat16) -REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, MLFloat16, MLFloat16) -REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, float, MLFloat16) - -void check_inputs_and_outputs( - const Tensor* w, - const Tensor* g, - const Tensor* m1, - const Tensor* m2, - const Tensor* w_mixed_precision, - const Tensor* w_new, - const Tensor* g_new, - const Tensor* m1_new, - const Tensor* m2_new, - const Tensor* w_mixed_precision_new) { - // Throw if we have incomplete input or output lists. - ORT_ENFORCE(w, "Weight tensor should not be null."); - ORT_ENFORCE(g, "gradient tensor should not be null."); - ORT_ENFORCE(m1, "First-order momentum tensor should not be null."); - ORT_ENFORCE(m2, "Second-order momentum tensor should not be null."); - ORT_ENFORCE(m1_new, "New first-order momentum tensor should not be null."); - ORT_ENFORCE(m2_new, "New second-order momentum tensor should not be null."); - // Check if all shapes are good. - ORT_ENFORCE(m1->Shape() == m1_new->Shape()); - ORT_ENFORCE(m2->Shape() == m2_new->Shape()); - if (w_new) - ORT_ENFORCE(w->Shape() == w_new->Shape()); - if (g_new) - ORT_ENFORCE(g->Shape() == g_new->Shape()); - if (w_mixed_precision && w_mixed_precision_new) - ORT_ENFORCE(w_mixed_precision->Shape() == w_mixed_precision_new->Shape()); -} - -template -Status copy_inputs_to_outputs( - hipStream_t stream, - OpKernelContext* ctx, - const int non_grouped_input_count, - const int non_grouped_output_count, - const int group_count, - const int input_group_size, - const int output_group_size) { - const Tensor* step_tensor = ctx->Input(4); - if (step_tensor) { - const int64_t* step_data = step_tensor->template Data(); - Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape()); - ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together."); - int64_t* step_data_new = step_tensor_new->template MutableData(); - *step_data_new = *step_data; - } - - for (int group_index = 0; group_index < group_count; ++group_index) { - const int input_start_index = non_grouped_input_count + group_index * input_group_size; - const Tensor& w = *ctx->Input(input_start_index); - const Tensor& g = *ctx->Input(input_start_index + 1); - const Tensor& m1 = *ctx->Input(input_start_index + 2); - const Tensor& m2 = *ctx->Input(input_start_index + 3); - const Tensor* w_mixed_precision = ctx->Input(input_start_index + 4); - const int output_start_index = non_grouped_output_count + group_index * output_group_size; - Tensor* w_new = ctx->Output(output_start_index, w.Shape()); - Tensor* g_new = ctx->Output(output_start_index + 1, g.Shape()); - Tensor& m1_new = *ctx->Output(output_start_index + 2, m1.Shape()); - Tensor& m2_new = *ctx->Output(output_start_index + 3, m2.Shape()); - Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr; - - // TODO: temporary hack until View is improved (it doesn't work with Alias) - if (w_new != nullptr) - w_new->SetByteOffset(w.ByteOffset()); - if (g_new != nullptr) - g_new->SetByteOffset(g.ByteOffset()); - if (w_mixed_precision_new != nullptr) - w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset()); - - if (w_new) { - ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(stream, w, *w_new)); - } - if (g_new) { - ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(stream, g, *g_new)); - } - ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(stream, m1, m1_new)); - ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(stream, m2, m2_new)); - - if (w_mixed_precision_new) { - ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(stream, *w_mixed_precision, *w_mixed_precision_new)); - } - } - - return Status::OK(); -} - -template -Status launch_lamb_compute_direction( - hipStream_t stream, - const int64_t update_count, - const int group_count, - const HipT2* p_loss_scale, - const HipT_GRAD_NORM* p_g_norm, - std::vector& tensor_sizes, - std::vector& p_ws, - std::vector& p_gs, - std::vector& p_m1s, - std::vector& p_m2s, - std::vector& p_ds, - std::vector& p_m1_news, - std::vector& p_m2_news, - const std::vector& alphas, - const std::vector& betas, - const std::vector& lambdas, - const std::vector& epsilons, - const std::vector& max_norms, - const int64_t do_bias_correction) { - ORT_ENFORCE(group_count == static_cast(tensor_sizes.size())); - - ORT_ENFORCE(group_count == static_cast(p_ws.size())); - ORT_ENFORCE(group_count == static_cast(p_gs.size())); - ORT_ENFORCE(group_count == static_cast(p_m1s.size())); - ORT_ENFORCE(group_count == static_cast(p_m2s.size())); - ORT_ENFORCE(group_count == static_cast(p_ds.size())); - ORT_ENFORCE(group_count == static_cast(p_m1_news.size())); - ORT_ENFORCE(group_count == static_cast(p_m2_news.size())); - - ORT_ENFORCE(group_count == static_cast(alphas.size())); - ORT_ENFORCE(group_count == static_cast(betas.size())); - ORT_ENFORCE(group_count == static_cast(lambdas.size())); - ORT_ENFORCE(group_count == static_cast(epsilons.size())); - - constexpr int tensor_count_per_group = 6; - const int max_tensor_size = compute_max_tensor_size_per_launch(4); - // Bucketize tensor groups by the associated optimizer configuration. - // If two tensor groups use different "alpha", they should be put into two distinct buckets. - std::map, std::vector>> buckets; - std::map, std::vector> tensor_sizes_in_buckets; - for (int i = 0; i < group_count; ++i) { - if (tensor_sizes[i] > max_tensor_size) { - // For the first iteration (indexed by 0), the update count should be 2. - const float alpha_correction = - do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alphas[i], update_count) : 1.f; - const float beta_correction = - do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(betas[i], update_count) : 1.f; - - LambComputeDirection( - stream, - p_ws[i], - p_gs[i], - p_m1s[i], - p_m2s[i], - p_loss_scale, - p_g_norm, - HipT4(alphas[i]), - HipT4(betas[i]), - HipT2(lambdas[i]), - HipT4(epsilons[i]), - HipT2(max_norms[i]), - HipT4(alpha_correction), - HipT4(beta_correction), - p_ds[i], - p_m1_news[i], - p_m2_news[i], - tensor_sizes[i]); - } else { - std::vector ptrs(tensor_count_per_group); - ptrs[0] = const_cast(p_ws[i]); // weight tensor - ptrs[1] = const_cast(p_gs[i]); // gradient (reused to store update direction) - ptrs[2] = const_cast(p_m1s[i]); // 1st momentum - ptrs[3] = const_cast(p_m2s[i]); // 2nd momentum - ptrs[4] = p_m1_news[i]; // new 1st momentum - ptrs[5] = p_m2_news[i]; // new 2nd momentum - - auto key = std::make_tuple(alphas[i], betas[i], lambdas[i], epsilons[i], max_norms[i]); - buckets[key].push_back(ptrs); - tensor_sizes_in_buckets[key].push_back(tensor_sizes[i]); - } - } - - for (auto& pair : buckets) { - const auto key = pair.first; - float alpha = 0.f, beta = 0.f, lambda = 0.f, epsilon = 0.f, max_norm = 0.f; - std::tie(alpha, beta, lambda, epsilon, max_norm) = key; - - // For the first iteration (indexed by 0), the update count should be 1. - const float alpha_correction = - do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alpha, update_count) : 1.f; - const float beta_correction = - do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(beta, update_count) : 1.f; - - typedef LambMultiTensorComputeDirectionFunctor LambStage1; - LambStage1 lamb_stage1; - - launch_multi_tensor_functor( - stream, - 2048 * 32, - tensor_sizes_in_buckets[key], - buckets[key], - lamb_stage1, - p_loss_scale, p_g_norm, lambda, alpha, beta, epsilon, HipT2(max_norm), alpha_correction, beta_correction); - } - - return Status::OK(); -} - -template -Status launch_lamb_reduction( - const RocmKernel& kernel, - const int group_count, - std::vector& tensor_sizes, - std::vector& p_w_norms, - std::vector& p_d_norms, - std::vector& p_ws, - std::vector& p_ds, - void* reduction_buffer, - size_t reduction_buffer_size) { - ORT_ENFORCE(group_count == static_cast(tensor_sizes.size())); - - ORT_ENFORCE(group_count == static_cast(p_w_norms.size())); - ORT_ENFORCE(group_count == static_cast(p_d_norms.size())); - - ORT_ENFORCE(group_count == static_cast(p_ws.size())); - ORT_ENFORCE(group_count == static_cast(p_ds.size())); - - constexpr int tensor_count_per_group = 4; - hipStream_t stream = kernel.Stream(); - // Bucketize tensor groups by the associated optimizer configuration. - // If two tensor groups use different "alpha", they should be put into two distinct buckets. - std::vector> buckets; - std::vector tensor_sizes_in_buckets; - const int max_tensor_size = compute_max_tensor_size_per_launch(4); - for (int i = 0; i < group_count; ++i) { - if (tensor_sizes[i] > max_tensor_size) { - ORT_RETURN_IF_ERROR(reduce_square_sum( - stream, - p_ws[i], - p_w_norms[i], - tensor_sizes[i], - reduction_buffer, - reduction_buffer_size)); - ORT_RETURN_IF_ERROR(reduce_square_sum( - stream, - p_ds[i], - p_d_norms[i], - tensor_sizes[i], - reduction_buffer, - reduction_buffer_size)); - } else { - std::vector ptrs(tensor_count_per_group); - ptrs[0] = const_cast(p_ws[i]); // weight tensor - ptrs[1] = const_cast(p_ds[i]); // update direction - ptrs[2] = p_w_norms[i]; // weight tensor's norm - ptrs[3] = p_d_norms[i]; // update direction's norm - - buckets.push_back(ptrs); - tensor_sizes_in_buckets.push_back(tensor_sizes[i]); - } - } - - if (buckets.size() > 0) { - ORT_ENFORCE(tensor_sizes_in_buckets.size() > 0); - } - - if (tensor_sizes_in_buckets.size() > 0) { - ORT_ENFORCE(buckets.size() > 0); - } - - // Only launch multi-tensor function if we have at least one tensor in the buckets. - if (tensor_sizes_in_buckets.size() > 0 && buckets.size() > 0) { - typedef LambMultiTensorReductionFunctor TReducer; - TReducer reducer; - launch_multi_tensor_functor( - stream, - 2048 * 32, - tensor_sizes_in_buckets, - buckets, - reducer, - kernel, - reduction_buffer, - reduction_buffer_size); - } - - return Status::OK(); -} - -template -Status launch_lamb_update( - hipStream_t stream, - const int group_count, - const HipT1* eta, - const float ratio_min, - const float ratio_max, - std::vector& tensor_sizes, - std::vector& p_w_norms, - std::vector& p_d_norms, - std::vector& p_ws, - std::vector& p_ds, - /* output */ std::vector& p_w_news, - /* output */ std::vector& p_g_news, - /* output */ std::vector& p_w_mixed_precision_news) { - ORT_ENFORCE(group_count == static_cast(tensor_sizes.size())); - - ORT_ENFORCE(group_count == static_cast(p_w_norms.size())); - ORT_ENFORCE(group_count == static_cast(p_d_norms.size())); - ORT_ENFORCE(group_count == static_cast(p_ws.size())); - ORT_ENFORCE(group_count == static_cast(p_ds.size())); - ORT_ENFORCE(group_count == static_cast(p_w_news.size())); - ORT_ENFORCE(group_count == static_cast(p_g_news.size())); - ORT_ENFORCE(group_count == static_cast(p_w_mixed_precision_news.size())); - - constexpr int tensor_count_per_group = 7; - - // Bucketize tensor groups by the associated optimizer configuration. - // If two tensor groups use different "alpha", they should be put into two distinct buckets. - std::vector> buckets; - std::vector tensor_sizes_in_bucket; - const int max_tensor_size = compute_max_tensor_size_per_launch(4); - for (int i = 0; i < group_count; ++i) { - if (tensor_sizes[i] > max_tensor_size) { - LambUpdate( - stream, - eta, - ratio_min, - ratio_max, - p_d_norms[i], - p_w_norms[i], - p_ws[i], - p_ds[i], - p_w_news[i], - p_g_news[i], - p_w_mixed_precision_news[i], - tensor_sizes[i]); - } else { - std::vector ptrs(tensor_count_per_group); - ptrs[0] = p_w_norms[i]; // weight tensor's norm - ptrs[1] = p_d_norms[i]; // direction's norm - ptrs[2] = const_cast(p_ws[i]); // weight tensor - ptrs[3] = p_ds[i]; // direction - ptrs[4] = p_w_news[i]; // new weight tensor - ptrs[5] = p_g_news[i]; // new gradient tensor - ptrs[6] = p_w_mixed_precision_news[i]; // new half-precision weight tensor - buckets.push_back(ptrs); - tensor_sizes_in_bucket.push_back(tensor_sizes[i]); - } - } - - if (buckets.size() > 0) { - ORT_ENFORCE(tensor_sizes_in_bucket.size() > 0); - } - - if (tensor_sizes_in_bucket.size() > 0) { - ORT_ENFORCE(buckets.size() > 0); - } - - // Only launch multi-tensor function if we have at least one tensor in the buckets. - if (tensor_sizes_in_bucket.size() > 0 && buckets.size() > 0) { - typedef LambMultiTensorUpdateFunctor< - HipT1, HipT2, HipT3, HipT_MIXED_PRECISION_FP> - LambStage2; - LambStage2 lamb_stage2; - - launch_multi_tensor_functor( - stream, - 2048 * 32, - tensor_sizes_in_bucket, - buckets, - lamb_stage2, - eta, - ratio_min, - ratio_max); - } - - return Status::OK(); -} - -template -Status LambOptimizer::ComputeInternal(OpKernelContext* ctx) const { - // HipT* are types used to invoke ROCM-based functions. It, for example, maps - // MLFloat16 in ONNXRuntime to half in ROCM. - typedef typename ToHipType::MappedType HipT1; - typedef typename ToHipType::MappedType HipT2; - typedef typename ToHipType::MappedType HipT3; - typedef typename ToHipType::MappedType HipT4; - typedef typename ToHipType::MappedType HipT_GRAD_NORM; - typedef typename ToHipType::MappedType HipT_MIXED_PRECISION_FP; - - constexpr int non_grouped_input_count = 5; - constexpr int input_group_size = 5; - constexpr int output_group_size = 5; - constexpr int non_grouped_output_count = 1; - constexpr int minimal_input_count = non_grouped_input_count + 1 * input_group_size - 1; - constexpr int minimal_output_count = non_grouped_output_count + 1 * output_group_size - 1; - const int grouped_input_tensor_count = ctx->InputCount() - non_grouped_input_count; - const int grouped_output_tensor_count = ctx->OutputCount() - non_grouped_output_count; - - // At least one variable group for updating one weight tensor. - ORT_ENFORCE( - ctx->InputCount() >= minimal_input_count, - "Expect at least ", minimal_input_count, " inputs but got ", - ctx->InputCount()); - // At least one variable group for updating one weight tensor. - ORT_ENFORCE( - ctx->OutputCount() >= minimal_output_count, - "Expect at least ", minimal_output_count, " outputs but got ", - ctx->OutputCount()); - - // In addition to the first non_grouped_input_count inputs, all inputs are repeated sequence of [w, g, m1, m2, w_mixed_precision]. - ORT_ENFORCE( - grouped_input_tensor_count % input_group_size == 0, - "Input count must be ", non_grouped_input_count, " + ", input_group_size, - " x (number of weights to optimize)."); - // Outputs are repeated sequence of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new]. - ORT_ENFORCE( - grouped_output_tensor_count % output_group_size == 0, - "Output count must be ", non_grouped_output_count, " + ", output_group_size, - " x (number of weights to optimize)."); - // Number of repeated [w, g, m1, m2, w_mixed_precision]'s should match number of repeated [w_new, g_new, m1_new, m2_new, w_mixed_precision_new]. - ORT_ENFORCE( - grouped_input_tensor_count / input_group_size == grouped_output_tensor_count / output_group_size, - "Input and output tensor counts are not aligned. Please check LambOptimizer's input and output lists."); - - // Number of [w, g, m1, m2, (w_mixed_precision)] (or [w_new, m1_new, m2_new, (w_mixed_precision_new)]). - const int group_count = (grouped_input_tensor_count + input_group_size - 1) / input_group_size; - - // At least we need one group of alpha, beta, lambda, ..., for processing one group. - ORT_ENFORCE(alpha_.size() >= static_cast(group_count)); - ORT_ENFORCE(beta_.size() >= static_cast(group_count)); - ORT_ENFORCE(lambda_.size() >= static_cast(group_count)); - ORT_ENFORCE(epsilon_.size() >= static_cast(group_count)); - ORT_ENFORCE(max_norm_clip_.size() >= static_cast(group_count)); - - // If gradient norm is not finite, we copy inputs to outputs directly. - if (ctx->Input(0)) { - auto update_signal_tensor = ctx->Input(0); - auto update_signal = *update_signal_tensor->template Data(); - if (!update_signal) { - return copy_inputs_to_outputs( - Stream(), - ctx, - non_grouped_input_count, - non_grouped_output_count, - group_count, - input_group_size, - output_group_size); - } - } - - const HipT2* loss_scale_data = nullptr; - if (ctx->Input(1)) { - const Tensor& loss_scale_tensor = *ctx->Input(1); - loss_scale_data = reinterpret_cast(loss_scale_tensor.template Data()); - } - - const HipT_GRAD_NORM* g_norm_data = nullptr; - if (ctx->Input(2)) { - const Tensor& g_norm_tensor = *ctx->Input(2); - g_norm_data = reinterpret_cast(g_norm_tensor.template Data()); - } - - const Tensor& eta = *ctx->Input(3); - const HipT1* eta_data = reinterpret_cast(eta.template Data()); - - const Tensor* step_tensor = ctx->Input(4); - const int64_t* step_data = nullptr; - if (step_tensor) { - step_data = step_tensor->template Data(); - } - - // Allocate buffer for reduction computation of update directions. - // The i-th update direction's norm is stored at the i-th element. - // We reduce type T3 tensor to type T2 scalar. An example is that T3=float16 - // and T2=float. - IAllocatorUniquePtr d_norm_buffer = GetScratchBuffer(group_count); - HipT2* d_norm_data = reinterpret_cast(d_norm_buffer.get()); - HIP_RETURN_IF_ERROR(hipMemsetAsync(d_norm_data, 0, group_count * sizeof(T2), Stream())); - - // Allocate buffer for reduction computation of weight tensor. - // The i-th weight's norm is stored at the i-th element. - // We reduce type T2 tensor to type T2 scalar. An example is that T2=float. - IAllocatorUniquePtr w_norm_buffer = GetScratchBuffer(group_count); - HipT2* w_norm_data = reinterpret_cast(w_norm_buffer.get()); - HIP_RETURN_IF_ERROR(hipMemsetAsync(w_norm_data, 0, group_count * sizeof(T2), Stream())); - - // Find the max size of updated weight tensors. - int max_tensor_size = 0; - for (int group_index = 0; group_index < group_count; ++group_index) { - // Prepare used input tensors for this group. - const int input_start_index = non_grouped_input_count + group_index * input_group_size; - const Tensor& w = *ctx->Input(input_start_index); - max_tensor_size = std::max(max_tensor_size, static_cast(w.Shape().Size())); - } - - const size_t reduction_buffer_size = [&]() { - // Allocate a buffer in byte for reduction API calls. - size_t rbs = compute_reduction_buffer_size(max_tensor_size); - - // Enlarge reduction buffer to accomodate multi-tensor reduction kernel as well - const int tensor_group_size = 4; // w, d, w_norm, d_norm - const int max_blocks = ChunkGroup::max_block_count; - const size_t multitensor_block_reduce_buffer_size = 2 * max_blocks * sizeof(HipT2); - rbs = std::max(rbs, multitensor_block_reduce_buffer_size); - - return rbs; - }(); - - // Allocate reduction buffer whose size is reduction_buffer_size bytes. - IAllocatorUniquePtr reduction_buffer = GetScratchBuffer(reduction_buffer_size); - - // Input tensors' pointers. - std::vector p_ws(group_count); - std::vector p_gs(group_count); - std::vector p_m1s(group_count); - std::vector p_m2s(group_count); - std::vector p_w_mixed_precisions(group_count); - // ds' is an mutable version of gs' because we want to reuse - // gs' memory to store the update direction to avoid allocating a model-scale buffer. - std::vector p_ds(group_count); - // Intermediate tensors, weight tensors' and directions' norms. - std::vector p_w_norms(group_count); - std::vector p_d_norms(group_count); - // Output tensors' pointers. - std::vector p_w_news(group_count); - std::vector p_g_news(group_count); - std::vector p_m1_news(group_count); - std::vector p_m2_news(group_count); - std::vector p_w_mixed_precision_news(group_count); - // The i-th element in following array is the size of - // the i-th updated weight tensor and other related tensors. - std::vector tensor_sizes(group_count); - - for (int group_index = 0; group_index < group_count; ++group_index) { - // Prepare used input tensors for this group. - const int input_start_index = non_grouped_input_count + group_index * input_group_size; - const Tensor* w = ctx->Input(input_start_index); - const Tensor* g = ctx->Input(input_start_index + 1); - const Tensor* m1 = ctx->Input(input_start_index + 2); - const Tensor* m2 = ctx->Input(input_start_index + 3); - const Tensor* w_mixed_precision = ctx->Input(input_start_index + 4); - - // Prepare used outputs tensors for this group. - const int output_start_index = non_grouped_output_count + group_index * output_group_size; - Tensor* w_new = ctx->Output(output_start_index, w->Shape()); - Tensor* g_new = ctx->Output(output_start_index + 1, g->Shape()); - Tensor* m1_new = ctx->Output(output_start_index + 2, m1->Shape()); - Tensor* m2_new = ctx->Output(output_start_index + 3, m2->Shape()); - Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr; - - // TODO: temporary hack until View is improved (it doesn't work with Alias) - if (w_new != nullptr) - w_new->SetByteOffset(w->ByteOffset()); - if (g_new != nullptr) - g_new->SetByteOffset(g->ByteOffset()); - if (w_mixed_precision_new != nullptr) - w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset()); - - check_inputs_and_outputs(w, g, m1, m2, w_mixed_precision, w_new, g_new, m1_new, m2_new, w_mixed_precision_new); - - // We should throw for preventing overflow in reduction APIs. - // The index in ROCM system is 32-bit integer. - ORT_ENFORCE( - w->Shape().Size() < - static_cast(std::numeric_limits::max())); - tensor_sizes[group_index] = static_cast(w->Shape().Size()); - - // Input tensors' pointers. - p_ws[group_index] = reinterpret_cast(w->template Data()); - p_gs[group_index] = reinterpret_cast(g->template Data()); - p_m1s[group_index] = reinterpret_cast(m1->template Data()); - p_m2s[group_index] = reinterpret_cast(m2->template Data()); - p_w_mixed_precisions[group_index] = w_mixed_precision != nullptr ? reinterpret_cast(w_mixed_precision->template Data()) : nullptr; - - // The following cast is for reusing gradient tensor g to store update direction d. - p_ds[group_index] = const_cast(reinterpret_cast(g->template Data())); - - // Set up which pointer to store which tensor's norm. - p_w_norms[group_index] = w_norm_data + group_index; - p_d_norms[group_index] = d_norm_data + group_index; - - // Output tensors' pointers. - p_w_news[group_index] = w_new != nullptr ? reinterpret_cast(w_new->template MutableData()) : nullptr; - p_g_news[group_index] = g_new != nullptr ? reinterpret_cast(g_new->template MutableData()) : nullptr; - p_m1_news[group_index] = reinterpret_cast(m1_new->template MutableData()); - p_m2_news[group_index] = reinterpret_cast(m2_new->template MutableData()); - p_w_mixed_precision_news[group_index] = w_mixed_precision_new != nullptr ? reinterpret_cast(w_mixed_precision_new->template MutableData()) : nullptr; - } - - ORT_RETURN_IF_ERROR(launch_lamb_compute_direction( - Stream(), - step_data ? *step_data : 0, - group_count, - loss_scale_data, - g_norm_data, - tensor_sizes, - p_ws, p_gs, p_m1s, p_m2s, - p_ds, - p_m1_news, p_m2_news, - alpha_, beta_, lambda_, epsilon_, max_norm_clip_, - do_bias_correction_)); - - ORT_RETURN_IF_ERROR(launch_lamb_reduction( - *this, - group_count, - tensor_sizes, - p_w_norms, - p_d_norms, - p_ws, - p_ds, - reduction_buffer.get(), - reduction_buffer_size)); - - ORT_RETURN_IF_ERROR(launch_lamb_update( - Stream(), - group_count, - eta_data, - ratio_min_, - ratio_max_, - tensor_sizes, - p_w_norms, - p_d_norms, - p_ws, - p_ds, - p_w_news, - p_g_news, - p_w_mixed_precision_news)); - - if (step_tensor) { - Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape()); - ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together."); - int64_t* step_data_new = step_tensor_new->template MutableData(); - *step_data_new = *step_data + 1; - } - - return Status::OK(); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index fece1d6ece..58af838854 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -159,9 +159,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad.cc b/orttraining/orttraining/training_ops/rocm/tensor/gather_grad.cc deleted file mode 100644 index d4ff6c37d8..0000000000 --- a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad.cc +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/tensor/gather_grad.h" -#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace rocm { - -ONNX_OPERATOR_KERNEL_EX( - GatherGrad, - kMSDomain, - 1, - kRocmExecutionProvider, - KernelDefBuilder() - .InputMemoryType(0) - .TypeConstraint("I", DataTypeImpl::GetTensorType()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .TypeConstraint("Tind", std::vector{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), - GatherGrad); - -namespace { -template -Status CallGatherGradImpl( - const RocmKernel& rocm_kernel, - int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs, - const Tensor& grad, const Tensor& indices, - Tensor& output) { - using HipT = typename ToHipType::MappedType; - - const T* grad_data = grad.template Data(); - T* output_data = output.template MutableData(); - const Tin* indices_data = indices.template Data(); - - GatherGradImpl( - rocm_kernel, - reinterpret_cast(grad_data), - indices_data, - indices.Shape().Size(), - num_weights, - stride, - reinterpret_cast(output_data), - num_inputs, - param_itrs); - - return Status::OK(); -} - -template -Status DispatchToGatherGradImplByTin( - MLDataType tin_data_type, - const RocmKernel& rocm_kernel, - int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs, - const Tensor& grad, const Tensor& indices, - Tensor& output) { - if (utils::IsPrimitiveDataType(tin_data_type)) { - return CallGatherGradImpl( - rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output); - } else if (utils::IsPrimitiveDataType(tin_data_type)) { - return CallGatherGradImpl( - rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported Tin type: ", tin_data_type); -} - -Status DispatchToGatherGradImpl( - MLDataType t_data_type, MLDataType tin_data_type, - const RocmKernel& rocm_kernel, - int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs, - const Tensor& grad, const Tensor& indices, - Tensor& output) { - if (utils::IsPrimitiveDataType(t_data_type)) { - return DispatchToGatherGradImplByTin( - tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output); - } else if (utils::IsPrimitiveDataType(t_data_type)) { - return DispatchToGatherGradImplByTin( - tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported T type: ", t_data_type); -} -} // namespace - -Status GatherGrad::ComputeInternal(OpKernelContext* context) const { - const Tensor* shape = context->Input(0); - const TensorShape data_shape(shape->template Data(), shape->Shape().Size()); - const Tensor* indices = context->Input(1); - const Tensor* grad = context->Input(2); - - Tensor* output = context->Output(0, data_shape); - HIP_RETURN_IF_ERROR(hipMemsetAsync(output->MutableDataRaw(), 0, output->SizeInBytes(), Stream())); - MLDataType T_type = grad->DataType(); - MLDataType Tin_type = indices->DataType(); - - const auto axis = HandleNegativeAxis(axis_, data_shape.NumDimensions()); - const int64_t stride = data_shape.SizeFromDimension(axis + 1); - const int64_t num_weights = data_shape.Size() / stride; - const int64_t num_inputs = data_shape.SizeFromDimension(axis); - const int64_t param_itrs = data_shape.SizeFromDimension(0) / num_inputs; - - return DispatchToGatherGradImpl( - T_type, Tin_type, *this, - num_weights, stride, num_inputs, param_itrs, - *grad, *indices, *output); -} - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.cu b/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.cu deleted file mode 100644 index bd988d0402..0000000000 --- a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.cu +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -#include - - -namespace onnxruntime { -namespace rocm { - -template -__global__ void _Iota( - hipcub::CountingInputIterator input, - size_t length, - T* output) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, length); - output[idx] = input[idx]; -} - -template -__global__ void _GatherGradImpl( - const Tin* input, - const Tin* indices, - const T* grad_output, - T* grad_weight, - int64_t numel, - int64_t input_numel, - int64_t param_itrs, - int64_t stride) { - int idx = blockIdx.x * 4 + threadIdx.y; - - if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) { - do { - for (int itr = 0; itr < param_itrs; ++itr) { - const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread; - const int weight_row = itr * input_numel + ((int)input[idx]) * stride; //the offset of the input - const int grad_row = (itr * numel + ((int)indices[idx])) * stride; //the offset of the gradient - - float gradient[NumElementsPerThread]; - float weight[NumElementsPerThread]; - -#pragma unroll - for (int ii = 0; ii < NumElementsPerThread; ii++) { - int feature_dim = start_feature + ii * GPU_WARP_SIZE; - if (feature_dim < stride) { - gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); - weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); - } - } - -#pragma unroll - for (int ii = 0; ii < NumElementsPerThread; ii++) { - weight[ii] += gradient[ii]; - } - -#pragma unroll - for (int ii = 0; ii < NumElementsPerThread; ii++) { - int feature_dim = start_feature + ii * GPU_WARP_SIZE; - if (feature_dim < stride) { - grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); - } - } - } - idx++; - } while (idx < numel && input[idx] == input[idx - 1]); - } -} - -// Special optimization for the case which the gather is on axis=0 -template -__global__ void _GatherAxis0GradImpl( - const Tin* input, - const Tin* indices, - const T* grad_output, - T* grad_weight, - int64_t numel, - int64_t input_numel, - int64_t stride) -{ - int idx = blockIdx.x * 4 + threadIdx.y; - - if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) { - const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread; - const int weight_row = ((int)input[idx]) * stride; //the offset of the input - - float weight[NumElementsPerThread]; - for (int ii = 0; ii < NumElementsPerThread; ii++) { - int feature_dim = start_feature + ii * GPU_WARP_SIZE/4; - if (feature_dim < stride) - weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); - } - - do { - const int grad_row = ((int)indices[idx]) * stride; //the offset of the gradient - float gradient[NumElementsPerThread]; - -#pragma unroll - for (int ii = 0; ii < NumElementsPerThread; ii++) { - int feature_dim = start_feature + ii * GPU_WARP_SIZE/4; - if (feature_dim < stride) { - gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); - weight[ii] += gradient[ii]; - } - } - idx++; - } while (idx < numel && input[idx] == input[idx - 1]); - -#pragma unroll - for (int ii = 0; ii < NumElementsPerThread; ii++) { - int feature_dim = start_feature + ii * GPU_WARP_SIZE/4; - if (feature_dim < stride) - grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); - } - } -} - -template -void GatherGradImpl( - const RocmKernel& rocm_kernel, - const T* grad_data, - const Tin* indices_data, - const int64_t num_indices, - const int64_t num_weights, - const int64_t stride, - T* output_data, - const int64_t num_inputs, //The number of input elements starting from the gathering dimension - const int64_t param_itrs //The size of dimensions of the data before gathering dimension - ) { - // allocate intermediate buffers - auto original_indices = rocm_kernel.template GetScratchBuffer(num_indices); - hipStream_t stream = rocm_kernel.Stream(); - - // initialize original_indices with [0, num_indices) - { - const auto blocks_per_grid = CeilDiv(num_indices, GridDim::maxThreadsPerBlock); - hipcub::CountingInputIterator counting_input(Tin{}); - hipLaunchKernelGGL(_Iota, dim3(blocks_per_grid), dim3(GridDim::maxThreadsPerBlock), 0, stream, - counting_input, num_indices, original_indices.get()); - } - - auto indices_data_sorted = rocm_kernel.template GetScratchBuffer(num_indices); - auto original_indices_sorted = rocm_kernel.template GetScratchBuffer(num_indices); - - // sort indices and original indices - size_t sort_temp_storage_size_bytes = 0; - HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs( - nullptr, sort_temp_storage_size_bytes, - indices_data, indices_data_sorted.get(), - original_indices.get(), original_indices_sorted.get(), - num_indices, 0, sizeof(Tin)*8, stream)); - - auto sort_temp_storage = rocm_kernel.GetScratchBuffer(sort_temp_storage_size_bytes); - - HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs( - sort_temp_storage.get(), sort_temp_storage_size_bytes, - indices_data, indices_data_sorted.get(), - original_indices.get(), original_indices_sorted.get(), - num_indices, 0, sizeof(Tin)*8, stream)); - - dim3 block(GPU_WARP_SIZE, 4); - dim3 grid(CeilDiv(num_indices, 4), CeilDiv(stride, GridDim::maxElementsPerThread * GPU_WARP_SIZE)); - -// commented optimization resulted in increase variance of loss for BERT across multiple reasons -//if (param_itrs == 1) -//{ -// hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherAxis0GradImpl), dim3(grid), dim3(block), 0, stream, -// indices_data_sorted.get(), -// original_indices_sorted.get(), -// grad_data, -// output_data, -// num_indices, -// num_inputs, -// stride); -//} else { - hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherGradImpl), dim3(grid), dim3(block), 0, stream, - indices_data_sorted.get(), - original_indices_sorted.get(), - grad_data, - output_data, - num_indices, - num_inputs, - param_itrs, - stride); -//} -} - -#define SPECIALIZED_GRAD_IMPL2(T) \ - template void GatherGradImpl( \ - const RocmKernel& rocm_kernel, \ - const T* grad_data, \ - const int64_t* indices_data, \ - const int64_t num_indices, \ - const int64_t num_weights, \ - const int64_t stride, \ - T* output_data, \ - const int64_t num_inputs, \ - const int64_t param_itrs); \ - template void GatherGradImpl( \ - const RocmKernel& rocm_kernel, \ - const T* grad_data, \ - const int32_t* indices_data, \ - const int64_t num_indices, \ - const int64_t num_weights, \ - const int64_t stride, \ - T* output_data, \ - const int64_t num_inputs, \ - const int64_t param_itrs); - -SPECIALIZED_GRAD_IMPL2(float) -SPECIALIZED_GRAD_IMPL2(half) - -} // namespace rocm -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.h b/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.h deleted file mode 100644 index ab910a4663..0000000000 --- a/orttraining/orttraining/training_ops/rocm/tensor/gather_grad_impl.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" - -namespace onnxruntime { -namespace rocm { - -template -void GatherGradImpl( - const RocmKernel& rocm_kernel, - const T* grad_data, - const Tin* indices_data, - const int64_t num_indices, - const int64_t num_weights, - const int64_t stride, - T* output_data, - const int64_t num_inputs, - const int64_t param_itrs); - -} // namespace rocm -} // namespace onnxruntime diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 897f4b6425..4b2bdc7329 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -204,12 +204,8 @@ training_ops_excluded_files = [ 'nn/batch_norm_grad.h', 'optimizer/adam.cc', 'optimizer/adam.cu', - 'optimizer/lamb.cc', 'reduction/reduction_all.cc', 'reduction/reduction_ops.cc', - 'tensor/gather_grad.cc', - 'tensor/gather_grad_impl.cu', - 'tensor/gather_grad_impl.h', 'tensor/gather_nd_grad_impl.cu', 'cuda_training_kernels.cc', 'cuda_training_kernels.h', @@ -245,8 +241,18 @@ def hipify(src_file_path, dst_file_path): s = s.replace('GPU_WARP_SIZE = 32', 'GPU_WARP_SIZE = 64') s = s.replace('std::exp', 'expf') s = s.replace('std::log', 'logf') - s = s.replace('#include ', '#include ') - s = s.replace('#include ', '') + s = s.replace('#include ', + '#include \n#include ') + s = s.replace('#include ', + '#include ') + s = s.replace('#include ', + '#include ') + s = s.replace('#include ', + '#include ') + s = s.replace('#include ', + '#include ') + s = s.replace('#include ', + '#include ') s = s.replace('typedef half MappedType', 'typedef __half MappedType') # CUBLAS -> ROCBLAS # s = s.replace('CUBLAS', 'HIPBLAS')