mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
update lamb and GatherGrad kernel for ROCm EP (#7184)
With ROCm4.1, the CUDA implementation of Lamb and GatherGrad can be utilized for ROCm EP.
This commit is contained in:
parent
17f91ff410
commit
a3f17c8b0d
10 changed files with 31 additions and 1085 deletions
|
|
@ -1036,7 +1036,7 @@ if (onnxruntime_USE_ROCM)
|
|||
target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template)
|
||||
endif()
|
||||
# During transition to separate hipFFT repo, put hipfft/include early
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen)
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ void Check<double>(const OpTester::Data& expected_data,
|
|||
}
|
||||
|
||||
double threshold = 0.001;
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
threshold = 0.005;
|
||||
#endif
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ void InternalNumericalCheck(const OpTester::Data& expected_data,
|
|||
}
|
||||
|
||||
float threshold = 0.0001f;
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
threshold = 0.005f;
|
||||
#endif
|
||||
|
||||
|
|
@ -247,7 +247,7 @@ void Check<MLFloat16>(const OpTester::Data& expected_data,
|
|||
}
|
||||
|
||||
float threshold = 0.001f;
|
||||
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA)
|
||||
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
|
||||
threshold = 0.005f;
|
||||
#endif
|
||||
for (int i = 0; i < size; ++i) {
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ TEST(GatherGradOpTest, GatherFewDistinctIndices) {
|
|||
RunGatherGradTestWithRandomData<float>(0, {2, 32}, {6, 128}, absolute_error);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
namespace {
|
||||
void RunGatherGradConsistentOutputTest(
|
||||
int64_t axis,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ IAllocatorUniquePtr<T> GetOffsetsFromCounts(
|
|||
|
||||
// adapted from here:
|
||||
// https://github.com/pytorch/pytorch/blob/b186831c08e0e4e447eedb8a5cfab582995d37f9/aten/src/ATen/native/cuda/Embedding.cu#L121
|
||||
template <typename T, typename TIndex>
|
||||
template <typename T, typename TIndex, int NumElementsPerThread>
|
||||
__global__ void DirectSumKernel(
|
||||
const TIndex* dX_indices_sorted,
|
||||
const TIndex* dY_indices_sorted,
|
||||
|
|
@ -116,21 +116,20 @@ __global__ void DirectSumKernel(
|
|||
int64_t num_batches) {
|
||||
GatheredIndexIndex_t idx = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
const int SZ = 4;
|
||||
if (idx < num_gathered_indices && (idx == 0 || dX_indices_sorted[idx] != dX_indices_sorted[idx - 1])) {
|
||||
do {
|
||||
for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
|
||||
const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * SZ;
|
||||
const auto gathered_element_idx_start = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
|
||||
const auto dX_row_offset =
|
||||
(batch_idx * gather_dimension_size + dX_indices_sorted[idx]) * num_gathered_per_index;
|
||||
const auto dY_row_offset =
|
||||
(batch_idx * num_gathered_indices + dY_indices_sorted[idx]) * num_gathered_per_index;
|
||||
|
||||
AccumulationType_t<T> dY_value[SZ];
|
||||
AccumulationType_t<T> dX_value[SZ];
|
||||
AccumulationType_t<T> dY_value[NumElementsPerThread];
|
||||
AccumulationType_t<T> dX_value[NumElementsPerThread];
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE;
|
||||
if (gathered_element_idx < num_gathered_per_index) {
|
||||
dY_value[ii] = static_cast<AccumulationType_t<T>>(dY_data[dY_row_offset + gathered_element_idx]);
|
||||
|
|
@ -139,12 +138,12 @@ __global__ void DirectSumKernel(
|
|||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
dX_value[ii] += dY_value[ii];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
const auto gathered_element_idx = gathered_element_idx_start + ii * GPU_WARP_SIZE;
|
||||
if (gathered_element_idx < num_gathered_per_index) {
|
||||
dX_data[dX_row_offset + gathered_element_idx] = static_cast<T>(dX_value[ii]);
|
||||
|
|
@ -169,9 +168,9 @@ void DirectSumImpl(
|
|||
int64_t gather_dimension_size,
|
||||
int64_t num_batches) {
|
||||
dim3 block(GPU_WARP_SIZE, 4);
|
||||
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128));
|
||||
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
|
||||
|
||||
DirectSumKernel<<<grid, block, 0, stream>>>(
|
||||
DirectSumKernel<T, TIndex, GridDim::maxElementsPerThread><<<grid, block, 0, stream>>>(
|
||||
dX_indices_sorted,
|
||||
dY_indices_sorted,
|
||||
dY_data,
|
||||
|
|
@ -509,9 +508,9 @@ void Impl_Simplified(
|
|||
dX_indices_sorted, dY_indices_sorted);
|
||||
|
||||
dim3 block(GPU_WARP_SIZE, 4);
|
||||
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, 128));
|
||||
dim3 grid(CeilDiv(num_gathered_indices, 4), CeilDiv(num_gathered_per_index, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
|
||||
|
||||
DirectSumKernel<<<grid, block, 0, stream>>>(
|
||||
DirectSumKernel<T, TIndex, GridDim::maxElementsPerThread><<<grid, block, 0, stream>>>(
|
||||
dX_indices_sorted.get(),
|
||||
dY_indices_sorted.get(),
|
||||
dY_data,
|
||||
|
|
|
|||
|
|
@ -1,705 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cmath>
|
||||
#include "core/providers/rocm/rocm_allocator.h"
|
||||
#include "core/providers/rocm/reduction/reduction_functions.h"
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops.h"
|
||||
#include "orttraining/training_ops/rocm/optimizer/common.h"
|
||||
#include "orttraining/training_ops/rocm/optimizer/lamb.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
std::vector<std::pair<int, int>> GenerateLambExtraAliasMapping() {
|
||||
// Starting index of extra inputs.
|
||||
constexpr int input_index_bias = 5;
|
||||
// Starting index of extra outputs.
|
||||
constexpr int output_index_bias = 1;
|
||||
// Count of extra I/O groups. One group corresponds to a weight update.
|
||||
constexpr int group_count = 1024;
|
||||
// length of [w, g, m1, m2, w_mixed_precision].
|
||||
constexpr int input_stride = 5;
|
||||
// length of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
|
||||
constexpr int output_stride = 5;
|
||||
|
||||
std::vector<std::pair<int, int>> alias_pairs{};
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
const int input = input_index_bias + i * input_stride;
|
||||
const int output = output_index_bias + i * output_stride;
|
||||
// w --> w_new
|
||||
alias_pairs.emplace_back(std::make_pair(input, output));
|
||||
// g --> g_new
|
||||
alias_pairs.emplace_back(std::make_pair(input + 1, output + 1));
|
||||
// m1 --> m1_new
|
||||
alias_pairs.emplace_back(std::make_pair(input + 2, output + 2));
|
||||
// m2 --> m2_new
|
||||
alias_pairs.emplace_back(std::make_pair(input + 3, output + 3));
|
||||
// w_mixed_precision --> w_mixed_precision_new
|
||||
alias_pairs.emplace_back(std::make_pair(input + 4, output + 4));
|
||||
}
|
||||
|
||||
// update_count are updated in place.
|
||||
alias_pairs.emplace_back(std::make_pair(4, 0));
|
||||
|
||||
return alias_pairs;
|
||||
}
|
||||
|
||||
// TODO: Once Schema is checked in to onnx lets fix this to match that
|
||||
#define REGISTER_LAMB_KERNEL_TYPED(T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
LambOptimizer, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1##_##T2##_##T3##_##T4##_##T_GRAD_NORM##_##T_MIXED_PRECISION_FP, \
|
||||
kRocmExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.Alias(GenerateLambExtraAliasMapping()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep do_update in CPU */ \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(4) /* Keep iteration_count in CPU */ \
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* Keep iteration_count in CPU */ \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()) \
|
||||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T3>()) \
|
||||
.TypeConstraint("T4", DataTypeImpl::GetTensorType<T4>()) \
|
||||
.TypeConstraint("T_MIXED_PRECISION_FP", DataTypeImpl::GetTensorType<T_MIXED_PRECISION_FP>()) \
|
||||
.TypeConstraint("T_GRAD_NORM", DataTypeImpl::GetTensorType<T_GRAD_NORM>()), \
|
||||
LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>);
|
||||
|
||||
REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, MLFloat16, MLFloat16)
|
||||
REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, float, MLFloat16)
|
||||
REGISTER_LAMB_KERNEL_TYPED(float, float, float, float, float, MLFloat16)
|
||||
// REGISTER_LAMB_KERNEL_TYPED(double, double, double, double, double, MLFloat16)
|
||||
// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, MLFloat16, MLFloat16)
|
||||
// REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, MLFloat16, float, MLFloat16)
|
||||
REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, MLFloat16, MLFloat16)
|
||||
REGISTER_LAMB_KERNEL_TYPED(MLFloat16, float, MLFloat16, float, float, MLFloat16)
|
||||
|
||||
void check_inputs_and_outputs(
|
||||
const Tensor* w,
|
||||
const Tensor* g,
|
||||
const Tensor* m1,
|
||||
const Tensor* m2,
|
||||
const Tensor* w_mixed_precision,
|
||||
const Tensor* w_new,
|
||||
const Tensor* g_new,
|
||||
const Tensor* m1_new,
|
||||
const Tensor* m2_new,
|
||||
const Tensor* w_mixed_precision_new) {
|
||||
// Throw if we have incomplete input or output lists.
|
||||
ORT_ENFORCE(w, "Weight tensor should not be null.");
|
||||
ORT_ENFORCE(g, "gradient tensor should not be null.");
|
||||
ORT_ENFORCE(m1, "First-order momentum tensor should not be null.");
|
||||
ORT_ENFORCE(m2, "Second-order momentum tensor should not be null.");
|
||||
ORT_ENFORCE(m1_new, "New first-order momentum tensor should not be null.");
|
||||
ORT_ENFORCE(m2_new, "New second-order momentum tensor should not be null.");
|
||||
// Check if all shapes are good.
|
||||
ORT_ENFORCE(m1->Shape() == m1_new->Shape());
|
||||
ORT_ENFORCE(m2->Shape() == m2_new->Shape());
|
||||
if (w_new)
|
||||
ORT_ENFORCE(w->Shape() == w_new->Shape());
|
||||
if (g_new)
|
||||
ORT_ENFORCE(g->Shape() == g_new->Shape());
|
||||
if (w_mixed_precision && w_mixed_precision_new)
|
||||
ORT_ENFORCE(w_mixed_precision->Shape() == w_mixed_precision_new->Shape());
|
||||
}
|
||||
|
||||
template <typename TWeight, typename TGradient, typename TMomentum, typename TMixedPrecision>
|
||||
Status copy_inputs_to_outputs(
|
||||
hipStream_t stream,
|
||||
OpKernelContext* ctx,
|
||||
const int non_grouped_input_count,
|
||||
const int non_grouped_output_count,
|
||||
const int group_count,
|
||||
const int input_group_size,
|
||||
const int output_group_size) {
|
||||
const Tensor* step_tensor = ctx->Input<Tensor>(4);
|
||||
if (step_tensor) {
|
||||
const int64_t* step_data = step_tensor->template Data<int64_t>();
|
||||
Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape());
|
||||
ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together.");
|
||||
int64_t* step_data_new = step_tensor_new->template MutableData<int64_t>();
|
||||
*step_data_new = *step_data;
|
||||
}
|
||||
|
||||
for (int group_index = 0; group_index < group_count; ++group_index) {
|
||||
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
|
||||
const Tensor& w = *ctx->Input<Tensor>(input_start_index);
|
||||
const Tensor& g = *ctx->Input<Tensor>(input_start_index + 1);
|
||||
const Tensor& m1 = *ctx->Input<Tensor>(input_start_index + 2);
|
||||
const Tensor& m2 = *ctx->Input<Tensor>(input_start_index + 3);
|
||||
const Tensor* w_mixed_precision = ctx->Input<Tensor>(input_start_index + 4);
|
||||
const int output_start_index = non_grouped_output_count + group_index * output_group_size;
|
||||
Tensor* w_new = ctx->Output(output_start_index, w.Shape());
|
||||
Tensor* g_new = ctx->Output(output_start_index + 1, g.Shape());
|
||||
Tensor& m1_new = *ctx->Output(output_start_index + 2, m1.Shape());
|
||||
Tensor& m2_new = *ctx->Output(output_start_index + 3, m2.Shape());
|
||||
Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr;
|
||||
|
||||
// TODO: temporary hack until View is improved (it doesn't work with Alias)
|
||||
if (w_new != nullptr)
|
||||
w_new->SetByteOffset(w.ByteOffset());
|
||||
if (g_new != nullptr)
|
||||
g_new->SetByteOffset(g.ByteOffset());
|
||||
if (w_mixed_precision_new != nullptr)
|
||||
w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset());
|
||||
|
||||
if (w_new) {
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TWeight>(stream, w, *w_new));
|
||||
}
|
||||
if (g_new) {
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TGradient>(stream, g, *g_new));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMomentum>(stream, m1, m1_new));
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMomentum>(stream, m2, m2_new));
|
||||
|
||||
if (w_mixed_precision_new) {
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<TMixedPrecision>(stream, *w_mixed_precision, *w_mixed_precision_new));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename HipT2, typename HipT3, typename HipT4, typename HipT_GRAD_NORM>
|
||||
Status launch_lamb_compute_direction(
|
||||
hipStream_t stream,
|
||||
const int64_t update_count,
|
||||
const int group_count,
|
||||
const HipT2* p_loss_scale,
|
||||
const HipT_GRAD_NORM* p_g_norm,
|
||||
std::vector<int>& tensor_sizes,
|
||||
std::vector<const HipT2*>& p_ws,
|
||||
std::vector<const HipT3*>& p_gs,
|
||||
std::vector<const HipT4*>& p_m1s,
|
||||
std::vector<const HipT4*>& p_m2s,
|
||||
std::vector<HipT3*>& p_ds,
|
||||
std::vector<HipT4*>& p_m1_news,
|
||||
std::vector<HipT4*>& p_m2_news,
|
||||
const std::vector<float>& alphas,
|
||||
const std::vector<float>& betas,
|
||||
const std::vector<float>& lambdas,
|
||||
const std::vector<float>& epsilons,
|
||||
const std::vector<float>& max_norms,
|
||||
const int64_t do_bias_correction) {
|
||||
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_gs.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_m1s.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_m2s.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_m1_news.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_m2_news.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(alphas.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(betas.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(lambdas.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(epsilons.size()));
|
||||
|
||||
constexpr int tensor_count_per_group = 6;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
// Bucketize tensor groups by the associated optimizer configuration.
|
||||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::map<std::tuple<float, float, float, float, float>, std::vector<std::vector<void*>>> buckets;
|
||||
std::map<std::tuple<float, float, float, float, float>, std::vector<int>> tensor_sizes_in_buckets;
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
// For the first iteration (indexed by 0), the update count should be 2.
|
||||
const float alpha_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alphas[i], update_count) : 1.f;
|
||||
const float beta_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(betas[i], update_count) : 1.f;
|
||||
|
||||
LambComputeDirection(
|
||||
stream,
|
||||
p_ws[i],
|
||||
p_gs[i],
|
||||
p_m1s[i],
|
||||
p_m2s[i],
|
||||
p_loss_scale,
|
||||
p_g_norm,
|
||||
HipT4(alphas[i]),
|
||||
HipT4(betas[i]),
|
||||
HipT2(lambdas[i]),
|
||||
HipT4(epsilons[i]),
|
||||
HipT2(max_norms[i]),
|
||||
HipT4(alpha_correction),
|
||||
HipT4(beta_correction),
|
||||
p_ds[i],
|
||||
p_m1_news[i],
|
||||
p_m2_news[i],
|
||||
tensor_sizes[i]);
|
||||
} else {
|
||||
std::vector<void*> ptrs(tensor_count_per_group);
|
||||
ptrs[0] = const_cast<HipT2*>(p_ws[i]); // weight tensor
|
||||
ptrs[1] = const_cast<HipT3*>(p_gs[i]); // gradient (reused to store update direction)
|
||||
ptrs[2] = const_cast<HipT4*>(p_m1s[i]); // 1st momentum
|
||||
ptrs[3] = const_cast<HipT4*>(p_m2s[i]); // 2nd momentum
|
||||
ptrs[4] = p_m1_news[i]; // new 1st momentum
|
||||
ptrs[5] = p_m2_news[i]; // new 2nd momentum
|
||||
|
||||
auto key = std::make_tuple(alphas[i], betas[i], lambdas[i], epsilons[i], max_norms[i]);
|
||||
buckets[key].push_back(ptrs);
|
||||
tensor_sizes_in_buckets[key].push_back(tensor_sizes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& pair : buckets) {
|
||||
const auto key = pair.first;
|
||||
float alpha = 0.f, beta = 0.f, lambda = 0.f, epsilon = 0.f, max_norm = 0.f;
|
||||
std::tie(alpha, beta, lambda, epsilon, max_norm) = key;
|
||||
|
||||
// For the first iteration (indexed by 0), the update count should be 1.
|
||||
const float alpha_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alpha, update_count) : 1.f;
|
||||
const float beta_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(beta, update_count) : 1.f;
|
||||
|
||||
typedef LambMultiTensorComputeDirectionFunctor<HipT2, HipT3, HipT4, HipT_GRAD_NORM> LambStage1;
|
||||
LambStage1 lamb_stage1;
|
||||
|
||||
launch_multi_tensor_functor<tensor_count_per_group, LambStage1>(
|
||||
stream,
|
||||
2048 * 32,
|
||||
tensor_sizes_in_buckets[key],
|
||||
buckets[key],
|
||||
lamb_stage1,
|
||||
p_loss_scale, p_g_norm, lambda, alpha, beta, epsilon, HipT2(max_norm), alpha_correction, beta_correction);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename HipTNorm, typename HipTIn1, typename HipTIn2>
|
||||
Status launch_lamb_reduction(
|
||||
const RocmKernel& kernel,
|
||||
const int group_count,
|
||||
std::vector<int>& tensor_sizes,
|
||||
std::vector<HipTNorm*>& p_w_norms,
|
||||
std::vector<HipTNorm*>& p_d_norms,
|
||||
std::vector<const HipTIn1*>& p_ws,
|
||||
std::vector<HipTIn2*>& p_ds,
|
||||
void* reduction_buffer,
|
||||
size_t reduction_buffer_size) {
|
||||
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_w_norms.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_d_norms.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
|
||||
|
||||
constexpr int tensor_count_per_group = 4;
|
||||
hipStream_t stream = kernel.Stream();
|
||||
// Bucketize tensor groups by the associated optimizer configuration.
|
||||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::vector<std::vector<void*>> buckets;
|
||||
std::vector<int> tensor_sizes_in_buckets;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
ORT_RETURN_IF_ERROR(reduce_square_sum(
|
||||
stream,
|
||||
p_ws[i],
|
||||
p_w_norms[i],
|
||||
tensor_sizes[i],
|
||||
reduction_buffer,
|
||||
reduction_buffer_size));
|
||||
ORT_RETURN_IF_ERROR(reduce_square_sum(
|
||||
stream,
|
||||
p_ds[i],
|
||||
p_d_norms[i],
|
||||
tensor_sizes[i],
|
||||
reduction_buffer,
|
||||
reduction_buffer_size));
|
||||
} else {
|
||||
std::vector<void*> ptrs(tensor_count_per_group);
|
||||
ptrs[0] = const_cast<HipTIn1*>(p_ws[i]); // weight tensor
|
||||
ptrs[1] = const_cast<HipTIn2*>(p_ds[i]); // update direction
|
||||
ptrs[2] = p_w_norms[i]; // weight tensor's norm
|
||||
ptrs[3] = p_d_norms[i]; // update direction's norm
|
||||
|
||||
buckets.push_back(ptrs);
|
||||
tensor_sizes_in_buckets.push_back(tensor_sizes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (buckets.size() > 0) {
|
||||
ORT_ENFORCE(tensor_sizes_in_buckets.size() > 0);
|
||||
}
|
||||
|
||||
if (tensor_sizes_in_buckets.size() > 0) {
|
||||
ORT_ENFORCE(buckets.size() > 0);
|
||||
}
|
||||
|
||||
// Only launch multi-tensor function if we have at least one tensor in the buckets.
|
||||
if (tensor_sizes_in_buckets.size() > 0 && buckets.size() > 0) {
|
||||
typedef LambMultiTensorReductionFunctor<HipTIn1, HipTIn2, HipTNorm, HipTNorm, HipTNorm> TReducer;
|
||||
TReducer reducer;
|
||||
launch_multi_tensor_functor<tensor_count_per_group, TReducer>(
|
||||
stream,
|
||||
2048 * 32,
|
||||
tensor_sizes_in_buckets,
|
||||
buckets,
|
||||
reducer,
|
||||
kernel,
|
||||
reduction_buffer,
|
||||
reduction_buffer_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename HipT1, typename HipT2, typename HipT3, typename HipT_MIXED_PRECISION_FP>
|
||||
Status launch_lamb_update(
|
||||
hipStream_t stream,
|
||||
const int group_count,
|
||||
const HipT1* eta,
|
||||
const float ratio_min,
|
||||
const float ratio_max,
|
||||
std::vector<int>& tensor_sizes,
|
||||
std::vector<HipT2*>& p_w_norms,
|
||||
std::vector<HipT2*>& p_d_norms,
|
||||
std::vector<const HipT2*>& p_ws,
|
||||
std::vector<HipT3*>& p_ds,
|
||||
/* output */ std::vector<HipT2*>& p_w_news,
|
||||
/* output */ std::vector<HipT3*>& p_g_news,
|
||||
/* output */ std::vector<HipT_MIXED_PRECISION_FP*>& p_w_mixed_precision_news) {
|
||||
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_w_norms.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_d_norms.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ws.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_ds.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_w_news.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_g_news.size()));
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_w_mixed_precision_news.size()));
|
||||
|
||||
constexpr int tensor_count_per_group = 7;
|
||||
|
||||
// Bucketize tensor groups by the associated optimizer configuration.
|
||||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::vector<std::vector<void*>> buckets;
|
||||
std::vector<int> tensor_sizes_in_bucket;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
LambUpdate(
|
||||
stream,
|
||||
eta,
|
||||
ratio_min,
|
||||
ratio_max,
|
||||
p_d_norms[i],
|
||||
p_w_norms[i],
|
||||
p_ws[i],
|
||||
p_ds[i],
|
||||
p_w_news[i],
|
||||
p_g_news[i],
|
||||
p_w_mixed_precision_news[i],
|
||||
tensor_sizes[i]);
|
||||
} else {
|
||||
std::vector<void*> ptrs(tensor_count_per_group);
|
||||
ptrs[0] = p_w_norms[i]; // weight tensor's norm
|
||||
ptrs[1] = p_d_norms[i]; // direction's norm
|
||||
ptrs[2] = const_cast<HipT2*>(p_ws[i]); // weight tensor
|
||||
ptrs[3] = p_ds[i]; // direction
|
||||
ptrs[4] = p_w_news[i]; // new weight tensor
|
||||
ptrs[5] = p_g_news[i]; // new gradient tensor
|
||||
ptrs[6] = p_w_mixed_precision_news[i]; // new half-precision weight tensor
|
||||
buckets.push_back(ptrs);
|
||||
tensor_sizes_in_bucket.push_back(tensor_sizes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (buckets.size() > 0) {
|
||||
ORT_ENFORCE(tensor_sizes_in_bucket.size() > 0);
|
||||
}
|
||||
|
||||
if (tensor_sizes_in_bucket.size() > 0) {
|
||||
ORT_ENFORCE(buckets.size() > 0);
|
||||
}
|
||||
|
||||
// Only launch multi-tensor function if we have at least one tensor in the buckets.
|
||||
if (tensor_sizes_in_bucket.size() > 0 && buckets.size() > 0) {
|
||||
typedef LambMultiTensorUpdateFunctor<
|
||||
HipT1, HipT2, HipT3, HipT_MIXED_PRECISION_FP>
|
||||
LambStage2;
|
||||
LambStage2 lamb_stage2;
|
||||
|
||||
launch_multi_tensor_functor<tensor_count_per_group, LambStage2>(
|
||||
stream,
|
||||
2048 * 32,
|
||||
tensor_sizes_in_bucket,
|
||||
buckets,
|
||||
lamb_stage2,
|
||||
eta,
|
||||
ratio_min,
|
||||
ratio_max);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename T4, typename T_GRAD_NORM, typename T_MIXED_PRECISION_FP>
|
||||
Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
// HipT* are types used to invoke ROCM-based functions. It, for example, maps
|
||||
// MLFloat16 in ONNXRuntime to half in ROCM.
|
||||
typedef typename ToHipType<T1>::MappedType HipT1;
|
||||
typedef typename ToHipType<T2>::MappedType HipT2;
|
||||
typedef typename ToHipType<T3>::MappedType HipT3;
|
||||
typedef typename ToHipType<T4>::MappedType HipT4;
|
||||
typedef typename ToHipType<T_GRAD_NORM>::MappedType HipT_GRAD_NORM;
|
||||
typedef typename ToHipType<T_MIXED_PRECISION_FP>::MappedType HipT_MIXED_PRECISION_FP;
|
||||
|
||||
constexpr int non_grouped_input_count = 5;
|
||||
constexpr int input_group_size = 5;
|
||||
constexpr int output_group_size = 5;
|
||||
constexpr int non_grouped_output_count = 1;
|
||||
constexpr int minimal_input_count = non_grouped_input_count + 1 * input_group_size - 1;
|
||||
constexpr int minimal_output_count = non_grouped_output_count + 1 * output_group_size - 1;
|
||||
const int grouped_input_tensor_count = ctx->InputCount() - non_grouped_input_count;
|
||||
const int grouped_output_tensor_count = ctx->OutputCount() - non_grouped_output_count;
|
||||
|
||||
// At least one variable group for updating one weight tensor.
|
||||
ORT_ENFORCE(
|
||||
ctx->InputCount() >= minimal_input_count,
|
||||
"Expect at least ", minimal_input_count, " inputs but got ",
|
||||
ctx->InputCount());
|
||||
// At least one variable group for updating one weight tensor.
|
||||
ORT_ENFORCE(
|
||||
ctx->OutputCount() >= minimal_output_count,
|
||||
"Expect at least ", minimal_output_count, " outputs but got ",
|
||||
ctx->OutputCount());
|
||||
|
||||
// In addition to the first non_grouped_input_count inputs, all inputs are repeated sequence of [w, g, m1, m2, w_mixed_precision].
|
||||
ORT_ENFORCE(
|
||||
grouped_input_tensor_count % input_group_size == 0,
|
||||
"Input count must be ", non_grouped_input_count, " + ", input_group_size,
|
||||
" x (number of weights to optimize).");
|
||||
// Outputs are repeated sequence of [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
|
||||
ORT_ENFORCE(
|
||||
grouped_output_tensor_count % output_group_size == 0,
|
||||
"Output count must be ", non_grouped_output_count, " + ", output_group_size,
|
||||
" x (number of weights to optimize).");
|
||||
// Number of repeated [w, g, m1, m2, w_mixed_precision]'s should match number of repeated [w_new, g_new, m1_new, m2_new, w_mixed_precision_new].
|
||||
ORT_ENFORCE(
|
||||
grouped_input_tensor_count / input_group_size == grouped_output_tensor_count / output_group_size,
|
||||
"Input and output tensor counts are not aligned. Please check LambOptimizer's input and output lists.");
|
||||
|
||||
// Number of [w, g, m1, m2, (w_mixed_precision)] (or [w_new, m1_new, m2_new, (w_mixed_precision_new)]).
|
||||
const int group_count = (grouped_input_tensor_count + input_group_size - 1) / input_group_size;
|
||||
|
||||
// At least we need one group of alpha, beta, lambda, ..., for processing one group.
|
||||
ORT_ENFORCE(alpha_.size() >= static_cast<size_t>(group_count));
|
||||
ORT_ENFORCE(beta_.size() >= static_cast<size_t>(group_count));
|
||||
ORT_ENFORCE(lambda_.size() >= static_cast<size_t>(group_count));
|
||||
ORT_ENFORCE(epsilon_.size() >= static_cast<size_t>(group_count));
|
||||
ORT_ENFORCE(max_norm_clip_.size() >= static_cast<size_t>(group_count));
|
||||
|
||||
// If gradient norm is not finite, we copy inputs to outputs directly.
|
||||
if (ctx->Input<Tensor>(0)) {
|
||||
auto update_signal_tensor = ctx->Input<Tensor>(0);
|
||||
auto update_signal = *update_signal_tensor->template Data<bool>();
|
||||
if (!update_signal) {
|
||||
return copy_inputs_to_outputs<T2, T3, T4, T_MIXED_PRECISION_FP>(
|
||||
Stream(),
|
||||
ctx,
|
||||
non_grouped_input_count,
|
||||
non_grouped_output_count,
|
||||
group_count,
|
||||
input_group_size,
|
||||
output_group_size);
|
||||
}
|
||||
}
|
||||
|
||||
const HipT2* loss_scale_data = nullptr;
|
||||
if (ctx->Input<Tensor>(1)) {
|
||||
const Tensor& loss_scale_tensor = *ctx->Input<Tensor>(1);
|
||||
loss_scale_data = reinterpret_cast<const HipT2*>(loss_scale_tensor.template Data<T2>());
|
||||
}
|
||||
|
||||
const HipT_GRAD_NORM* g_norm_data = nullptr;
|
||||
if (ctx->Input<Tensor>(2)) {
|
||||
const Tensor& g_norm_tensor = *ctx->Input<Tensor>(2);
|
||||
g_norm_data = reinterpret_cast<const HipT_GRAD_NORM*>(g_norm_tensor.template Data<T_GRAD_NORM>());
|
||||
}
|
||||
|
||||
const Tensor& eta = *ctx->Input<Tensor>(3);
|
||||
const HipT1* eta_data = reinterpret_cast<const HipT1*>(eta.template Data<T1>());
|
||||
|
||||
const Tensor* step_tensor = ctx->Input<Tensor>(4);
|
||||
const int64_t* step_data = nullptr;
|
||||
if (step_tensor) {
|
||||
step_data = step_tensor->template Data<int64_t>();
|
||||
}
|
||||
|
||||
// Allocate buffer for reduction computation of update directions.
|
||||
// The i-th update direction's norm is stored at the i-th element.
|
||||
// We reduce type T3 tensor to type T2 scalar. An example is that T3=float16
|
||||
// and T2=float.
|
||||
IAllocatorUniquePtr<T2> d_norm_buffer = GetScratchBuffer<T2>(group_count);
|
||||
HipT2* d_norm_data = reinterpret_cast<HipT2*>(d_norm_buffer.get());
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(d_norm_data, 0, group_count * sizeof(T2), Stream()));
|
||||
|
||||
// Allocate buffer for reduction computation of weight tensor.
|
||||
// The i-th weight's norm is stored at the i-th element.
|
||||
// We reduce type T2 tensor to type T2 scalar. An example is that T2=float.
|
||||
IAllocatorUniquePtr<T2> w_norm_buffer = GetScratchBuffer<T2>(group_count);
|
||||
HipT2* w_norm_data = reinterpret_cast<HipT2*>(w_norm_buffer.get());
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(w_norm_data, 0, group_count * sizeof(T2), Stream()));
|
||||
|
||||
// Find the max size of updated weight tensors.
|
||||
int max_tensor_size = 0;
|
||||
for (int group_index = 0; group_index < group_count; ++group_index) {
|
||||
// Prepare used input tensors for this group.
|
||||
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
|
||||
const Tensor& w = *ctx->Input<Tensor>(input_start_index);
|
||||
max_tensor_size = std::max(max_tensor_size, static_cast<int>(w.Shape().Size()));
|
||||
}
|
||||
|
||||
const size_t reduction_buffer_size = [&]() {
|
||||
// Allocate a buffer in byte for reduction API calls.
|
||||
size_t rbs = compute_reduction_buffer_size<HipT2>(max_tensor_size);
|
||||
|
||||
// Enlarge reduction buffer to accomodate multi-tensor reduction kernel as well
|
||||
const int tensor_group_size = 4; // w, d, w_norm, d_norm
|
||||
const int max_blocks = ChunkGroup<tensor_group_size>::max_block_count;
|
||||
const size_t multitensor_block_reduce_buffer_size = 2 * max_blocks * sizeof(HipT2);
|
||||
rbs = std::max(rbs, multitensor_block_reduce_buffer_size);
|
||||
|
||||
return rbs;
|
||||
}();
|
||||
|
||||
// Allocate reduction buffer whose size is reduction_buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(reduction_buffer_size);
|
||||
|
||||
// Input tensors' pointers.
|
||||
std::vector<const HipT2*> p_ws(group_count);
|
||||
std::vector<const HipT3*> p_gs(group_count);
|
||||
std::vector<const HipT4*> p_m1s(group_count);
|
||||
std::vector<const HipT4*> p_m2s(group_count);
|
||||
std::vector<const HipT_MIXED_PRECISION_FP*> p_w_mixed_precisions(group_count);
|
||||
// ds' is an mutable version of gs' because we want to reuse
|
||||
// gs' memory to store the update direction to avoid allocating a model-scale buffer.
|
||||
std::vector<HipT3*> p_ds(group_count);
|
||||
// Intermediate tensors, weight tensors' and directions' norms.
|
||||
std::vector<HipT2*> p_w_norms(group_count);
|
||||
std::vector<HipT2*> p_d_norms(group_count);
|
||||
// Output tensors' pointers.
|
||||
std::vector<HipT2*> p_w_news(group_count);
|
||||
std::vector<HipT3*> p_g_news(group_count);
|
||||
std::vector<HipT4*> p_m1_news(group_count);
|
||||
std::vector<HipT4*> p_m2_news(group_count);
|
||||
std::vector<HipT_MIXED_PRECISION_FP*> p_w_mixed_precision_news(group_count);
|
||||
// The i-th element in following array is the size of
|
||||
// the i-th updated weight tensor and other related tensors.
|
||||
std::vector<int> tensor_sizes(group_count);
|
||||
|
||||
for (int group_index = 0; group_index < group_count; ++group_index) {
|
||||
// Prepare used input tensors for this group.
|
||||
const int input_start_index = non_grouped_input_count + group_index * input_group_size;
|
||||
const Tensor* w = ctx->Input<Tensor>(input_start_index);
|
||||
const Tensor* g = ctx->Input<Tensor>(input_start_index + 1);
|
||||
const Tensor* m1 = ctx->Input<Tensor>(input_start_index + 2);
|
||||
const Tensor* m2 = ctx->Input<Tensor>(input_start_index + 3);
|
||||
const Tensor* w_mixed_precision = ctx->Input<Tensor>(input_start_index + 4);
|
||||
|
||||
// Prepare used outputs tensors for this group.
|
||||
const int output_start_index = non_grouped_output_count + group_index * output_group_size;
|
||||
Tensor* w_new = ctx->Output(output_start_index, w->Shape());
|
||||
Tensor* g_new = ctx->Output(output_start_index + 1, g->Shape());
|
||||
Tensor* m1_new = ctx->Output(output_start_index + 2, m1->Shape());
|
||||
Tensor* m2_new = ctx->Output(output_start_index + 3, m2->Shape());
|
||||
Tensor* w_mixed_precision_new = w_mixed_precision != nullptr ? ctx->Output(output_start_index + 4, w_mixed_precision->Shape()) : nullptr;
|
||||
|
||||
// TODO: temporary hack until View is improved (it doesn't work with Alias)
|
||||
if (w_new != nullptr)
|
||||
w_new->SetByteOffset(w->ByteOffset());
|
||||
if (g_new != nullptr)
|
||||
g_new->SetByteOffset(g->ByteOffset());
|
||||
if (w_mixed_precision_new != nullptr)
|
||||
w_mixed_precision_new->SetByteOffset(w_mixed_precision->ByteOffset());
|
||||
|
||||
check_inputs_and_outputs(w, g, m1, m2, w_mixed_precision, w_new, g_new, m1_new, m2_new, w_mixed_precision_new);
|
||||
|
||||
// We should throw for preventing overflow in reduction APIs.
|
||||
// The index in ROCM system is 32-bit integer.
|
||||
ORT_ENFORCE(
|
||||
w->Shape().Size() <
|
||||
static_cast<int64_t>(std::numeric_limits<int>::max()));
|
||||
tensor_sizes[group_index] = static_cast<int>(w->Shape().Size());
|
||||
|
||||
// Input tensors' pointers.
|
||||
p_ws[group_index] = reinterpret_cast<const HipT2*>(w->template Data<T2>());
|
||||
p_gs[group_index] = reinterpret_cast<const HipT3*>(g->template Data<T3>());
|
||||
p_m1s[group_index] = reinterpret_cast<const HipT4*>(m1->template Data<T4>());
|
||||
p_m2s[group_index] = reinterpret_cast<const HipT4*>(m2->template Data<T4>());
|
||||
p_w_mixed_precisions[group_index] = w_mixed_precision != nullptr ? reinterpret_cast<const HipT_MIXED_PRECISION_FP*>(w_mixed_precision->template Data<T_MIXED_PRECISION_FP>()) : nullptr;
|
||||
|
||||
// The following cast is for reusing gradient tensor g to store update direction d.
|
||||
p_ds[group_index] = const_cast<HipT3*>(reinterpret_cast<const HipT3*>(g->template Data<T3>()));
|
||||
|
||||
// Set up which pointer to store which tensor's norm.
|
||||
p_w_norms[group_index] = w_norm_data + group_index;
|
||||
p_d_norms[group_index] = d_norm_data + group_index;
|
||||
|
||||
// Output tensors' pointers.
|
||||
p_w_news[group_index] = w_new != nullptr ? reinterpret_cast<HipT2*>(w_new->template MutableData<T2>()) : nullptr;
|
||||
p_g_news[group_index] = g_new != nullptr ? reinterpret_cast<HipT3*>(g_new->template MutableData<T3>()) : nullptr;
|
||||
p_m1_news[group_index] = reinterpret_cast<HipT4*>(m1_new->template MutableData<T4>());
|
||||
p_m2_news[group_index] = reinterpret_cast<HipT4*>(m2_new->template MutableData<T4>());
|
||||
p_w_mixed_precision_news[group_index] = w_mixed_precision_new != nullptr ? reinterpret_cast<HipT_MIXED_PRECISION_FP*>(w_mixed_precision_new->template MutableData<T_MIXED_PRECISION_FP>()) : nullptr;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_compute_direction(
|
||||
Stream(),
|
||||
step_data ? *step_data : 0,
|
||||
group_count,
|
||||
loss_scale_data,
|
||||
g_norm_data,
|
||||
tensor_sizes,
|
||||
p_ws, p_gs, p_m1s, p_m2s,
|
||||
p_ds,
|
||||
p_m1_news, p_m2_news,
|
||||
alpha_, beta_, lambda_, epsilon_, max_norm_clip_,
|
||||
do_bias_correction_));
|
||||
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_reduction(
|
||||
*this,
|
||||
group_count,
|
||||
tensor_sizes,
|
||||
p_w_norms,
|
||||
p_d_norms,
|
||||
p_ws,
|
||||
p_ds,
|
||||
reduction_buffer.get(),
|
||||
reduction_buffer_size));
|
||||
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_update(
|
||||
Stream(),
|
||||
group_count,
|
||||
eta_data,
|
||||
ratio_min_,
|
||||
ratio_max_,
|
||||
tensor_sizes,
|
||||
p_w_norms,
|
||||
p_d_norms,
|
||||
p_ws,
|
||||
p_ds,
|
||||
p_w_news,
|
||||
p_g_news,
|
||||
p_w_mixed_precision_news));
|
||||
|
||||
if (step_tensor) {
|
||||
Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape());
|
||||
ORT_ENFORCE(step_tensor_new != nullptr, "Step tensor (input) and updated step tensor (output) must be specified together.");
|
||||
int64_t* step_data_new = step_tensor_new->template MutableData<int64_t>();
|
||||
*step_data_new = *step_data + 1;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -159,9 +159,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_float_float_float_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_float_MLFloat16_float_float_MLFloat16, LambOptimizer)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double_double_double_MLFloat16, LambOptimizer)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, LambOptimizer)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_float_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double_double_double_double_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_MLFloat16_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_MLFloat16_float_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_MLFloat16_MLFloat16, LambOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16_float_float_MLFloat16, LambOptimizer)>,
|
||||
|
||||
|
|
|
|||
|
|
@ -1,113 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/rocm/tensor/gather_grad.h"
|
||||
#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GatherGrad,
|
||||
kMSDomain,
|
||||
1,
|
||||
kRocmExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0)
|
||||
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()})
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
GatherGrad);
|
||||
|
||||
namespace {
|
||||
template <typename T, typename Tin>
|
||||
Status CallGatherGradImpl(
|
||||
const RocmKernel& rocm_kernel,
|
||||
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
|
||||
const Tensor& grad, const Tensor& indices,
|
||||
Tensor& output) {
|
||||
using HipT = typename ToHipType<T>::MappedType;
|
||||
|
||||
const T* grad_data = grad.template Data<T>();
|
||||
T* output_data = output.template MutableData<T>();
|
||||
const Tin* indices_data = indices.template Data<Tin>();
|
||||
|
||||
GatherGradImpl(
|
||||
rocm_kernel,
|
||||
reinterpret_cast<const HipT*>(grad_data),
|
||||
indices_data,
|
||||
indices.Shape().Size(),
|
||||
num_weights,
|
||||
stride,
|
||||
reinterpret_cast<HipT*>(output_data),
|
||||
num_inputs,
|
||||
param_itrs);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DispatchToGatherGradImplByTin(
|
||||
MLDataType tin_data_type,
|
||||
const RocmKernel& rocm_kernel,
|
||||
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
|
||||
const Tensor& grad, const Tensor& indices,
|
||||
Tensor& output) {
|
||||
if (utils::IsPrimitiveDataType<int32_t>(tin_data_type)) {
|
||||
return CallGatherGradImpl<T, int32_t>(
|
||||
rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
|
||||
} else if (utils::IsPrimitiveDataType<int64_t>(tin_data_type)) {
|
||||
return CallGatherGradImpl<T, int64_t>(
|
||||
rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported Tin type: ", tin_data_type);
|
||||
}
|
||||
|
||||
Status DispatchToGatherGradImpl(
|
||||
MLDataType t_data_type, MLDataType tin_data_type,
|
||||
const RocmKernel& rocm_kernel,
|
||||
int64_t num_weights, int64_t stride, int64_t num_inputs, int64_t param_itrs,
|
||||
const Tensor& grad, const Tensor& indices,
|
||||
Tensor& output) {
|
||||
if (utils::IsPrimitiveDataType<float>(t_data_type)) {
|
||||
return DispatchToGatherGradImplByTin<float>(
|
||||
tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
|
||||
} else if (utils::IsPrimitiveDataType<MLFloat16>(t_data_type)) {
|
||||
return DispatchToGatherGradImplByTin<MLFloat16>(
|
||||
tin_data_type, rocm_kernel, num_weights, stride, num_inputs, param_itrs, grad, indices, output);
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GatherGrad unsupported T type: ", t_data_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GatherGrad::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* shape = context->Input<Tensor>(0);
|
||||
const TensorShape data_shape(shape->template Data<int64_t>(), shape->Shape().Size());
|
||||
const Tensor* indices = context->Input<Tensor>(1);
|
||||
const Tensor* grad = context->Input<Tensor>(2);
|
||||
|
||||
Tensor* output = context->Output(0, data_shape);
|
||||
HIP_RETURN_IF_ERROR(hipMemsetAsync(output->MutableDataRaw(), 0, output->SizeInBytes(), Stream()));
|
||||
MLDataType T_type = grad->DataType();
|
||||
MLDataType Tin_type = indices->DataType();
|
||||
|
||||
const auto axis = HandleNegativeAxis(axis_, data_shape.NumDimensions());
|
||||
const int64_t stride = data_shape.SizeFromDimension(axis + 1);
|
||||
const int64_t num_weights = data_shape.Size() / stride;
|
||||
const int64_t num_inputs = data_shape.SizeFromDimension(axis);
|
||||
const int64_t param_itrs = data_shape.SizeFromDimension(0) / num_inputs;
|
||||
|
||||
return DispatchToGatherGradImpl(
|
||||
T_type, Tin_type, *this,
|
||||
num_weights, stride, num_inputs, param_itrs,
|
||||
*grad, *indices, *output);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,216 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/rocm/tensor/gather_grad_impl.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/shared_inc/rocm_call.h"
|
||||
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
__global__ void _Iota(
|
||||
hipcub::CountingInputIterator<T> input,
|
||||
size_t length,
|
||||
T* output) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, length);
|
||||
output[idx] = input[idx];
|
||||
}
|
||||
|
||||
template <typename T, typename Tin, int NumElementsPerThread>
|
||||
__global__ void _GatherGradImpl(
|
||||
const Tin* input,
|
||||
const Tin* indices,
|
||||
const T* grad_output,
|
||||
T* grad_weight,
|
||||
int64_t numel,
|
||||
int64_t input_numel,
|
||||
int64_t param_itrs,
|
||||
int64_t stride) {
|
||||
int idx = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
|
||||
do {
|
||||
for (int itr = 0; itr < param_itrs; ++itr) {
|
||||
const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
|
||||
const int weight_row = itr * input_numel + ((int)input[idx]) * stride; //the offset of the input
|
||||
const int grad_row = (itr * numel + ((int)indices[idx])) * stride; //the offset of the gradient
|
||||
|
||||
float gradient[NumElementsPerThread];
|
||||
float weight[NumElementsPerThread];
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
int feature_dim = start_feature + ii * GPU_WARP_SIZE;
|
||||
if (feature_dim < stride) {
|
||||
gradient[ii] = static_cast<float>(grad_output[grad_row + feature_dim]);
|
||||
weight[ii] = static_cast<float>(grad_weight[weight_row + feature_dim]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
weight[ii] += gradient[ii];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
int feature_dim = start_feature + ii * GPU_WARP_SIZE;
|
||||
if (feature_dim < stride) {
|
||||
grad_weight[weight_row + feature_dim] = static_cast<T>(weight[ii]);
|
||||
}
|
||||
}
|
||||
}
|
||||
idx++;
|
||||
} while (idx < numel && input[idx] == input[idx - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Special optimization for the case which the gather is on axis=0
|
||||
template <typename T, typename Tin, int NumElementsPerThread>
|
||||
__global__ void _GatherAxis0GradImpl(
|
||||
const Tin* input,
|
||||
const Tin* indices,
|
||||
const T* grad_output,
|
||||
T* grad_weight,
|
||||
int64_t numel,
|
||||
int64_t input_numel,
|
||||
int64_t stride)
|
||||
{
|
||||
int idx = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
|
||||
const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * NumElementsPerThread;
|
||||
const int weight_row = ((int)input[idx]) * stride; //the offset of the input
|
||||
|
||||
float weight[NumElementsPerThread];
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
|
||||
if (feature_dim < stride)
|
||||
weight[ii] = static_cast<float>(grad_weight[weight_row + feature_dim]);
|
||||
}
|
||||
|
||||
do {
|
||||
const int grad_row = ((int)indices[idx]) * stride; //the offset of the gradient
|
||||
float gradient[NumElementsPerThread];
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
|
||||
if (feature_dim < stride) {
|
||||
gradient[ii] = static_cast<float>(grad_output[grad_row + feature_dim]);
|
||||
weight[ii] += gradient[ii];
|
||||
}
|
||||
}
|
||||
idx++;
|
||||
} while (idx < numel && input[idx] == input[idx - 1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NumElementsPerThread; ii++) {
|
||||
int feature_dim = start_feature + ii * GPU_WARP_SIZE/4;
|
||||
if (feature_dim < stride)
|
||||
grad_weight[weight_row + feature_dim] = static_cast<T>(weight[ii]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
void GatherGradImpl(
|
||||
const RocmKernel& rocm_kernel,
|
||||
const T* grad_data,
|
||||
const Tin* indices_data,
|
||||
const int64_t num_indices,
|
||||
const int64_t num_weights,
|
||||
const int64_t stride,
|
||||
T* output_data,
|
||||
const int64_t num_inputs, //The number of input elements starting from the gathering dimension
|
||||
const int64_t param_itrs //The size of dimensions of the data before gathering dimension
|
||||
) {
|
||||
// allocate intermediate buffers
|
||||
auto original_indices = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
|
||||
hipStream_t stream = rocm_kernel.Stream();
|
||||
|
||||
// initialize original_indices with [0, num_indices)
|
||||
{
|
||||
const auto blocks_per_grid = CeilDiv(num_indices, GridDim::maxThreadsPerBlock);
|
||||
hipcub::CountingInputIterator<Tin> counting_input(Tin{});
|
||||
hipLaunchKernelGGL(_Iota, dim3(blocks_per_grid), dim3(GridDim::maxThreadsPerBlock), 0, stream,
|
||||
counting_input, num_indices, original_indices.get());
|
||||
}
|
||||
|
||||
auto indices_data_sorted = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
|
||||
auto original_indices_sorted = rocm_kernel.template GetScratchBuffer<Tin>(num_indices);
|
||||
|
||||
// sort indices and original indices
|
||||
size_t sort_temp_storage_size_bytes = 0;
|
||||
HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs(
|
||||
nullptr, sort_temp_storage_size_bytes,
|
||||
indices_data, indices_data_sorted.get(),
|
||||
original_indices.get(), original_indices_sorted.get(),
|
||||
num_indices, 0, sizeof(Tin)*8, stream));
|
||||
|
||||
auto sort_temp_storage = rocm_kernel.GetScratchBuffer<void>(sort_temp_storage_size_bytes);
|
||||
|
||||
HIP_CALL_THROW(hipcub::DeviceRadixSort::SortPairs(
|
||||
sort_temp_storage.get(), sort_temp_storage_size_bytes,
|
||||
indices_data, indices_data_sorted.get(),
|
||||
original_indices.get(), original_indices_sorted.get(),
|
||||
num_indices, 0, sizeof(Tin)*8, stream));
|
||||
|
||||
dim3 block(GPU_WARP_SIZE, 4);
|
||||
dim3 grid(CeilDiv(num_indices, 4), CeilDiv(stride, GridDim::maxElementsPerThread * GPU_WARP_SIZE));
|
||||
|
||||
// commented optimization resulted in increase variance of loss for BERT across multiple reasons
|
||||
//if (param_itrs == 1)
|
||||
//{
|
||||
// hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherAxis0GradImpl<T, Tin, GridDim::maxElementsPerThread>), dim3(grid), dim3(block), 0, stream,
|
||||
// indices_data_sorted.get(),
|
||||
// original_indices_sorted.get(),
|
||||
// grad_data,
|
||||
// output_data,
|
||||
// num_indices,
|
||||
// num_inputs,
|
||||
// stride);
|
||||
//} else {
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(_GatherGradImpl<T, Tin, GridDim::maxElementsPerThread>), dim3(grid), dim3(block), 0, stream,
|
||||
indices_data_sorted.get(),
|
||||
original_indices_sorted.get(),
|
||||
grad_data,
|
||||
output_data,
|
||||
num_indices,
|
||||
num_inputs,
|
||||
param_itrs,
|
||||
stride);
|
||||
//}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_GRAD_IMPL2(T) \
|
||||
template void GatherGradImpl<T, int64_t>( \
|
||||
const RocmKernel& rocm_kernel, \
|
||||
const T* grad_data, \
|
||||
const int64_t* indices_data, \
|
||||
const int64_t num_indices, \
|
||||
const int64_t num_weights, \
|
||||
const int64_t stride, \
|
||||
T* output_data, \
|
||||
const int64_t num_inputs, \
|
||||
const int64_t param_itrs); \
|
||||
template void GatherGradImpl<T, int32_t>( \
|
||||
const RocmKernel& rocm_kernel, \
|
||||
const T* grad_data, \
|
||||
const int32_t* indices_data, \
|
||||
const int64_t num_indices, \
|
||||
const int64_t num_weights, \
|
||||
const int64_t stride, \
|
||||
T* output_data, \
|
||||
const int64_t num_inputs, \
|
||||
const int64_t param_itrs);
|
||||
|
||||
SPECIALIZED_GRAD_IMPL2(float)
|
||||
SPECIALIZED_GRAD_IMPL2(half)
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, typename Tin>
|
||||
void GatherGradImpl(
|
||||
const RocmKernel& rocm_kernel,
|
||||
const T* grad_data,
|
||||
const Tin* indices_data,
|
||||
const int64_t num_indices,
|
||||
const int64_t num_weights,
|
||||
const int64_t stride,
|
||||
T* output_data,
|
||||
const int64_t num_inputs,
|
||||
const int64_t param_itrs);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -204,12 +204,8 @@ training_ops_excluded_files = [
|
|||
'nn/batch_norm_grad.h',
|
||||
'optimizer/adam.cc',
|
||||
'optimizer/adam.cu',
|
||||
'optimizer/lamb.cc',
|
||||
'reduction/reduction_all.cc',
|
||||
'reduction/reduction_ops.cc',
|
||||
'tensor/gather_grad.cc',
|
||||
'tensor/gather_grad_impl.cu',
|
||||
'tensor/gather_grad_impl.h',
|
||||
'tensor/gather_nd_grad_impl.cu',
|
||||
'cuda_training_kernels.cc',
|
||||
'cuda_training_kernels.h',
|
||||
|
|
@ -245,8 +241,18 @@ def hipify(src_file_path, dst_file_path):
|
|||
s = s.replace('GPU_WARP_SIZE = 32', 'GPU_WARP_SIZE = 64')
|
||||
s = s.replace('std::exp', 'expf')
|
||||
s = s.replace('std::log', 'logf')
|
||||
s = s.replace('#include <cub/device/device_radix_sort.cuh>', '#include <hipcub/hipcub.hpp>')
|
||||
s = s.replace('#include <cub/iterator/counting_input_iterator.cuh>', '')
|
||||
s = s.replace('#include <cub/device/device_radix_sort.cuh>',
|
||||
'#include <hipcub/hipcub.hpp>\n#include <hipcub/backend/rocprim/device/device_radix_sort.hpp>')
|
||||
s = s.replace('#include <cub/device/device_reduce.cuh>',
|
||||
'#include <hipcub/backend/rocprim/device/device_reduce.hpp>')
|
||||
s = s.replace('#include <cub/device/device_run_length_encode.cuh>',
|
||||
'#include <hipcub/backend/rocprim/device/device_run_length_encode.hpp>')
|
||||
s = s.replace('#include <cub/device/device_scan.cuh>',
|
||||
'#include <hipcub/backend/rocprim/device/device_scan.hpp>')
|
||||
s = s.replace('#include <cub/iterator/counting_input_iterator.cuh>',
|
||||
'#include <hipcub/backend/rocprim/iterator/counting_input_iterator.hpp>')
|
||||
s = s.replace('#include <cub/iterator/discard_output_iterator.cuh>',
|
||||
'#include <hipcub/backend/rocprim/iterator/discard_output_iterator.hpp>')
|
||||
s = s.replace('typedef half MappedType', 'typedef __half MappedType')
|
||||
# CUBLAS -> ROCBLAS
|
||||
# s = s.replace('CUBLAS', 'HIPBLAS')
|
||||
|
|
|
|||
Loading…
Reference in a new issue