From 7a43fa00285ce8a923bf1a1d7907818cdaa0d063 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Tue, 8 Dec 2020 15:55:13 -0800 Subject: [PATCH] Fix AllReduce kernel for contiguous buffer (#6064) --- .../cuda/collective/nccl_kernels.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc b/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc index 0dc3209d12..26bf59f730 100644 --- a/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc @@ -13,18 +13,25 @@ Status NcclAllReduce::ComputeInternal(OpKernelContext* context) const { cudaStream_t stream = nullptr; // Default stream ncclComm_t comm = nccl_->Comm(group_type_); - size_t input_count = 0; const void* input_data = context->Input(0)->DataRaw(); void* output_data = context->Output(0, context->Input(0)->Shape())->MutableDataRaw(); MLDataType onnx_type = context->Input(0)->DataType(); + + // Although we assumed the memory address is contiguous for the input, ORT pads activation tensors to 64 bytes aligned + // and initializers to 256 bytes aligned. There are tiny padding gaps in the contiguous buffer space. + // We have to AllReduce on the entire buffer, including the padding space. + const Tensor* last_tensor = context->Input(context->InputCount() - 1); + int8_t* end_address = (int8_t*)last_tensor->DataRaw() + last_tensor->SizeInBytes(); + size_t num_bytes = end_address - (int8_t*)input_data; + size_t count = num_bytes / onnx_type->Size(); + ORT_ENFORCE(num_bytes % onnx_type->Size() == 0); + for (int i = 0; i < context->InputCount(); i++) { - const Tensor* input_tensor = context->Input(i); - input_count += input_tensor->Shape().Size(); - context->Output(i, input_tensor->Shape()); + context->Output(i, context->Input(i)->Shape()); } ncclDataType_t dtype = GetNcclDataType(onnx_type); - NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream)); + NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, count, dtype, ncclSum, comm, stream)); return Status::OK(); }