From e9d20e9dba844aeed1d71c54e9219587f7b91c8c Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 30 Jul 2020 23:02:45 -0700 Subject: [PATCH] Revise Send and Recv (#4547) * Add ability to retrieve inferred shapes when executing a kernel. This ability helps Recv to know its output shapes without doing actual cummunication. Of course, if the output shapes cannot be inferred, Recv still needs to do communication to get shapes from Send. * Avoid communicating shape information when it can be inferred statically * Replace unordered_map with thread-safe wrapper. We don't want to have racing condition and undefined behavior when using parallel executor.y * Remove cout * Add missing file * Address comments * Check dim_value. -1 means missing * lock properly * Address comments (remove thread-safe map) * Remove poc header * Replace Stream with DeferredReleaseCPUPtr --- .../onnxruntime/core/framework/op_kernel.h | 10 + onnxruntime/core/framework/execution_frame.cc | 31 ++- onnxruntime/core/framework/execution_frame.h | 15 + onnxruntime/core/framework/op_kernel.cc | 8 + onnxruntime/core/framework/session_state.cc | 98 ++++--- onnxruntime/core/framework/session_state.h | 7 +- .../training_ops/cuda/communication/common.h | 76 +++++ .../training_ops/cuda/communication/recv.cc | 260 ++++++++++++------ .../training_ops/cuda/communication/recv.h | 12 + .../training_ops/cuda/communication/send.cc | 243 +++++++++------- .../training_ops/cuda/communication/send.h | 14 + 11 files changed, 554 insertions(+), 220 deletions(-) diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 0a97d39390..2287e5f43a 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -128,6 +128,16 @@ class OpKernelContext { // unless static optimization pre-allocates it. SparseTensor* Output(int index, size_t num_values, const TensorShape& shape); + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredInputShape(int index, TensorShape& shape) const; + + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredOutputShape(int index, TensorShape& shape) const; + const logging::Logger& Logger() const { return *logger_; } diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 6cc3345b1c..848b227692 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -73,6 +73,13 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShap return status; } +bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const { + // By default, there is not information about inferred shape, so this default + // implementation always returns false. The derived class of IExecutionFrame + // can override this function to provide, for example, activations' shape information. + return false; +} + AllocatorPtr IExecutionFrame::GetAllocator(const OrtMemoryInfo& info) const { return GetAllocatorImpl(info); } @@ -235,7 +242,7 @@ ExecutionFrame::ExecutionFrame(const std::vector& feed_mlvalue_idxs, const //if there are some traditional ml value type in inputs disable the memory pattern optimization. if (all_tensors) { - mem_patterns_ = session_state.GetMemoryPatternGroup(input_shapes, feed_mlvalue_idxs); + mem_patterns_ = session_state.GetMemoryPatternGroup(input_shapes, feed_mlvalue_idxs, inferred_shapes_); // if no existing patterns, generate one in this executionframe if (!mem_patterns_) { planner_ = onnxruntime::make_unique(*session_state.GetExecutionPlan()); @@ -623,4 +630,26 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const { return planner_->GeneratePatterns(out); } + +bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { + // NodeArg index to OrtValue index. + int ort_value_idx = GetNodeIdxToMLValueIdx(index); + + // Check if index is valid. + if (ort_value_idx == NodeIndexInfo::kInvalidEntry) { + return false; + } + + // Search for inferred shape. + // If inferred shape is found, it's assigned to "shape" so that caller can use it. + auto it = inferred_shapes_.find(ort_value_idx); + if (it != inferred_shapes_.end()) { + shape = it->second; + return true; + } + + // Tell the caller if the search is successful or not. + return false; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index f0b211994f..dc500314d3 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -53,6 +54,10 @@ class IExecutionFrame { // Shape is required for tensors but not traditional ML values. Status GetOrCreateNodeOutputMLValue(int index, const TensorShape* shape, OrtValue*& p_ort_value, size_t nnz = 0); + // This function try retrieve the inferred shapes for the given NodeArg index. + // If the retrival is sucessful, this function returns true and false otherwise. + virtual bool TryGetInferredShape(int index, TensorShape& shape) const; + /** * write the output values to the 'fetches' vector * Don't access the values after SessionState is destroyed @@ -129,6 +134,10 @@ class ExecutionFrame final : public IExecutionFrame { return planner_ != nullptr; } + // This function try retrieve the inferred shapes for the given NodeArg index. + // If the retrival is sucessful, this function returns true and false otherwise. + bool TryGetInferredShape(int index, TensorShape& shape) const override; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExecutionFrame); @@ -168,5 +177,11 @@ class ExecutionFrame final : public IExecutionFrame { // Big chunks on different locations that will be used by mem_pattern. std::map buffers_; + + // Given the input shapes of the executed graph, ExecutionFrame tries inferring + // all symbolic shapes. inferred_shapes_[i] is the shape of OrtValue indexed + // by i, if the key i exists. + // inferred_shapes_ is generated togehter with mem_patterns_. + std::unordered_map inferred_shapes_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index a21bab8166..b3c05bb4e5 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -38,6 +38,14 @@ SparseTensor* OpKernelContext::Output(int index, size_t nnz, const TensorShape& return p_ml_value ? p_ml_value->GetMutable() : nullptr; } +bool OpKernelContext::TryGetInferredInputShape(int index, TensorShape& shape) const { + return execution_frame_->TryGetInferredShape(GetInputArgIndex(index), shape); +} + +bool OpKernelContext::TryGetInferredOutputShape(int index, TensorShape& shape) const { + return execution_frame_->TryGetInferredShape(GetOutputArgIndex(index), shape); +} + OrtValue* OpKernelContext::OutputMLValue(int index, const TensorShape& shape, size_t nnz) { if (index < 0 || index >= OutputCount()) return nullptr; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a376956e72..2a03f55a9d 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -252,11 +252,53 @@ Status ResolveDimParams(const GraphViewer& graph, } return Status::OK(); } + +Status ResolveSizeAndShape( + const NodeArg* arg, + const std::unordered_map& symbolic_dimensions, + size_t& size, // total number of elements. It's 0 if shape is unknown. + std::vector& resolved_shape) { + if (!arg->Shape()) { + // 0 means no shape information. + size = 0; + return Status::OK(); + } + + std::vector shape; + + SafeInt safe_size = 1; + for (auto& dim : arg->Shape()->dim()) { + if (dim.has_dim_param()) { + auto it = symbolic_dimensions.find(dim.dim_param()); + if (it == symbolic_dimensions.end()) { + return Status(ONNXRUNTIME, FAIL, "Unknown symbolic dimension, " + dim.dim_param() + ", found in memory pattern compute."); + } + safe_size *= it->second; + shape.push_back(it->second); + } else if (dim.has_dim_value() && dim.dim_value() > 0) { + safe_size *= dim.dim_value(); + shape.push_back(dim.dim_value()); + } else { + // tensor shape is unknown. + safe_size = 0; + } + } + + size = safe_size; + + // Only assign shape if all symbolic dimensions are resolved. + if (size != 0) { + resolved_shape = std::move(shape); + } + + return Status::OK(); +} } // namespace Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shape, const std::vector& feed_mlvalue_idxs, - MemoryPatternGroup* output) const { + MemoryPatternGroup* output, + std::unordered_map& resolved_shapes) const { std::map feeds; for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) { std::string name; @@ -282,38 +324,27 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorIsTensorType()) continue; const auto* ml_data_type = static_cast(ml_type)->GetElementType(); + + auto* arg = node->OutputDefs()[i]; + size_t size = 0; + std::vector resolved_shape; + ORT_RETURN_IF_ERROR(ResolveSizeAndShape(arg, map, size, resolved_shape)); + + // Store all valid resolved shapes. They will be queried in, for example, + // Recv operator to bypass the dependency of output shapes on inputs. + if (size != 0) { + resolved_shapes[ml_value_idx] = resolved_shape; + } + + // Plan memory if conditions are met. if (exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate && - ml_data_type != DataTypeImpl::GetType()) { - //calculate size - auto* arg = node->OutputDefs()[i]; - if (!arg->Shape()) - continue; - size_t size = 0; - SafeInt len = 1; - for (auto& dim : arg->Shape()->dim()) { - if (dim.has_dim_param()) { - auto it = map.find(dim.dim_param()); - if (it == map.end()) { - return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute"); - } - len *= it->second; - } else if (dim.has_dim_value()) { - len *= dim.dim_value(); - } else { - // tensor shape is unknown - len = 0; - } - } - - // Skip planning for this tensor if shape is unknown - if (len == 0) { - continue; - } - - if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) { + ml_data_type != DataTypeImpl::GetType() && size != 0) { + size_t aligned_size = 0; + if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(size, ml_data_type->Size(), &aligned_size)) { return Status(ONNXRUNTIME, FAIL, "Size overflow"); } - mem_planner.TraceAllocation(ml_value_idx, size); + + mem_planner.TraceAllocation(ml_value_idx, aligned_size); } } //release nodes @@ -337,7 +368,8 @@ Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shapes, - const std::vector& feed_mlvalue_idxs) const { + const std::vector& feed_mlvalue_idxs, + std::unordered_map& inferred_shapes) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -345,10 +377,11 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< if (it == mem_patterns_.end()) { #ifdef ENABLE_TRAINING auto mem_patterns = onnxruntime::make_unique(); - if (GeneratePatternGroupCache(input_shapes, feed_mlvalue_idxs, mem_patterns.get()).IsOK()) { + if (GeneratePatternGroupCache(input_shapes, feed_mlvalue_idxs, mem_patterns.get(), inferred_shapes).IsOK()) { key = CalculateMemoryPatternsKey(input_shapes); auto ptr = mem_patterns.get(); mem_patterns_[key] = std::move(mem_patterns); + shape_patterns_[key] = inferred_shapes; return ptr; } return nullptr; @@ -358,6 +391,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< #endif } + inferred_shapes = shape_patterns_[key]; return it->second.get(); } diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index fd435f3dad..2f0dac6f20 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -187,7 +187,8 @@ class SessionState { */ const MemoryPatternGroup* GetMemoryPatternGroup( const std::vector>& input_shapes, - const std::vector& feed_mlvalue_idxs) const; + const std::vector& feed_mlvalue_idxs, + std::unordered_map& inferred_shapes) const; /** Set generated memory pattern with a given input shapes. @@ -278,7 +279,8 @@ class SessionState { Status GeneratePatternGroupCache( const std::vector>& input_shape, const std::vector& feed_mlvalue_idxs, - MemoryPatternGroup* output) const; + MemoryPatternGroup* output, + std::unordered_map& inferred_shapes) const; #endif // cache of the constructed kernels to avoid spending construction time per executor @@ -346,6 +348,7 @@ class SessionState { // cache for the generated mem_patterns. key is calculated based on input shapes. mutable std::map> mem_patterns_; + mutable std::map> shape_patterns_; NameNodeInfoMapType input_names_to_nodeinfo_mapping_; NameNodeInfoMapType output_names_to_nodeinfo_mapping_; diff --git a/orttraining/orttraining/training_ops/cuda/communication/common.h b/orttraining/orttraining/training_ops/cuda/communication/common.h index a87e6ab850..e6ec8afdb3 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/common.h +++ b/orttraining/orttraining/training_ops/cuda/communication/common.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" #pragma once namespace onnxruntime { @@ -26,5 +27,80 @@ inline size_t GetAggregatedAlignedAddress(size_t old_addr) { return new_addr; } +// This function extracts shapes and sizes for tensors indexed by +// begin, begin + 1, ..., begin + count - 1. +// tensor_sizes[i]: i-th indexed tensor's size. +// tensor_shapes[i]: i-th indexed tensor's shape. +inline void GetTensorShapesAndSizes( + const bool is_index_input, // Index inputs from the context if true. Otherwise, use outputs. + const int begin, // Index of the first sent/received tensor in the context. + const int count, // Number of sent/received tensors. + OpKernelContext* ctx, // The context. + std::vector& tensor_sizes, // tensor_sizes[i] is the size of i-th sent/received tensor in byte. + std::vector& tensor_shapes) { // tensor_shapes[i] is the i-th sent/received tensor. + + // Helper function to retrieve input or output tensor. + auto get_tensor = [&](const int index) -> const Tensor* { + if (is_index_input) { + return ctx->Input(begin + index); + } else { + return ctx->Output(begin + index); + } + }; + + // Get tensors and shapes for indexed tensors from context. + tensor_sizes.resize(count); + tensor_shapes.resize(count); + for (int i = 0; i < count; ++i) { + const Tensor* tensor = get_tensor(i); + tensor_sizes[i] = tensor->SizeInBytes(); + tensor_shapes[i] = tensor->Shape(); + } +} + +// Compute shape-related information from given tensor shapes. +inline void ComputeShapeRelatedInfo( + // tensor_sizes[i] is the size of i-th sent/received tensor in byte. + const std::vector tensor_sizes, + // tensor_shapes[i] is the i-th sent/received tensor. + const std::vector tensor_shapes, + // The size in bytes if we concatenate all tensors into one single tensor. + // It may be larger than the original size due to memory alignment. + size_t& aggregated_aligned_tensor_bytes, + // aggregated_tensor_shapes[prefix_tensor_shape_sizes[i]] is the first dimension of the i-th tensor. + // aggregated_tensor_shapes[prefix_tensor_shape_sizes[i + 1]] is the element after the last dimension of the i-th tensor. + std::vector& prefix_tensor_shape_sizes, + // This field is the concatenation of all received tensors' shapes. + // Assume that there are two tensors A and B with rank NA and NB, respectively. + // aggregated_tensor_shapes = [A_shape[0], A_shape[1], ..., A_shape[NA-1], B_shape[0], B_shape[1], ..., B_shape[NB-1]]. + std::vector& aggregated_tensor_shapes, + // tensor_offsets_in_bytes[i] is the offset of the starting byte of the i-th tensor in the communicated tensor buffer. + // That is, i-th tensor's first element is tensor_buffer[tensor_offsets_in_bytes[i]]. + std::vector& tensor_offsets_in_bytes) { + // Initialize outputs. + aggregated_aligned_tensor_bytes = 0; + prefix_tensor_shape_sizes.resize(0); + aggregated_tensor_shapes.resize(0); + tensor_offsets_in_bytes.resize(0); + + // Compute shape information. + size_t prefix_tensor_shape_size_sum = 0; + for (int i = 0; static_cast(i) < tensor_shapes.size(); ++i) { + const auto& shape = tensor_shapes[i]; + prefix_tensor_shape_size_sum += shape.NumDimensions(); + prefix_tensor_shape_sizes.push_back(prefix_tensor_shape_size_sum); + aggregated_tensor_shapes.insert(aggregated_tensor_shapes.end(), + shape.GetDims().begin(), + shape.GetDims().end()); + + // aggregated_aligned_tensor_bytes is the first non-occupied address. + // Starting form aggregated_aligned_tensor_bytes, we find the next aligned offset in the + // tensor buffer to meet alignment requirement. + aggregated_aligned_tensor_bytes = GetAggregatedAlignedAddress(aggregated_aligned_tensor_bytes); + tensor_offsets_in_bytes.push_back(aggregated_aligned_tensor_bytes); + aggregated_aligned_tensor_bytes += tensor_sizes[i]; + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.cc b/orttraining/orttraining/training_ops/cuda/communication/recv.cc index 9d43323b88..42be0c0de9 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.cc @@ -6,6 +6,7 @@ #include "orttraining/training_ops/cuda/communication/recv.h" #include "orttraining/training_ops/cuda/communication/common.h" #include "core/profile/profile.h" +#include "core/providers/cuda/cuda_common.h" #include #include "orttraining/core/framework/mpi_setup.h" @@ -13,6 +14,102 @@ namespace onnxruntime { namespace cuda { +void Recv::ReceiveShapeInfo( + const int src, + const int num_tensors, + size_t& aggregated_aligned_tensor_bytes, + std::vector& prefix_tensor_shape_sizes, + std::vector& aggregated_tensor_shapes) const { + // Resize vector so that the following .data() returns meaningful pointer. + prefix_tensor_shape_sizes.resize(num_tensors); + CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), + num_tensors * static_cast(sizeof(size_t)), + src, + static_cast(tag_)}; + CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, + static_cast(sizeof(size_t)), + src, + static_cast(tag_)}; + // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because + // MPI_Recv may block the entire GPU until it returns. + MPI_CHECK(MPI_Recv( + info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, + info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + + MPI_CHECK(MPI_Recv( + info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, + info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + + // prefix_tensor_shape_sizes's last element is the number of total dimensions. + // If a 3-D tensor and a 2-D tensor are sent, its value is 2 + 3 = 5. + aggregated_tensor_shapes.resize(prefix_tensor_shape_sizes[num_tensors - 1]); + CommInfo_t info_shapes{aggregated_tensor_shapes.data(), + static_cast(aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)), + src, + static_cast(tag_)}; + MPI_CHECK(MPI_Recv( + info_shapes.buffer, info_shapes.size, MPI_CHAR, + info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); +} + +void Recv::ReceiveData( + const int num_tensors, + std::vector received_tensors, + const int src, + const size_t aggregated_aligned_tensor_bytes, + IAllocatorUniquePtr& buffer) const { +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator recvRange( + "Recv-" + std::to_string(src), profile::Color::Green); + // Begin of major communication tasks. + // The first MPI_Recv is not included because we don't want to + // count waiting time before setting up the actual communication. + recvRange.Begin(); +#endif + buffer = AllocateBufferOnCPUPinned(static_cast(aggregated_aligned_tensor_bytes)); + CommInfo_t info_data{buffer.get(), + static_cast(aggregated_aligned_tensor_bytes), + src, + static_cast(tag_)}; + + MPI_CHECK(MPI_Recv( + info_data.buffer, info_data.size, MPI_CHAR, + info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + +#ifdef ENABLE_NVTX_PROFILE + // End of actual communication. + recvRange.End(); +#endif + +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator memcpyRange( + "RecvMemcpy-" + std::to_string(src), profile::Color::Green); + // Begin of host-to-device memory copy. + memcpyRange.Begin(); +#endif + + // Copy tensors from buffer to outputs. + size_t tensor_offset_in_bytes = 0; + for (int i = 0; i < num_tensors; ++i) { + Tensor* tensor = received_tensors[i]; + + // Find the next aligned offset in the tensor buffer to meet alignment requirement + tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes); + + // Keep the sync copy in the previous design + // TODO they can be moved to async call after global stream becoming accessible + CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, + tensor->SizeInBytes(), cudaMemcpyHostToDevice)); + tensor_offset_in_bytes += tensor->SizeInBytes(); + } + AddDeferredReleaseCPUPtr(buffer.release()); + +#ifdef ENABLE_NVTX_PROFILE + // End of host-to-device copy. + memcpyRange.End(); +#endif +} + ONNX_OPERATOR_KERNEL_EX( Recv, kMSDomain, @@ -40,40 +137,89 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator preRange( - "PreRecv-" + std::to_string(src), profile::Color::Green); + "PreRecv-" + std::to_string(src), profile::Color::Green); // Begin of preparation for receiving data. preRange.Begin(); #endif - // Create buffers - const int tensor_num = static_cast(element_types_.size()); - // TODO move the following variables to member variables for extending life-time - // if we want to make the entire call async - std::vector prefix_tensor_shape_sizes(tensor_num); - std::vector aggregated_tensor_shapes; - size_t aggregated_aligned_tensor_bytes = 0; - // Start communication int world_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank)); ORT_ENFORCE(world_rank != src, "Receive data from rank ", src, " on the rank ", world_rank, "."); - // Receive shape sizes and aggregated size - CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), - tensor_num * static_cast(sizeof(size_t)), - src, - static_cast(tag_)}; - CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, - static_cast(sizeof(size_t)), - src, - static_cast(tag_)}; + const int num_tensors = static_cast(element_types_.size()); + std::vector tensor_sizes_in_bytes; + std::vector tensor_shapes; + // TODO move the following variables to member variables for extending life-time + // if we want to make the entire call async + size_t aggregated_aligned_tensor_bytes = 0; + std::vector prefix_tensor_shape_sizes; + std::vector aggregated_tensor_shapes; + // tensor_offsets_in_bytes[i] is the starting byte of the i-th tensor in the send tensor buffer + std::vector tensor_offsets_in_bytes; + // Whether shapes are statically inferrable. + bool all_shapes_inferred = true; + // At iteration i, the i-th received tensor is processed. + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + // The first input is a boolean control signal. We only check actual received tensors. + auto shape_inferred = ctx->TryGetInferredOutputShape(i + 1, inferred_shape); + if (!shape_inferred) { + all_shapes_inferred = false; + break; + } + } - // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because - // MPI_Recv may block the entire GPU until it returns. - MPI_CHECK(MPI_Recv( - info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, - info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + std::vector received_tensors(num_tensors); + if (all_shapes_inferred) { + // Create outputs before communication because all shapes are inferred. + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + // The first input is a boolean control signal. We only work on actual received tensors before that. + ORT_ENFORCE(ctx->TryGetInferredOutputShape(i + 1, inferred_shape)); + // If shape is statically inferred, we declare output here and + // access its shape from operator's context in GetTensorShapesAndSizes(...). + received_tensors[i] = ctx->Output(i + 1, inferred_shape); + } + + GetTensorShapesAndSizes( + false, // value of "is_index_input". Received tensors are "output"s so this flag is "false". + 1, // First received tensor's index. + num_tensors, // Number of tensors to received. + ctx, + tensor_sizes_in_bytes, + tensor_shapes); + + // Extract information needed for copying input tensors from a big buffer + // to individual locations. + // Only that big buffer will be received through MPI. + ComputeShapeRelatedInfo( + tensor_sizes_in_bytes, + tensor_shapes, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes, + tensor_offsets_in_bytes); + } else { + ReceiveShapeInfo( + src, + num_tensors, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes); + + // Create output tensors. Unlike the case where we can infer output shapes before communication, + // we need to create outputs after receiving shapes. + size_t begin = 0; + for (int i = 0; i < num_tensors; ++i) { + std::vector tensor_shape(aggregated_tensor_shapes.begin() + begin, + aggregated_tensor_shapes.begin() + prefix_tensor_shape_sizes[i]); + received_tensors[i] = ctx->Output(i + 1, tensor_shape); + // Move the "begin" to the beginning dimension of the next received tensor. + begin = prefix_tensor_shape_sizes[i]; + } + } #ifdef ENABLE_NVTX_PROFILE // This range object includes the first MPI_Recv which receives a scalar. @@ -81,72 +227,18 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { preRange.End(); #endif -#ifdef ENABLE_NVTX_PROFILE - profile::NvtxRangeCreator recvRange( - "Recv-" + std::to_string(src), profile::Color::Green); - // Begin of major communication tasks. - // The first MPI_Recv is not included because we don't want to - // count waiting time before setting up the actual communication. - recvRange.Begin(); -#endif - - MPI_CHECK(MPI_Recv( - info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, - info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - - // Prepare receive shapes and data buffer - aggregated_tensor_shapes.resize(prefix_tensor_shape_sizes[tensor_num - 1]); - IAllocatorUniquePtr buffer = - AllocateBufferOnCPUPinned(static_cast(aggregated_aligned_tensor_bytes)); - CommInfo_t info_shapes{aggregated_tensor_shapes.data(), - static_cast(aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)), - src, - static_cast(tag_)}; - CommInfo_t info_data{buffer.get(), - static_cast(aggregated_aligned_tensor_bytes), - src, - static_cast(tag_)}; - - // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because - // MPI_Recv may block the entire GPU until it returns. - MPI_CHECK(MPI_Recv( - info_shapes.buffer, info_shapes.size, MPI_CHAR, - info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - - MPI_CHECK(MPI_Recv( - info_data.buffer, info_data.size, MPI_CHAR, - info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - -#ifdef ENABLE_NVTX_PROFILE - // End of actual communication. - recvRange.End(); -#endif + // At this stage, all shape information (either inferred locally or received from the source process) + // required to receive tensors are ready. + // Create buffer and receive data. + IAllocatorUniquePtr buffer; + ReceiveData(num_tensors, received_tensors, src, aggregated_aligned_tensor_bytes, buffer); #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator postRange( - "PostRecv-" + std::to_string(src), profile::Color::Green); + "PostRecv-" + std::to_string(src), profile::Color::Green); postRange.Begin(); #endif - // Create Tensors - size_t begin = 0; - size_t tensor_offset_in_bytes = 0; - for (int i = 0; i < tensor_num; ++i) { - std::vector tensor_shape(aggregated_tensor_shapes.begin() + begin, - aggregated_tensor_shapes.begin() + prefix_tensor_shape_sizes[i]); - begin = prefix_tensor_shape_sizes[i]; - - Tensor* x_tensor = ctx->Output(i + 1, tensor_shape); - // Find the next aligned offset in the tensor buffer to meet alignment requirement - tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes); - - // Keep the sync copy in the previous design - // TODO they can be moved to async call after global stream becoming accessible - ORT_ENFORCE(cudaMemcpy(x_tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, - x_tensor->SizeInBytes(), cudaMemcpyHostToDevice) == cudaSuccess); - tensor_offset_in_bytes += x_tensor->SizeInBytes(); - } - // Set first output after communication is done. Tensor* output_signal_tensor = ctx->Output(0, {}); bool* output_signal = output_signal_tensor->template MutableData(); diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.h b/orttraining/orttraining/training_ops/cuda/communication/recv.h index cf24a8b84c..2cef45a6cf 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.h +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.h @@ -20,6 +20,18 @@ public: Status ComputeInternal(OpKernelContext* context) const override; private: + void ReceiveShapeInfo( + const int src, + const int num_tensors, + size_t& aggregated_aligned_tensor_bytes, + std::vector& prefix_tensor_shape_sizes, + std::vector& aggregated_tensor_shapes) const; + void ReceiveData( + const int num_tensors, + std::vector received_tensors, + const int src, + const size_t aggregated_aligned_tensor_bytes, + IAllocatorUniquePtr& buffer) const; int64_t tag_; std::vector element_types_; }; diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.cc b/orttraining/orttraining/training_ops/cuda/communication/send.cc index de03538aa8..13855d016a 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/send.cc @@ -6,6 +6,7 @@ #include "orttraining/training_ops/cuda/communication/send.h" #include "orttraining/training_ops/cuda/communication/common.h" #include "core/profile/profile.h" +#include "core/providers/cuda/cuda_common.h" #include #include @@ -28,9 +29,105 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), Send); -void CUDART_CB HostSend(void* args) { - CommInfo_t* info = reinterpret_cast(args); - MPI_CHECK(MPI_Send(info->buffer, info->size, MPI_CHAR, info->rank, info->tag, MPI_COMM_WORLD)); +void Send::SendShapeInfo( + const int dst, + const int num_tensors, // Number of sent tensors. + size_t aggregated_aligned_tensor_bytes, + std::vector prefix_tensor_shape_sizes, + std::vector aggregated_tensor_shapes) const { + const int num_tensors_in_bytes = num_tensors * static_cast(sizeof(size_t)); + ORT_ENFORCE(num_tensors_in_bytes < INT_MAX, + "Total tensor number larger than MPI size limit"); + + CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), + num_tensors_in_bytes, + dst, + static_cast(tag_)}; + ORT_ENFORCE(aggregated_aligned_tensor_bytes < INT_MAX, + "Aggregated tensor size larger than MPI size limit"); + + CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, + static_cast(sizeof(size_t)), + dst, + static_cast(tag_)}; + + int total_tensor_dim_in_bytes = static_cast( + aggregated_tensor_shapes.size()) * + static_cast(sizeof(int64_t)); + ORT_ENFORCE(total_tensor_dim_in_bytes < INT_MAX, + "Total dimensions of tensors larger than MPI size limit"); + + CommInfo_t info_shapes{aggregated_tensor_shapes.data(), + total_tensor_dim_in_bytes, + dst, + static_cast(tag_)}; + + // Directly use CPU to wait MPI_Send. We cannot use GPU callback because + // MPI_Send may block the entire GPU until it returns. + MPI_CHECK(MPI_Send( + info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, + info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD)); + + MPI_CHECK(MPI_Send( + info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, + info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD)); + + MPI_CHECK(MPI_Send( + info_shapes.buffer, info_shapes.size, MPI_CHAR, + info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD)); +} + +void Send::SendData( + OpKernelContext* ctx, + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector tensor_offsets_in_bytes, + std::vector tensor_sizes_in_bytes) const { +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator memcpyRange( + "SendMemcpy-" + std::to_string(dst), profile::Color::Red); + // Begin of major communication tasks. + // The previous MPI_Send's are not included because we don't want to + // count waiting time before setting up the actual communication. + memcpyRange.Begin(); +#endif + + IAllocatorUniquePtr buffer = AllocateBufferOnCPUPinned( + aggregated_aligned_tensor_bytes); + + for (int i = 0; i < num_tensors; ++i) { + const Tensor* tensor = ctx->Input(i + 2); + CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), + tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost)); + } + +#ifdef ENABLE_NVTX_PROFILE + memcpyRange.End(); +#endif + +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator sendRange( + "Send-" + std::to_string(dst), profile::Color::Red); + // Begin of major communication tasks. + // The previous MPI_Send's are not included because we don't want to + // count waiting time before setting up the actual communication. + sendRange.Begin(); +#endif + + CommInfo_t info_data{buffer.get(), + static_cast(aggregated_aligned_tensor_bytes), + dst, + static_cast(tag_)}; + + MPI_CHECK(MPI_Send( + info_data.buffer, info_data.size, MPI_CHAR, + info_data.rank, info_data.tag, MPI_COMM_WORLD)); + +#ifdef ENABLE_NVTX_PROFILE + // End of major communication tasks. + sendRange.End(); +#endif } Status Send::ComputeInternal(OpKernelContext* ctx) const { @@ -44,129 +141,73 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const { const int64_t* remote_rank = remote_rank_tensor->template Data(); const int dst = static_cast(*remote_rank); + // Same-rank communication is not allowed because we currently don't have async Send/Recv. + int world_rank; + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank)); + ORT_ENFORCE(world_rank != dst, "Sending data to rank ", dst, " on the rank ", world_rank, "."); + #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator preRange( - "PreSend-" + std::to_string(dst), profile::Color::Red); + "PreSend-" + std::to_string(dst), profile::Color::Red); // Begin of preparation for sending data. This time range includes // the time for sending a scalar. preRange.Begin(); #endif - // Create buffers - const int tensor_num = static_cast(element_types_.size()); + const int num_tensors = static_cast(element_types_.size()); + std::vector tensor_sizes_in_bytes; + std::vector tensor_shapes; + GetTensorShapesAndSizes( + true, + 2, // First sent tensor's index. + num_tensors, // Number of tensors to send + ctx, + tensor_sizes_in_bytes, + tensor_shapes); + // TODO move the following variables to member variables for extending life-time // if we want to make the entire call async + size_t aggregated_aligned_tensor_bytes = 0; std::vector prefix_tensor_shape_sizes; std::vector aggregated_tensor_shapes; - size_t aggregated_aligned_tensor_bytes = 0; // tensor_offsets_in_bytes[i] is the starting byte of the i-th tensor in the send tensor buffer std::vector tensor_offsets_in_bytes; - // tensor_sizes_in_bytes[i] = (# of elements in the i-th tensor) * sizeof(the i-th tensor's element type) - std::vector tensor_sizes_in_bytes; - // Same-rank communication is not allowed because we currently don't have async Send/Recv. - int world_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); - ORT_ENFORCE(world_rank != dst, "Sending data to rank ", dst, " on the rank ", world_rank, "."); + // Extract information needed for copying input tensors into a big buffer. + // Only that big buffer will be sent. + ComputeShapeRelatedInfo( + tensor_sizes_in_bytes, + tensor_shapes, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes, + tensor_offsets_in_bytes); - // Compute tensor shapes and sizes - size_t prefix_tensor_shape_size_sum = 0; - for (int i = 0; i < tensor_num; ++i) { - const Tensor* x_tensor = ctx->Input(i + 2); - prefix_tensor_shape_size_sum += x_tensor->Shape().NumDimensions(); - prefix_tensor_shape_sizes.push_back(prefix_tensor_shape_size_sum); - aggregated_tensor_shapes.insert(aggregated_tensor_shapes.end(), - x_tensor->Shape().GetDims().begin(), - x_tensor->Shape().GetDims().end()); - - // Find the next aligned offset in the tensor buffer to meet alignment requirement - aggregated_aligned_tensor_bytes = GetAggregatedAlignedAddress(aggregated_aligned_tensor_bytes); - tensor_offsets_in_bytes.push_back(aggregated_aligned_tensor_bytes); - aggregated_aligned_tensor_bytes += x_tensor->SizeInBytes(); - tensor_sizes_in_bytes.push_back(x_tensor->SizeInBytes()); + bool all_shapes_inferred = true; + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + auto shape_inferred = ctx->TryGetInferredInputShape(i + 2, inferred_shape); + if (!shape_inferred) { + all_shapes_inferred = false; + break; + } } - IAllocatorUniquePtr buffer = AllocateBufferOnCPUPinned( - static_cast(aggregated_aligned_tensor_bytes)); - - // Keep the sync copy in the previous design - // TODO they can be moved to async call after global stream becoming accessible - for (int i = 0; i < tensor_num; ++i) { - const Tensor* x_tensor = ctx->Input(i + 2); - ORT_ENFORCE(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], x_tensor->DataRaw(), - tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost) == cudaSuccess); + // Communicate shape information when it cannot be inferred. + if (!all_shapes_inferred) { + SendShapeInfo(dst, num_tensors, aggregated_aligned_tensor_bytes, prefix_tensor_shape_sizes, aggregated_tensor_shapes); } - - // Prepare MPI communication info - int tensor_num_in_bytes = tensor_num * static_cast(sizeof(size_t)); - ORT_ENFORCE(tensor_num_in_bytes < INT_MAX, - "Total tensor number larger than MPI size limit"); - CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), - tensor_num_in_bytes, - dst, - static_cast(tag_)}; - - ORT_ENFORCE(aggregated_aligned_tensor_bytes < INT_MAX, - "Aggregated tensor size larger than MPI size limit"); - CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, - static_cast(sizeof(size_t)), - dst, - static_cast(tag_)}; - - int total_tensor_dim_in_bytes = static_cast( - aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)); - ORT_ENFORCE(total_tensor_dim_in_bytes < INT_MAX, - "Total dimensions of tensors larger than MPI size limit"); - CommInfo_t info_shapes{aggregated_tensor_shapes.data(), - total_tensor_dim_in_bytes, - dst, - static_cast(tag_)}; - - CommInfo_t info_data{buffer.get(), - static_cast(aggregated_aligned_tensor_bytes), - dst, - static_cast(tag_)}; - - - // Directly use CPU to wait MPI_Send. We cannot use GPU callback because - // MPI_Send may block the entire GPU until it returns. - MPI_CHECK(MPI_Send( - info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, - info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD)); - #ifdef ENABLE_NVTX_PROFILE + // End of data preparation and shape communication. preRange.End(); #endif -#ifdef ENABLE_NVTX_PROFILE - profile::NvtxRangeCreator sendRange( - "Send-" + std::to_string(dst), profile::Color::Red); - // Begin of major communication tasks. - // The first MPI_Send is not included because we don't want to - // count waiting time before setting up the actual communication. - sendRange.Begin(); -#endif - - MPI_CHECK(MPI_Send( - info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, - info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD)); - - MPI_CHECK(MPI_Send( - info_shapes.buffer, info_shapes.size, MPI_CHAR, - info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD)); - - MPI_CHECK(MPI_Send( - info_data.buffer, info_data.size, MPI_CHAR, - info_data.rank, info_data.tag, MPI_COMM_WORLD)); - -#ifdef ENABLE_NVTX_PROFILE - // End of major communication tasks. - sendRange.End(); -#endif + // Send tensors. + SendData(ctx, dst, num_tensors, aggregated_aligned_tensor_bytes, tensor_offsets_in_bytes, tensor_sizes_in_bytes); #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator postRange( - "PostSend-" + std::to_string(dst), profile::Color::Red); + "PostSend-" + std::to_string(dst), profile::Color::Red); // Begin of post communication tasks. postRange.Begin(); #endif diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.h b/orttraining/orttraining/training_ops/cuda/communication/send.h index 170b2aff6e..ace4bacb37 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.h +++ b/orttraining/orttraining/training_ops/cuda/communication/send.h @@ -21,6 +21,20 @@ public: Status ComputeInternal(OpKernelContext* context) const override; private: + void SendShapeInfo( + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector prefix_tensor_shape_sizes, + std::vector aggregated_tensor_shapes) const; + void SendData( + OpKernelContext* ctx, + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector tensor_offsets_in_bytes, + std::vector tensor_sizes_in_bytes) const; + int64_t tag_; std::vector element_types_; };