mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix AllReduce kernel for contiguous buffer (#6064)
This commit is contained in:
parent
e357486707
commit
7a43fa0028
1 changed files with 12 additions and 5 deletions
|
|
@ -13,18 +13,25 @@ Status NcclAllReduce::ComputeInternal(OpKernelContext* context) const {
|
|||
cudaStream_t stream = nullptr; // Default stream
|
||||
ncclComm_t comm = nccl_->Comm(group_type_);
|
||||
|
||||
size_t input_count = 0;
|
||||
const void* input_data = context->Input<Tensor>(0)->DataRaw();
|
||||
void* output_data = context->Output(0, context->Input<Tensor>(0)->Shape())->MutableDataRaw();
|
||||
MLDataType onnx_type = context->Input<Tensor>(0)->DataType();
|
||||
|
||||
// Although we assumed the memory address is contiguous for the input, ORT pads activation tensors to 64 bytes aligned
|
||||
// and initializers to 256 bytes aligned. There are tiny padding gaps in the contiguous buffer space.
|
||||
// We have to AllReduce on the entire buffer, including the padding space.
|
||||
const Tensor* last_tensor = context->Input<Tensor>(context->InputCount() - 1);
|
||||
int8_t* end_address = (int8_t*)last_tensor->DataRaw() + last_tensor->SizeInBytes();
|
||||
size_t num_bytes = end_address - (int8_t*)input_data;
|
||||
size_t count = num_bytes / onnx_type->Size();
|
||||
ORT_ENFORCE(num_bytes % onnx_type->Size() == 0);
|
||||
|
||||
for (int i = 0; i < context->InputCount(); i++) {
|
||||
const Tensor* input_tensor = context->Input<Tensor>(i);
|
||||
input_count += input_tensor->Shape().Size();
|
||||
context->Output(i, input_tensor->Shape());
|
||||
context->Output(i, context->Input<Tensor>(i)->Shape());
|
||||
}
|
||||
|
||||
ncclDataType_t dtype = GetNcclDataType(onnx_type);
|
||||
NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream));
|
||||
NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, count, dtype, ncclSum, comm, stream));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue