diff --git a/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc b/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc index 17dad9e9c0..c054b96816 100644 --- a/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc +++ b/orttraining/orttraining/training_ops/rocm/reduction/reduction_all.cc @@ -64,32 +64,36 @@ Status ReduceAllL2::ComputeInternal(OpKernelContext* ctx) const { // alternate path only for deterministic compute .. typedef AccumulationType_t HipTAcc; - // find scratch buffer size needed by 'reduce_square_sum' for each tensor - int scratch_size = 0; + // find reduction buffer size needed by 'reduce_square_sum' for each tensor + size_t reduction_buffer_size = 0; for (int i = 0; i < total_tensor_count; ++i) { - scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(HipTAcc), tensor_sizes[i])); + reduction_buffer_size = + std::max(reduction_buffer_size, compute_reduction_buffer_size(tensor_sizes[i])); } - // enlarge scratch buffer size for 'reduce_sum' over tensor square norms - scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(HipTAcc), total_tensor_count)); - - // add head room for final output and square norms of each tensor - scratch_size += (1 + total_tensor_count) * sizeof(HipTAcc); + // enlarge reduction buffer size for 'reduce_sum' over tensor square norms + reduction_buffer_size = + std::max(reduction_buffer_size, compute_reduction_buffer_size(total_tensor_count)); // create GPU scratch space and zero target for each tensor square norm - auto scratch_buffer = GetScratchBuffer(scratch_size); - HIP_RETURN_IF_ERROR(hipMemsetAsync(scratch_buffer.get(), 0, sizeof(HipTAcc) * (1 + total_tensor_count))); + auto reduction_buffer = GetScratchBuffer(reduction_buffer_size); - HipTAcc* p_global_sqnorm = reinterpret_cast(scratch_buffer.get()); + // buffer for final output and square norms of each tensor + auto results_buffer = GetScratchBuffer(1 + total_tensor_count); + + HIP_RETURN_IF_ERROR(hipMemsetAsync(results_buffer.get(), 0, sizeof(HipTAcc) * (1 + total_tensor_count))); + + HipTAcc* p_global_sqnorm = results_buffer.get(); HipTAcc* p_tensor_sqnorm = p_global_sqnorm + 1; - HipTAcc* p_reduce_buffer = p_tensor_sqnorm + total_tensor_count; // perform reduction l2norm = sqrt[sum(tensor[i][j]**2)] for i,j over all tensor elements for (int i = 0; i < total_tensor_count; ++i) { HipTIn* p_tensor_i = reinterpret_cast(grouped_tensor_pointers[i][0]); - reduce_square_sum(p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], p_reduce_buffer); + ORT_RETURN_IF_ERROR(reduce_square_sum( + p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], reduction_buffer.get(), reduction_buffer_size)); } - reduce_sum(p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, p_reduce_buffer); + ORT_RETURN_IF_ERROR(reduce_sum( + p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, reduction_buffer.get(), reduction_buffer_size)); ScalarSqrt(p_global_sqnorm, p_output); } diff --git a/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc b/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc index 3b9832d48a..2e4d64e347 100644 --- a/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc +++ b/orttraining/orttraining/training_ops/rocm/reduction/reduction_ops.cc @@ -63,5 +63,97 @@ Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, miope calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction); } +template <> +template <> +Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { + typedef typename ToHipType::MappedType HipT; + + const Tensor* X = ctx->Input(0); + + //override the attribute value with the input value for reduction_axes + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->template Data(); + std::vector axes(data, data + nDims); + + // empty axes and no-op + if (axes.empty() && noop_with_empty_axes_) { + auto* Y = ctx->Output(0, X->Shape()); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice)); + return Status::OK(); + } + + PrepareReduceMetadata prepare_reduce_metadata; + + ORT_RETURN_IF_ERROR(PrepareForReduce(X, + keepdims_, + axes, + prepare_reduce_metadata)); + + Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); + + int64_t input_count = prepare_reduce_metadata.input_count; + int64_t output_count = prepare_reduce_metadata.output_count; + std::vector& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; + std::vector& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; + + // special case when there is a dim value of 0 in the shape. + if (input_count == 0) { + assert(Y->Shape().Size() == 0); + return Status::OK(); + } + + // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case + if (input_count == output_count) { + if (Y->template MutableData() != X->template Data()) { + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->template MutableData(), X->template Data(), input_count * sizeof(int32_t), hipMemcpyDeviceToDevice)); + } + return Status::OK(); + } + + // This reduction keep adding values to this buffer. If a non-zero value, say 1000, is here, the sum will start with 1000. + // Therefore zeroing out the memory is required + HIP_RETURN_IF_ERROR(hipMemsetAsync(Y->MutableDataRaw(), 0, Y->SizeInBytes())); + + size_t indices_bytes = 0; + size_t workspace_bytes = 0; + MiopenTensor input_tensor; + MiopenTensor output_tensor; + MiopenReduceDescriptor reduce_desc; + + miopenDataType_t miopen_type_X = miopenFloat; + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count); + Impl_Cast(reinterpret_cast(X->template Data()), temp_X.get(), X->Shape().Size()); + + ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES)); + ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); + ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_miopen, miopen_type_X)); + MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); + MIOPEN_RETURN_IF_ERROR(miopenGetReductionWorkspaceSize(MiopenHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); + IAllocatorUniquePtr indices_miopen = GetScratchBuffer(indices_bytes); + IAllocatorUniquePtr workspace_miopen = GetScratchBuffer(workspace_bytes); + + const auto one = Consts::One; + const auto zero = Consts::Zero; + auto temp_Y = GetScratchBuffer(output_count); + MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(MiopenHandle(), + reduce_desc, + indices_miopen.get(), + indices_bytes, + workspace_miopen.get(), + workspace_bytes, + &one, + input_tensor, + temp_X.get(), + &zero, + output_tensor, + temp_Y.get())); + + Impl_Cast(temp_Y.get(), Y->template MutableData(), output_count); + + return Status::OK(); +} + } // namespace rocm } // namespace onnxruntime \ No newline at end of file diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 50318f1ee2..b8477d363d 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -32,6 +32,7 @@ ReductionOpTest.ReduceL1_do_not_keep_dims_2 ReductionOpTest.ReduceL1_keepdims ReductionOpTest.ReduceL1 ReductionOpTest.ReduceL1_int32 +ReductionOpTest.ReduceL10DTensor ReductionOpTest.ReduceL2_default_axes_keepdims ReductionOpTest.ReduceL2_default_axes_do_not_keep_dims ReductionOpTest.ReduceL2_do_not_keepdims @@ -39,6 +40,7 @@ ReductionOpTest.ReduceL2_do_not_keepdims_2 ReductionOpTest.ReduceL2_keepdims ReductionOpTest.ReduceL2 ReductionOpTest.ReduceL2_int32 +ReductionOpTest.ReduceL20DTensor ReductionOpTest.ReduceLogSum ReductionOpTest.ReduceLogSum_samesize ReductionOpTest.ReduceLogSum_do_not_keepdims_2 @@ -76,6 +78,8 @@ ReductionOpTest.ReduceMean_keepdims_double ReductionOpTest.ReduceMean ReductionOpTest.ReduceMean_double ReductionOpTest.ReduceMean_int32 +ReductionOpTest.ReduceMean0DTensor +ReductionOpTest.ReduceMean0DTensor_double ReductionOpTest.ReduceMin_default_axes_keepdims ReductionOpTest.ReduceMin_default_axes_do_not_keep_dims ReductionOpTest.ReduceMin_default_axes_do_not_keep_dims_2D