diff --git a/caffe2/sgd/adagrad_fused_op_gpu.cu b/caffe2/sgd/adagrad_fused_op_gpu.cu index cbe22cf9208..ba10eb608a6 100644 --- a/caffe2/sgd/adagrad_fused_op_gpu.cu +++ b/caffe2/sgd/adagrad_fused_op_gpu.cu @@ -273,7 +273,12 @@ __global__ void linear_index_weight_offsets_dedup_kernel( } } -template +template < + typename SIndex, + typename TParam, + typename T, + bool ExactBlock = false, + roundOption roundOpt = NEAREST> #ifdef __HIP_PLATFORM_HCC__ C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS) #endif @@ -293,7 +298,14 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel( const SIndex* sorted_linear_ind_data, // sorted linear indices const int* __restrict__ sorted_seg_id_data, // sorted segment id const float* lr, + ulong2 seed, float weight_decay = 0.f) { + + class randFactor rand_factor( + seed, + blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x); + const float LR = lr[0]; // num_indices blocks, each block process one index int sorted_linear_indice_id = blockIdx.x; // the index of sorted_linear_ind @@ -337,7 +349,8 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel( int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id]; x_ij += grad[group * block_size + i]; } - x_ij += weight_decay * param[index * block_size + i]; + x_ij += weight_decay * + rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]); sum_squares += x_ij * x_ij; } float reduce_result = BlockReduce(temp_storage).Sum(sum_squares, valid); @@ -358,9 +371,9 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel( x_ij += grad[group * block_size + i]; } const size_t paramIdx = index * block_size + i; // index for param - x_ij += weight_decay * param[paramIdx]; - float param_new = param[paramIdx] + x_ij * step; - param[paramIdx] = param_new; + x_ij += weight_decay * rand_factor.convertTypeFromParamToTarget(param[paramIdx]); + param[paramIdx] = + rand_factor.convertTypeFromTargetToParam(param[paramIdx] + x_ij * step); } } @@ -930,7 +943,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final // 0: nearest rounding // 1: stochastic rounding - if (round_option_) { + if (round_option_ == STOCHASTIC) { seed.x = default_rng_seed_val; seed.y = maxThreads * block_size; } @@ -939,7 +952,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final // WarpReduce. int multiple = std::min(maxThreads / block_size, SEGREDUCE_MINBLOCKS); dim3 block(block_size, multiple); - if (round_option_) { + if (round_option_ == STOCHASTIC) { rowwise_sparse_adagrad_fused_length_sum_gradient_kernel< IndexType, TParam, @@ -1054,12 +1067,19 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final Workspace* ws) : Operator(operator_def, ws), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)), + round_option_((roundOption)this->template GetSingleArgument( + "round_option", + NEAREST)), weight_decay_( this->template GetSingleArgument("weight_decay", 0.f)) { VLOG(1) << "gradient optimization operator in use: " << "CUDARowWiseSparseAdagradFusedWithSparseLengthSumGradientOp" << " weight_decay_=" << weight_decay_; + CAFFE_ENFORCE( + round_option_ == STOCHASTIC || round_option_ == NEAREST, + "round_option_ should be either NEAREST or STOCHATIC"); + const T decay = this->template GetSingleArgument("decay", 1.0f); CAFFE_ENFORCE_EQ(decay, 1.0, "Decay is not supported for SparseAdagradOp"); } @@ -1180,29 +1200,68 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final &sorted_seg_id_buffer_, &context_); - rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< - IndexType, - TParam, - T, - false> - <<>>( - prefix_sum_length_data, - N, - block_size, - num_lengths, - num_indices, - epsilon_, - paramOut, - momentOut, - indices, - is_mean ? grad_buffer_data : grad, - sorted_linear_ind_buffer_.template data(), - sorted_seg_id_buffer_.template data(), - lr, - weight_decay_); + ulong2 seed; + + // 0: nearest rounding + // 1: stochastic rounding + if (round_option_ == STOCHASTIC) { + seed.x = default_rng_seed_val; + seed.y = maxThreads * block_size; + } + + if (round_option_ == STOCHASTIC) { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + false, + STOCHASTIC> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + } else { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + false, + NEAREST> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + } return true; } @@ -1220,6 +1279,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final protected: T epsilon_; + roundOption round_option_; T weight_decay_; INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, LENGTHS); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);