mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[caffe2] Extend dedup SparseAdagrad fusion with stochastic rounding FP16 (#43124)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43124 Add the stochastic rounding FP16 support for dedup version of SparseAdagrad fusion. ghstack-source-id: 111037723 Test Plan: ``` buck test mode/dev-nosan //caffe2/caffe2/fb/net_transforms/tests:fuse_sparse_ops_test -- 'test_fuse_sparse_adagrad_with_sparse_lengths_sum_gradient \(caffe2\.caffe2\.fb\.net_transforms\.tests\.fuse_sparse_ops_test\.TestFuseSparseOps\)' --print-passing-details ``` https://our.intern.facebook.com/intern/testinfra/testrun/5629499566042000 ``` buck test mode/dev-nosan //caffe2/caffe2/fb/net_transforms/tests:fuse_sparse_ops_test -- 'test_fuse_sparse_adagrad_with_sparse_lengths_mean_gradient \(caffe2\.caffe2\.fb\.net_transforms\.tests\.fuse_sparse_ops_test\.TestFuseSparseOps\)' --print-passing-details ``` https://our.intern.facebook.com/intern/testinfra/testrun/1125900076333177 Reviewed By: xianjiec Differential Revision: D22893851 fbshipit-source-id: 81c7a7fe4b0d2de0e6b4fc965c5d23210213c46c
This commit is contained in:
parent
f17d7a5556
commit
3c2f6d2ecf
1 changed files with 90 additions and 30 deletions
|
|
@ -273,7 +273,12 @@ __global__ void linear_index_weight_offsets_dedup_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename SIndex, typename TParam, typename T, bool ExactBlock = false>
|
||||
template <
|
||||
typename SIndex,
|
||||
typename TParam,
|
||||
typename T,
|
||||
bool ExactBlock = false,
|
||||
roundOption roundOpt = NEAREST>
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
|
||||
#endif
|
||||
|
|
@ -293,7 +298,14 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
|
|||
const SIndex* sorted_linear_ind_data, // sorted linear indices
|
||||
const int* __restrict__ sorted_seg_id_data, // sorted segment id
|
||||
const float* lr,
|
||||
ulong2 seed,
|
||||
float weight_decay = 0.f) {
|
||||
|
||||
class randFactor<TParam, T, roundOpt> rand_factor(
|
||||
seed,
|
||||
blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
|
||||
threadIdx.x);
|
||||
|
||||
const float LR = lr[0];
|
||||
// num_indices blocks, each block process one index
|
||||
int sorted_linear_indice_id = blockIdx.x; // the index of sorted_linear_ind
|
||||
|
|
@ -337,7 +349,8 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
|
|||
int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id];
|
||||
x_ij += grad[group * block_size + i];
|
||||
}
|
||||
x_ij += weight_decay * param[index * block_size + i];
|
||||
x_ij += weight_decay *
|
||||
rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]);
|
||||
sum_squares += x_ij * x_ij;
|
||||
}
|
||||
float reduce_result = BlockReduce(temp_storage).Sum(sum_squares, valid);
|
||||
|
|
@ -358,9 +371,9 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
|
|||
x_ij += grad[group * block_size + i];
|
||||
}
|
||||
const size_t paramIdx = index * block_size + i; // index for param
|
||||
x_ij += weight_decay * param[paramIdx];
|
||||
float param_new = param[paramIdx] + x_ij * step;
|
||||
param[paramIdx] = param_new;
|
||||
x_ij += weight_decay * rand_factor.convertTypeFromParamToTarget(param[paramIdx]);
|
||||
param[paramIdx] =
|
||||
rand_factor.convertTypeFromTargetToParam(param[paramIdx] + x_ij * step);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -930,7 +943,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
|
|||
|
||||
// 0: nearest rounding
|
||||
// 1: stochastic rounding
|
||||
if (round_option_) {
|
||||
if (round_option_ == STOCHASTIC) {
|
||||
seed.x = default_rng_seed_val;
|
||||
seed.y = maxThreads * block_size;
|
||||
}
|
||||
|
|
@ -939,7 +952,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
|
|||
// WarpReduce.
|
||||
int multiple = std::min(maxThreads / block_size, SEGREDUCE_MINBLOCKS);
|
||||
dim3 block(block_size, multiple);
|
||||
if (round_option_) {
|
||||
if (round_option_ == STOCHASTIC) {
|
||||
rowwise_sparse_adagrad_fused_length_sum_gradient_kernel<
|
||||
IndexType,
|
||||
TParam,
|
||||
|
|
@ -1054,12 +1067,19 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
|
|||
Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
|
||||
round_option_((roundOption)this->template GetSingleArgument<int>(
|
||||
"round_option",
|
||||
NEAREST)),
|
||||
weight_decay_(
|
||||
this->template GetSingleArgument<float>("weight_decay", 0.f)) {
|
||||
VLOG(1) << "gradient optimization operator in use: "
|
||||
<< "CUDARowWiseSparseAdagradFusedWithSparseLengthSumGradientOp"
|
||||
<< " weight_decay_=" << weight_decay_;
|
||||
|
||||
CAFFE_ENFORCE(
|
||||
round_option_ == STOCHASTIC || round_option_ == NEAREST,
|
||||
"round_option_ should be either NEAREST or STOCHATIC");
|
||||
|
||||
const T decay = this->template GetSingleArgument<T>("decay", 1.0f);
|
||||
CAFFE_ENFORCE_EQ(decay, 1.0, "Decay is not supported for SparseAdagradOp");
|
||||
}
|
||||
|
|
@ -1180,29 +1200,68 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
|
|||
&sorted_seg_id_buffer_,
|
||||
&context_);
|
||||
|
||||
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
|
||||
IndexType,
|
||||
TParam,
|
||||
T,
|
||||
false>
|
||||
<<<num_indices,
|
||||
std::min(maxThreads, block_size),
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
prefix_sum_length_data,
|
||||
N,
|
||||
block_size,
|
||||
num_lengths,
|
||||
num_indices,
|
||||
epsilon_,
|
||||
paramOut,
|
||||
momentOut,
|
||||
indices,
|
||||
is_mean ? grad_buffer_data : grad,
|
||||
sorted_linear_ind_buffer_.template data<IndexType>(),
|
||||
sorted_seg_id_buffer_.template data<int>(),
|
||||
lr,
|
||||
weight_decay_);
|
||||
ulong2 seed;
|
||||
|
||||
// 0: nearest rounding
|
||||
// 1: stochastic rounding
|
||||
if (round_option_ == STOCHASTIC) {
|
||||
seed.x = default_rng_seed_val;
|
||||
seed.y = maxThreads * block_size;
|
||||
}
|
||||
|
||||
if (round_option_ == STOCHASTIC) {
|
||||
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
|
||||
IndexType,
|
||||
TParam,
|
||||
T,
|
||||
false,
|
||||
STOCHASTIC>
|
||||
<<<num_indices,
|
||||
std::min(maxThreads, block_size),
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
prefix_sum_length_data,
|
||||
N,
|
||||
block_size,
|
||||
num_lengths,
|
||||
num_indices,
|
||||
epsilon_,
|
||||
paramOut,
|
||||
momentOut,
|
||||
indices,
|
||||
is_mean ? grad_buffer_data : grad,
|
||||
sorted_linear_ind_buffer_.template data<IndexType>(),
|
||||
sorted_seg_id_buffer_.template data<int>(),
|
||||
lr,
|
||||
seed,
|
||||
weight_decay_);
|
||||
} else {
|
||||
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
|
||||
IndexType,
|
||||
TParam,
|
||||
T,
|
||||
false,
|
||||
NEAREST>
|
||||
<<<num_indices,
|
||||
std::min(maxThreads, block_size),
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
prefix_sum_length_data,
|
||||
N,
|
||||
block_size,
|
||||
num_lengths,
|
||||
num_indices,
|
||||
epsilon_,
|
||||
paramOut,
|
||||
momentOut,
|
||||
indices,
|
||||
is_mean ? grad_buffer_data : grad,
|
||||
sorted_linear_ind_buffer_.template data<IndexType>(),
|
||||
sorted_seg_id_buffer_.template data<int>(),
|
||||
lr,
|
||||
seed,
|
||||
weight_decay_);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1220,6 +1279,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
|
|||
|
||||
protected:
|
||||
T epsilon_;
|
||||
roundOption round_option_;
|
||||
T weight_decay_;
|
||||
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, LENGTHS);
|
||||
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
|
||||
|
|
|
|||
Loading…
Reference in a new issue