[caffe2] Extend dedup SparseAdagrad fusion with stochastic rounding FP16 (#43124)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43124

Add the stochastic rounding FP16 support for dedup version of SparseAdagrad fusion.
ghstack-source-id: 111037723

Test Plan:
```
buck test mode/dev-nosan //caffe2/caffe2/fb/net_transforms/tests:fuse_sparse_ops_test -- 'test_fuse_sparse_adagrad_with_sparse_lengths_sum_gradient \(caffe2\.caffe2\.fb\.net_transforms\.tests\.fuse_sparse_ops_test\.TestFuseSparseOps\)' --print-passing-details
```

https://our.intern.facebook.com/intern/testinfra/testrun/5629499566042000

```
buck test mode/dev-nosan //caffe2/caffe2/fb/net_transforms/tests:fuse_sparse_ops_test -- 'test_fuse_sparse_adagrad_with_sparse_lengths_mean_gradient \(caffe2\.caffe2\.fb\.net_transforms\.tests\.fuse_sparse_ops_test\.TestFuseSparseOps\)' --print-passing-details
```

https://our.intern.facebook.com/intern/testinfra/testrun/1125900076333177

Reviewed By: xianjiec

Differential Revision: D22893851

fbshipit-source-id: 81c7a7fe4b0d2de0e6b4fc965c5d23210213c46c
This commit is contained in:
Jianyu Huang 2020-08-31 20:33:48 -07:00 committed by Facebook GitHub Bot
parent f17d7a5556
commit 3c2f6d2ecf

View file

@ -273,7 +273,12 @@ __global__ void linear_index_weight_offsets_dedup_kernel(
}
}
template <typename SIndex, typename TParam, typename T, bool ExactBlock = false>
template <
typename SIndex,
typename TParam,
typename T,
bool ExactBlock = false,
roundOption roundOpt = NEAREST>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
@ -293,7 +298,14 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
const SIndex* sorted_linear_ind_data, // sorted linear indices
const int* __restrict__ sorted_seg_id_data, // sorted segment id
const float* lr,
ulong2 seed,
float weight_decay = 0.f) {
class randFactor<TParam, T, roundOpt> rand_factor(
seed,
blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x);
const float LR = lr[0];
// num_indices blocks, each block process one index
int sorted_linear_indice_id = blockIdx.x; // the index of sorted_linear_ind
@ -337,7 +349,8 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id];
x_ij += grad[group * block_size + i];
}
x_ij += weight_decay * param[index * block_size + i];
x_ij += weight_decay *
rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]);
sum_squares += x_ij * x_ij;
}
float reduce_result = BlockReduce(temp_storage).Sum(sum_squares, valid);
@ -358,9 +371,9 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
x_ij += grad[group * block_size + i];
}
const size_t paramIdx = index * block_size + i; // index for param
x_ij += weight_decay * param[paramIdx];
float param_new = param[paramIdx] + x_ij * step;
param[paramIdx] = param_new;
x_ij += weight_decay * rand_factor.convertTypeFromParamToTarget(param[paramIdx]);
param[paramIdx] =
rand_factor.convertTypeFromTargetToParam(param[paramIdx] + x_ij * step);
}
}
@ -930,7 +943,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
// 0: nearest rounding
// 1: stochastic rounding
if (round_option_) {
if (round_option_ == STOCHASTIC) {
seed.x = default_rng_seed_val;
seed.y = maxThreads * block_size;
}
@ -939,7 +952,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
// WarpReduce.
int multiple = std::min(maxThreads / block_size, SEGREDUCE_MINBLOCKS);
dim3 block(block_size, multiple);
if (round_option_) {
if (round_option_ == STOCHASTIC) {
rowwise_sparse_adagrad_fused_length_sum_gradient_kernel<
IndexType,
TParam,
@ -1054,12 +1067,19 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
Workspace* ws)
: Operator<Context>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
round_option_((roundOption)this->template GetSingleArgument<int>(
"round_option",
NEAREST)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)) {
VLOG(1) << "gradient optimization operator in use: "
<< "CUDARowWiseSparseAdagradFusedWithSparseLengthSumGradientOp"
<< " weight_decay_=" << weight_decay_;
CAFFE_ENFORCE(
round_option_ == STOCHASTIC || round_option_ == NEAREST,
"round_option_ should be either NEAREST or STOCHATIC");
const T decay = this->template GetSingleArgument<T>("decay", 1.0f);
CAFFE_ENFORCE_EQ(decay, 1.0, "Decay is not supported for SparseAdagradOp");
}
@ -1180,29 +1200,68 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
&sorted_seg_id_buffer_,
&context_);
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
IndexType,
TParam,
T,
false>
<<<num_indices,
std::min(maxThreads, block_size),
0,
context_.cuda_stream()>>>(
prefix_sum_length_data,
N,
block_size,
num_lengths,
num_indices,
epsilon_,
paramOut,
momentOut,
indices,
is_mean ? grad_buffer_data : grad,
sorted_linear_ind_buffer_.template data<IndexType>(),
sorted_seg_id_buffer_.template data<int>(),
lr,
weight_decay_);
ulong2 seed;
// 0: nearest rounding
// 1: stochastic rounding
if (round_option_ == STOCHASTIC) {
seed.x = default_rng_seed_val;
seed.y = maxThreads * block_size;
}
if (round_option_ == STOCHASTIC) {
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
IndexType,
TParam,
T,
false,
STOCHASTIC>
<<<num_indices,
std::min(maxThreads, block_size),
0,
context_.cuda_stream()>>>(
prefix_sum_length_data,
N,
block_size,
num_lengths,
num_indices,
epsilon_,
paramOut,
momentOut,
indices,
is_mean ? grad_buffer_data : grad,
sorted_linear_ind_buffer_.template data<IndexType>(),
sorted_seg_id_buffer_.template data<int>(),
lr,
seed,
weight_decay_);
} else {
rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel<
IndexType,
TParam,
T,
false,
NEAREST>
<<<num_indices,
std::min(maxThreads, block_size),
0,
context_.cuda_stream()>>>(
prefix_sum_length_data,
N,
block_size,
num_lengths,
num_indices,
epsilon_,
paramOut,
momentOut,
indices,
is_mean ? grad_buffer_data : grad,
sorted_linear_ind_buffer_.template data<IndexType>(),
sorted_seg_id_buffer_.template data<int>(),
lr,
seed,
weight_decay_);
}
return true;
}
@ -1220,6 +1279,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final
protected:
T epsilon_;
roundOption round_option_;
T weight_decay_;
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, LENGTHS);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);