diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 0a97d39390..2287e5f43a 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -128,6 +128,16 @@ class OpKernelContext { // unless static optimization pre-allocates it. SparseTensor* Output(int index, size_t num_values, const TensorShape& shape); + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredInputShape(int index, TensorShape& shape) const; + + // Retrieve indexed shape obtained from memory planning before actual + // computation. If the indexed shape cannot be inferred, this function returns + // false. + bool TryGetInferredOutputShape(int index, TensorShape& shape) const; + const logging::Logger& Logger() const { return *logger_; } diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 6cc3345b1c..848b227692 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -73,6 +73,13 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShap return status; } +bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const { + // By default, there is not information about inferred shape, so this default + // implementation always returns false. The derived class of IExecutionFrame + // can override this function to provide, for example, activations' shape information. + return false; +} + AllocatorPtr IExecutionFrame::GetAllocator(const OrtMemoryInfo& info) const { return GetAllocatorImpl(info); } @@ -235,7 +242,7 @@ ExecutionFrame::ExecutionFrame(const std::vector& feed_mlvalue_idxs, const //if there are some traditional ml value type in inputs disable the memory pattern optimization. if (all_tensors) { - mem_patterns_ = session_state.GetMemoryPatternGroup(input_shapes, feed_mlvalue_idxs); + mem_patterns_ = session_state.GetMemoryPatternGroup(input_shapes, feed_mlvalue_idxs, inferred_shapes_); // if no existing patterns, generate one in this executionframe if (!mem_patterns_) { planner_ = onnxruntime::make_unique(*session_state.GetExecutionPlan()); @@ -623,4 +630,26 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const { return planner_->GeneratePatterns(out); } + +bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { + // NodeArg index to OrtValue index. + int ort_value_idx = GetNodeIdxToMLValueIdx(index); + + // Check if index is valid. + if (ort_value_idx == NodeIndexInfo::kInvalidEntry) { + return false; + } + + // Search for inferred shape. + // If inferred shape is found, it's assigned to "shape" so that caller can use it. + auto it = inferred_shapes_.find(ort_value_idx); + if (it != inferred_shapes_.end()) { + shape = it->second; + return true; + } + + // Tell the caller if the search is successful or not. + return false; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index f0b211994f..dc500314d3 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -53,6 +54,10 @@ class IExecutionFrame { // Shape is required for tensors but not traditional ML values. Status GetOrCreateNodeOutputMLValue(int index, const TensorShape* shape, OrtValue*& p_ort_value, size_t nnz = 0); + // This function try retrieve the inferred shapes for the given NodeArg index. + // If the retrival is sucessful, this function returns true and false otherwise. + virtual bool TryGetInferredShape(int index, TensorShape& shape) const; + /** * write the output values to the 'fetches' vector * Don't access the values after SessionState is destroyed @@ -129,6 +134,10 @@ class ExecutionFrame final : public IExecutionFrame { return planner_ != nullptr; } + // This function try retrieve the inferred shapes for the given NodeArg index. + // If the retrival is sucessful, this function returns true and false otherwise. + bool TryGetInferredShape(int index, TensorShape& shape) const override; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExecutionFrame); @@ -168,5 +177,11 @@ class ExecutionFrame final : public IExecutionFrame { // Big chunks on different locations that will be used by mem_pattern. std::map buffers_; + + // Given the input shapes of the executed graph, ExecutionFrame tries inferring + // all symbolic shapes. inferred_shapes_[i] is the shape of OrtValue indexed + // by i, if the key i exists. + // inferred_shapes_ is generated togehter with mem_patterns_. + std::unordered_map inferred_shapes_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index a21bab8166..b3c05bb4e5 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -38,6 +38,14 @@ SparseTensor* OpKernelContext::Output(int index, size_t nnz, const TensorShape& return p_ml_value ? p_ml_value->GetMutable() : nullptr; } +bool OpKernelContext::TryGetInferredInputShape(int index, TensorShape& shape) const { + return execution_frame_->TryGetInferredShape(GetInputArgIndex(index), shape); +} + +bool OpKernelContext::TryGetInferredOutputShape(int index, TensorShape& shape) const { + return execution_frame_->TryGetInferredShape(GetOutputArgIndex(index), shape); +} + OrtValue* OpKernelContext::OutputMLValue(int index, const TensorShape& shape, size_t nnz) { if (index < 0 || index >= OutputCount()) return nullptr; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a376956e72..2a03f55a9d 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -252,11 +252,53 @@ Status ResolveDimParams(const GraphViewer& graph, } return Status::OK(); } + +Status ResolveSizeAndShape( + const NodeArg* arg, + const std::unordered_map& symbolic_dimensions, + size_t& size, // total number of elements. It's 0 if shape is unknown. + std::vector& resolved_shape) { + if (!arg->Shape()) { + // 0 means no shape information. + size = 0; + return Status::OK(); + } + + std::vector shape; + + SafeInt safe_size = 1; + for (auto& dim : arg->Shape()->dim()) { + if (dim.has_dim_param()) { + auto it = symbolic_dimensions.find(dim.dim_param()); + if (it == symbolic_dimensions.end()) { + return Status(ONNXRUNTIME, FAIL, "Unknown symbolic dimension, " + dim.dim_param() + ", found in memory pattern compute."); + } + safe_size *= it->second; + shape.push_back(it->second); + } else if (dim.has_dim_value() && dim.dim_value() > 0) { + safe_size *= dim.dim_value(); + shape.push_back(dim.dim_value()); + } else { + // tensor shape is unknown. + safe_size = 0; + } + } + + size = safe_size; + + // Only assign shape if all symbolic dimensions are resolved. + if (size != 0) { + resolved_shape = std::move(shape); + } + + return Status::OK(); +} } // namespace Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shape, const std::vector& feed_mlvalue_idxs, - MemoryPatternGroup* output) const { + MemoryPatternGroup* output, + std::unordered_map& resolved_shapes) const { std::map feeds; for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) { std::string name; @@ -282,38 +324,27 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorIsTensorType()) continue; const auto* ml_data_type = static_cast(ml_type)->GetElementType(); + + auto* arg = node->OutputDefs()[i]; + size_t size = 0; + std::vector resolved_shape; + ORT_RETURN_IF_ERROR(ResolveSizeAndShape(arg, map, size, resolved_shape)); + + // Store all valid resolved shapes. They will be queried in, for example, + // Recv operator to bypass the dependency of output shapes on inputs. + if (size != 0) { + resolved_shapes[ml_value_idx] = resolved_shape; + } + + // Plan memory if conditions are met. if (exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate && - ml_data_type != DataTypeImpl::GetType()) { - //calculate size - auto* arg = node->OutputDefs()[i]; - if (!arg->Shape()) - continue; - size_t size = 0; - SafeInt len = 1; - for (auto& dim : arg->Shape()->dim()) { - if (dim.has_dim_param()) { - auto it = map.find(dim.dim_param()); - if (it == map.end()) { - return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute"); - } - len *= it->second; - } else if (dim.has_dim_value()) { - len *= dim.dim_value(); - } else { - // tensor shape is unknown - len = 0; - } - } - - // Skip planning for this tensor if shape is unknown - if (len == 0) { - continue; - } - - if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) { + ml_data_type != DataTypeImpl::GetType() && size != 0) { + size_t aligned_size = 0; + if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(size, ml_data_type->Size(), &aligned_size)) { return Status(ONNXRUNTIME, FAIL, "Size overflow"); } - mem_planner.TraceAllocation(ml_value_idx, size); + + mem_planner.TraceAllocation(ml_value_idx, aligned_size); } } //release nodes @@ -337,7 +368,8 @@ Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shapes, - const std::vector& feed_mlvalue_idxs) const { + const std::vector& feed_mlvalue_idxs, + std::unordered_map& inferred_shapes) const { int64_t key = CalculateMemoryPatternsKey(input_shapes); std::lock_guard lock(mem_patterns_lock_); @@ -345,10 +377,11 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< if (it == mem_patterns_.end()) { #ifdef ENABLE_TRAINING auto mem_patterns = onnxruntime::make_unique(); - if (GeneratePatternGroupCache(input_shapes, feed_mlvalue_idxs, mem_patterns.get()).IsOK()) { + if (GeneratePatternGroupCache(input_shapes, feed_mlvalue_idxs, mem_patterns.get(), inferred_shapes).IsOK()) { key = CalculateMemoryPatternsKey(input_shapes); auto ptr = mem_patterns.get(); mem_patterns_[key] = std::move(mem_patterns); + shape_patterns_[key] = inferred_shapes; return ptr; } return nullptr; @@ -358,6 +391,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< #endif } + inferred_shapes = shape_patterns_[key]; return it->second.get(); } diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index fd435f3dad..2f0dac6f20 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -187,7 +187,8 @@ class SessionState { */ const MemoryPatternGroup* GetMemoryPatternGroup( const std::vector>& input_shapes, - const std::vector& feed_mlvalue_idxs) const; + const std::vector& feed_mlvalue_idxs, + std::unordered_map& inferred_shapes) const; /** Set generated memory pattern with a given input shapes. @@ -278,7 +279,8 @@ class SessionState { Status GeneratePatternGroupCache( const std::vector>& input_shape, const std::vector& feed_mlvalue_idxs, - MemoryPatternGroup* output) const; + MemoryPatternGroup* output, + std::unordered_map& inferred_shapes) const; #endif // cache of the constructed kernels to avoid spending construction time per executor @@ -346,6 +348,7 @@ class SessionState { // cache for the generated mem_patterns. key is calculated based on input shapes. mutable std::map> mem_patterns_; + mutable std::map> shape_patterns_; NameNodeInfoMapType input_names_to_nodeinfo_mapping_; NameNodeInfoMapType output_names_to_nodeinfo_mapping_; diff --git a/orttraining/orttraining/training_ops/cuda/communication/common.h b/orttraining/orttraining/training_ops/cuda/communication/common.h index a87e6ab850..e6ec8afdb3 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/common.h +++ b/orttraining/orttraining/training_ops/cuda/communication/common.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" #pragma once namespace onnxruntime { @@ -26,5 +27,80 @@ inline size_t GetAggregatedAlignedAddress(size_t old_addr) { return new_addr; } +// This function extracts shapes and sizes for tensors indexed by +// begin, begin + 1, ..., begin + count - 1. +// tensor_sizes[i]: i-th indexed tensor's size. +// tensor_shapes[i]: i-th indexed tensor's shape. +inline void GetTensorShapesAndSizes( + const bool is_index_input, // Index inputs from the context if true. Otherwise, use outputs. + const int begin, // Index of the first sent/received tensor in the context. + const int count, // Number of sent/received tensors. + OpKernelContext* ctx, // The context. + std::vector& tensor_sizes, // tensor_sizes[i] is the size of i-th sent/received tensor in byte. + std::vector& tensor_shapes) { // tensor_shapes[i] is the i-th sent/received tensor. + + // Helper function to retrieve input or output tensor. + auto get_tensor = [&](const int index) -> const Tensor* { + if (is_index_input) { + return ctx->Input(begin + index); + } else { + return ctx->Output(begin + index); + } + }; + + // Get tensors and shapes for indexed tensors from context. + tensor_sizes.resize(count); + tensor_shapes.resize(count); + for (int i = 0; i < count; ++i) { + const Tensor* tensor = get_tensor(i); + tensor_sizes[i] = tensor->SizeInBytes(); + tensor_shapes[i] = tensor->Shape(); + } +} + +// Compute shape-related information from given tensor shapes. +inline void ComputeShapeRelatedInfo( + // tensor_sizes[i] is the size of i-th sent/received tensor in byte. + const std::vector tensor_sizes, + // tensor_shapes[i] is the i-th sent/received tensor. + const std::vector tensor_shapes, + // The size in bytes if we concatenate all tensors into one single tensor. + // It may be larger than the original size due to memory alignment. + size_t& aggregated_aligned_tensor_bytes, + // aggregated_tensor_shapes[prefix_tensor_shape_sizes[i]] is the first dimension of the i-th tensor. + // aggregated_tensor_shapes[prefix_tensor_shape_sizes[i + 1]] is the element after the last dimension of the i-th tensor. + std::vector& prefix_tensor_shape_sizes, + // This field is the concatenation of all received tensors' shapes. + // Assume that there are two tensors A and B with rank NA and NB, respectively. + // aggregated_tensor_shapes = [A_shape[0], A_shape[1], ..., A_shape[NA-1], B_shape[0], B_shape[1], ..., B_shape[NB-1]]. + std::vector& aggregated_tensor_shapes, + // tensor_offsets_in_bytes[i] is the offset of the starting byte of the i-th tensor in the communicated tensor buffer. + // That is, i-th tensor's first element is tensor_buffer[tensor_offsets_in_bytes[i]]. + std::vector& tensor_offsets_in_bytes) { + // Initialize outputs. + aggregated_aligned_tensor_bytes = 0; + prefix_tensor_shape_sizes.resize(0); + aggregated_tensor_shapes.resize(0); + tensor_offsets_in_bytes.resize(0); + + // Compute shape information. + size_t prefix_tensor_shape_size_sum = 0; + for (int i = 0; static_cast(i) < tensor_shapes.size(); ++i) { + const auto& shape = tensor_shapes[i]; + prefix_tensor_shape_size_sum += shape.NumDimensions(); + prefix_tensor_shape_sizes.push_back(prefix_tensor_shape_size_sum); + aggregated_tensor_shapes.insert(aggregated_tensor_shapes.end(), + shape.GetDims().begin(), + shape.GetDims().end()); + + // aggregated_aligned_tensor_bytes is the first non-occupied address. + // Starting form aggregated_aligned_tensor_bytes, we find the next aligned offset in the + // tensor buffer to meet alignment requirement. + aggregated_aligned_tensor_bytes = GetAggregatedAlignedAddress(aggregated_aligned_tensor_bytes); + tensor_offsets_in_bytes.push_back(aggregated_aligned_tensor_bytes); + aggregated_aligned_tensor_bytes += tensor_sizes[i]; + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.cc b/orttraining/orttraining/training_ops/cuda/communication/recv.cc index 9d43323b88..42be0c0de9 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.cc @@ -6,6 +6,7 @@ #include "orttraining/training_ops/cuda/communication/recv.h" #include "orttraining/training_ops/cuda/communication/common.h" #include "core/profile/profile.h" +#include "core/providers/cuda/cuda_common.h" #include #include "orttraining/core/framework/mpi_setup.h" @@ -13,6 +14,102 @@ namespace onnxruntime { namespace cuda { +void Recv::ReceiveShapeInfo( + const int src, + const int num_tensors, + size_t& aggregated_aligned_tensor_bytes, + std::vector& prefix_tensor_shape_sizes, + std::vector& aggregated_tensor_shapes) const { + // Resize vector so that the following .data() returns meaningful pointer. + prefix_tensor_shape_sizes.resize(num_tensors); + CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), + num_tensors * static_cast(sizeof(size_t)), + src, + static_cast(tag_)}; + CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, + static_cast(sizeof(size_t)), + src, + static_cast(tag_)}; + // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because + // MPI_Recv may block the entire GPU until it returns. + MPI_CHECK(MPI_Recv( + info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, + info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + + MPI_CHECK(MPI_Recv( + info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, + info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + + // prefix_tensor_shape_sizes's last element is the number of total dimensions. + // If a 3-D tensor and a 2-D tensor are sent, its value is 2 + 3 = 5. + aggregated_tensor_shapes.resize(prefix_tensor_shape_sizes[num_tensors - 1]); + CommInfo_t info_shapes{aggregated_tensor_shapes.data(), + static_cast(aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)), + src, + static_cast(tag_)}; + MPI_CHECK(MPI_Recv( + info_shapes.buffer, info_shapes.size, MPI_CHAR, + info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); +} + +void Recv::ReceiveData( + const int num_tensors, + std::vector received_tensors, + const int src, + const size_t aggregated_aligned_tensor_bytes, + IAllocatorUniquePtr& buffer) const { +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator recvRange( + "Recv-" + std::to_string(src), profile::Color::Green); + // Begin of major communication tasks. + // The first MPI_Recv is not included because we don't want to + // count waiting time before setting up the actual communication. + recvRange.Begin(); +#endif + buffer = AllocateBufferOnCPUPinned(static_cast(aggregated_aligned_tensor_bytes)); + CommInfo_t info_data{buffer.get(), + static_cast(aggregated_aligned_tensor_bytes), + src, + static_cast(tag_)}; + + MPI_CHECK(MPI_Recv( + info_data.buffer, info_data.size, MPI_CHAR, + info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + +#ifdef ENABLE_NVTX_PROFILE + // End of actual communication. + recvRange.End(); +#endif + +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator memcpyRange( + "RecvMemcpy-" + std::to_string(src), profile::Color::Green); + // Begin of host-to-device memory copy. + memcpyRange.Begin(); +#endif + + // Copy tensors from buffer to outputs. + size_t tensor_offset_in_bytes = 0; + for (int i = 0; i < num_tensors; ++i) { + Tensor* tensor = received_tensors[i]; + + // Find the next aligned offset in the tensor buffer to meet alignment requirement + tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes); + + // Keep the sync copy in the previous design + // TODO they can be moved to async call after global stream becoming accessible + CUDA_CALL(cudaMemcpyAsync(tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, + tensor->SizeInBytes(), cudaMemcpyHostToDevice)); + tensor_offset_in_bytes += tensor->SizeInBytes(); + } + AddDeferredReleaseCPUPtr(buffer.release()); + +#ifdef ENABLE_NVTX_PROFILE + // End of host-to-device copy. + memcpyRange.End(); +#endif +} + ONNX_OPERATOR_KERNEL_EX( Recv, kMSDomain, @@ -40,40 +137,89 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator preRange( - "PreRecv-" + std::to_string(src), profile::Color::Green); + "PreRecv-" + std::to_string(src), profile::Color::Green); // Begin of preparation for receiving data. preRange.Begin(); #endif - // Create buffers - const int tensor_num = static_cast(element_types_.size()); - // TODO move the following variables to member variables for extending life-time - // if we want to make the entire call async - std::vector prefix_tensor_shape_sizes(tensor_num); - std::vector aggregated_tensor_shapes; - size_t aggregated_aligned_tensor_bytes = 0; - // Start communication int world_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank)); ORT_ENFORCE(world_rank != src, "Receive data from rank ", src, " on the rank ", world_rank, "."); - // Receive shape sizes and aggregated size - CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), - tensor_num * static_cast(sizeof(size_t)), - src, - static_cast(tag_)}; - CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, - static_cast(sizeof(size_t)), - src, - static_cast(tag_)}; + const int num_tensors = static_cast(element_types_.size()); + std::vector tensor_sizes_in_bytes; + std::vector tensor_shapes; + // TODO move the following variables to member variables for extending life-time + // if we want to make the entire call async + size_t aggregated_aligned_tensor_bytes = 0; + std::vector prefix_tensor_shape_sizes; + std::vector aggregated_tensor_shapes; + // tensor_offsets_in_bytes[i] is the starting byte of the i-th tensor in the send tensor buffer + std::vector tensor_offsets_in_bytes; + // Whether shapes are statically inferrable. + bool all_shapes_inferred = true; + // At iteration i, the i-th received tensor is processed. + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + // The first input is a boolean control signal. We only check actual received tensors. + auto shape_inferred = ctx->TryGetInferredOutputShape(i + 1, inferred_shape); + if (!shape_inferred) { + all_shapes_inferred = false; + break; + } + } - // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because - // MPI_Recv may block the entire GPU until it returns. - MPI_CHECK(MPI_Recv( - info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, - info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + std::vector received_tensors(num_tensors); + if (all_shapes_inferred) { + // Create outputs before communication because all shapes are inferred. + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + // The first input is a boolean control signal. We only work on actual received tensors before that. + ORT_ENFORCE(ctx->TryGetInferredOutputShape(i + 1, inferred_shape)); + // If shape is statically inferred, we declare output here and + // access its shape from operator's context in GetTensorShapesAndSizes(...). + received_tensors[i] = ctx->Output(i + 1, inferred_shape); + } + + GetTensorShapesAndSizes( + false, // value of "is_index_input". Received tensors are "output"s so this flag is "false". + 1, // First received tensor's index. + num_tensors, // Number of tensors to received. + ctx, + tensor_sizes_in_bytes, + tensor_shapes); + + // Extract information needed for copying input tensors from a big buffer + // to individual locations. + // Only that big buffer will be received through MPI. + ComputeShapeRelatedInfo( + tensor_sizes_in_bytes, + tensor_shapes, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes, + tensor_offsets_in_bytes); + } else { + ReceiveShapeInfo( + src, + num_tensors, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes); + + // Create output tensors. Unlike the case where we can infer output shapes before communication, + // we need to create outputs after receiving shapes. + size_t begin = 0; + for (int i = 0; i < num_tensors; ++i) { + std::vector tensor_shape(aggregated_tensor_shapes.begin() + begin, + aggregated_tensor_shapes.begin() + prefix_tensor_shape_sizes[i]); + received_tensors[i] = ctx->Output(i + 1, tensor_shape); + // Move the "begin" to the beginning dimension of the next received tensor. + begin = prefix_tensor_shape_sizes[i]; + } + } #ifdef ENABLE_NVTX_PROFILE // This range object includes the first MPI_Recv which receives a scalar. @@ -81,72 +227,18 @@ Status Recv::ComputeInternal(OpKernelContext* ctx) const { preRange.End(); #endif -#ifdef ENABLE_NVTX_PROFILE - profile::NvtxRangeCreator recvRange( - "Recv-" + std::to_string(src), profile::Color::Green); - // Begin of major communication tasks. - // The first MPI_Recv is not included because we don't want to - // count waiting time before setting up the actual communication. - recvRange.Begin(); -#endif - - MPI_CHECK(MPI_Recv( - info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, - info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - - // Prepare receive shapes and data buffer - aggregated_tensor_shapes.resize(prefix_tensor_shape_sizes[tensor_num - 1]); - IAllocatorUniquePtr buffer = - AllocateBufferOnCPUPinned(static_cast(aggregated_aligned_tensor_bytes)); - CommInfo_t info_shapes{aggregated_tensor_shapes.data(), - static_cast(aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)), - src, - static_cast(tag_)}; - CommInfo_t info_data{buffer.get(), - static_cast(aggregated_aligned_tensor_bytes), - src, - static_cast(tag_)}; - - // Directly use CPU to wait MPI_Recv. We cannot use GPU callback because - // MPI_Recv may block the entire GPU until it returns. - MPI_CHECK(MPI_Recv( - info_shapes.buffer, info_shapes.size, MPI_CHAR, - info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - - MPI_CHECK(MPI_Recv( - info_data.buffer, info_data.size, MPI_CHAR, - info_data.rank, info_data.tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - -#ifdef ENABLE_NVTX_PROFILE - // End of actual communication. - recvRange.End(); -#endif + // At this stage, all shape information (either inferred locally or received from the source process) + // required to receive tensors are ready. + // Create buffer and receive data. + IAllocatorUniquePtr buffer; + ReceiveData(num_tensors, received_tensors, src, aggregated_aligned_tensor_bytes, buffer); #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator postRange( - "PostRecv-" + std::to_string(src), profile::Color::Green); + "PostRecv-" + std::to_string(src), profile::Color::Green); postRange.Begin(); #endif - // Create Tensors - size_t begin = 0; - size_t tensor_offset_in_bytes = 0; - for (int i = 0; i < tensor_num; ++i) { - std::vector tensor_shape(aggregated_tensor_shapes.begin() + begin, - aggregated_tensor_shapes.begin() + prefix_tensor_shape_sizes[i]); - begin = prefix_tensor_shape_sizes[i]; - - Tensor* x_tensor = ctx->Output(i + 1, tensor_shape); - // Find the next aligned offset in the tensor buffer to meet alignment requirement - tensor_offset_in_bytes = GetAggregatedAlignedAddress(tensor_offset_in_bytes); - - // Keep the sync copy in the previous design - // TODO they can be moved to async call after global stream becoming accessible - ORT_ENFORCE(cudaMemcpy(x_tensor->MutableDataRaw(), buffer.get() + tensor_offset_in_bytes, - x_tensor->SizeInBytes(), cudaMemcpyHostToDevice) == cudaSuccess); - tensor_offset_in_bytes += x_tensor->SizeInBytes(); - } - // Set first output after communication is done. Tensor* output_signal_tensor = ctx->Output(0, {}); bool* output_signal = output_signal_tensor->template MutableData(); diff --git a/orttraining/orttraining/training_ops/cuda/communication/recv.h b/orttraining/orttraining/training_ops/cuda/communication/recv.h index cf24a8b84c..2cef45a6cf 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/recv.h +++ b/orttraining/orttraining/training_ops/cuda/communication/recv.h @@ -20,6 +20,18 @@ public: Status ComputeInternal(OpKernelContext* context) const override; private: + void ReceiveShapeInfo( + const int src, + const int num_tensors, + size_t& aggregated_aligned_tensor_bytes, + std::vector& prefix_tensor_shape_sizes, + std::vector& aggregated_tensor_shapes) const; + void ReceiveData( + const int num_tensors, + std::vector received_tensors, + const int src, + const size_t aggregated_aligned_tensor_bytes, + IAllocatorUniquePtr& buffer) const; int64_t tag_; std::vector element_types_; }; diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.cc b/orttraining/orttraining/training_ops/cuda/communication/send.cc index de03538aa8..13855d016a 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/send.cc @@ -6,6 +6,7 @@ #include "orttraining/training_ops/cuda/communication/send.h" #include "orttraining/training_ops/cuda/communication/common.h" #include "core/profile/profile.h" +#include "core/providers/cuda/cuda_common.h" #include #include @@ -28,9 +29,105 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), Send); -void CUDART_CB HostSend(void* args) { - CommInfo_t* info = reinterpret_cast(args); - MPI_CHECK(MPI_Send(info->buffer, info->size, MPI_CHAR, info->rank, info->tag, MPI_COMM_WORLD)); +void Send::SendShapeInfo( + const int dst, + const int num_tensors, // Number of sent tensors. + size_t aggregated_aligned_tensor_bytes, + std::vector prefix_tensor_shape_sizes, + std::vector aggregated_tensor_shapes) const { + const int num_tensors_in_bytes = num_tensors * static_cast(sizeof(size_t)); + ORT_ENFORCE(num_tensors_in_bytes < INT_MAX, + "Total tensor number larger than MPI size limit"); + + CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), + num_tensors_in_bytes, + dst, + static_cast(tag_)}; + ORT_ENFORCE(aggregated_aligned_tensor_bytes < INT_MAX, + "Aggregated tensor size larger than MPI size limit"); + + CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, + static_cast(sizeof(size_t)), + dst, + static_cast(tag_)}; + + int total_tensor_dim_in_bytes = static_cast( + aggregated_tensor_shapes.size()) * + static_cast(sizeof(int64_t)); + ORT_ENFORCE(total_tensor_dim_in_bytes < INT_MAX, + "Total dimensions of tensors larger than MPI size limit"); + + CommInfo_t info_shapes{aggregated_tensor_shapes.data(), + total_tensor_dim_in_bytes, + dst, + static_cast(tag_)}; + + // Directly use CPU to wait MPI_Send. We cannot use GPU callback because + // MPI_Send may block the entire GPU until it returns. + MPI_CHECK(MPI_Send( + info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, + info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD)); + + MPI_CHECK(MPI_Send( + info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, + info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD)); + + MPI_CHECK(MPI_Send( + info_shapes.buffer, info_shapes.size, MPI_CHAR, + info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD)); +} + +void Send::SendData( + OpKernelContext* ctx, + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector tensor_offsets_in_bytes, + std::vector tensor_sizes_in_bytes) const { +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator memcpyRange( + "SendMemcpy-" + std::to_string(dst), profile::Color::Red); + // Begin of major communication tasks. + // The previous MPI_Send's are not included because we don't want to + // count waiting time before setting up the actual communication. + memcpyRange.Begin(); +#endif + + IAllocatorUniquePtr buffer = AllocateBufferOnCPUPinned( + aggregated_aligned_tensor_bytes); + + for (int i = 0; i < num_tensors; ++i) { + const Tensor* tensor = ctx->Input(i + 2); + CUDA_CALL(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], tensor->DataRaw(), + tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost)); + } + +#ifdef ENABLE_NVTX_PROFILE + memcpyRange.End(); +#endif + +#ifdef ENABLE_NVTX_PROFILE + profile::NvtxRangeCreator sendRange( + "Send-" + std::to_string(dst), profile::Color::Red); + // Begin of major communication tasks. + // The previous MPI_Send's are not included because we don't want to + // count waiting time before setting up the actual communication. + sendRange.Begin(); +#endif + + CommInfo_t info_data{buffer.get(), + static_cast(aggregated_aligned_tensor_bytes), + dst, + static_cast(tag_)}; + + MPI_CHECK(MPI_Send( + info_data.buffer, info_data.size, MPI_CHAR, + info_data.rank, info_data.tag, MPI_COMM_WORLD)); + +#ifdef ENABLE_NVTX_PROFILE + // End of major communication tasks. + sendRange.End(); +#endif } Status Send::ComputeInternal(OpKernelContext* ctx) const { @@ -44,129 +141,73 @@ Status Send::ComputeInternal(OpKernelContext* ctx) const { const int64_t* remote_rank = remote_rank_tensor->template Data(); const int dst = static_cast(*remote_rank); + // Same-rank communication is not allowed because we currently don't have async Send/Recv. + int world_rank; + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank)); + ORT_ENFORCE(world_rank != dst, "Sending data to rank ", dst, " on the rank ", world_rank, "."); + #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator preRange( - "PreSend-" + std::to_string(dst), profile::Color::Red); + "PreSend-" + std::to_string(dst), profile::Color::Red); // Begin of preparation for sending data. This time range includes // the time for sending a scalar. preRange.Begin(); #endif - // Create buffers - const int tensor_num = static_cast(element_types_.size()); + const int num_tensors = static_cast(element_types_.size()); + std::vector tensor_sizes_in_bytes; + std::vector tensor_shapes; + GetTensorShapesAndSizes( + true, + 2, // First sent tensor's index. + num_tensors, // Number of tensors to send + ctx, + tensor_sizes_in_bytes, + tensor_shapes); + // TODO move the following variables to member variables for extending life-time // if we want to make the entire call async + size_t aggregated_aligned_tensor_bytes = 0; std::vector prefix_tensor_shape_sizes; std::vector aggregated_tensor_shapes; - size_t aggregated_aligned_tensor_bytes = 0; // tensor_offsets_in_bytes[i] is the starting byte of the i-th tensor in the send tensor buffer std::vector tensor_offsets_in_bytes; - // tensor_sizes_in_bytes[i] = (# of elements in the i-th tensor) * sizeof(the i-th tensor's element type) - std::vector tensor_sizes_in_bytes; - // Same-rank communication is not allowed because we currently don't have async Send/Recv. - int world_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); - ORT_ENFORCE(world_rank != dst, "Sending data to rank ", dst, " on the rank ", world_rank, "."); + // Extract information needed for copying input tensors into a big buffer. + // Only that big buffer will be sent. + ComputeShapeRelatedInfo( + tensor_sizes_in_bytes, + tensor_shapes, + aggregated_aligned_tensor_bytes, + prefix_tensor_shape_sizes, + aggregated_tensor_shapes, + tensor_offsets_in_bytes); - // Compute tensor shapes and sizes - size_t prefix_tensor_shape_size_sum = 0; - for (int i = 0; i < tensor_num; ++i) { - const Tensor* x_tensor = ctx->Input(i + 2); - prefix_tensor_shape_size_sum += x_tensor->Shape().NumDimensions(); - prefix_tensor_shape_sizes.push_back(prefix_tensor_shape_size_sum); - aggregated_tensor_shapes.insert(aggregated_tensor_shapes.end(), - x_tensor->Shape().GetDims().begin(), - x_tensor->Shape().GetDims().end()); - - // Find the next aligned offset in the tensor buffer to meet alignment requirement - aggregated_aligned_tensor_bytes = GetAggregatedAlignedAddress(aggregated_aligned_tensor_bytes); - tensor_offsets_in_bytes.push_back(aggregated_aligned_tensor_bytes); - aggregated_aligned_tensor_bytes += x_tensor->SizeInBytes(); - tensor_sizes_in_bytes.push_back(x_tensor->SizeInBytes()); + bool all_shapes_inferred = true; + for (int i = 0; i < num_tensors; ++i) { + TensorShape inferred_shape; + auto shape_inferred = ctx->TryGetInferredInputShape(i + 2, inferred_shape); + if (!shape_inferred) { + all_shapes_inferred = false; + break; + } } - IAllocatorUniquePtr buffer = AllocateBufferOnCPUPinned( - static_cast(aggregated_aligned_tensor_bytes)); - - // Keep the sync copy in the previous design - // TODO they can be moved to async call after global stream becoming accessible - for (int i = 0; i < tensor_num; ++i) { - const Tensor* x_tensor = ctx->Input(i + 2); - ORT_ENFORCE(cudaMemcpy(buffer.get() + tensor_offsets_in_bytes[i], x_tensor->DataRaw(), - tensor_sizes_in_bytes[i], cudaMemcpyDeviceToHost) == cudaSuccess); + // Communicate shape information when it cannot be inferred. + if (!all_shapes_inferred) { + SendShapeInfo(dst, num_tensors, aggregated_aligned_tensor_bytes, prefix_tensor_shape_sizes, aggregated_tensor_shapes); } - - // Prepare MPI communication info - int tensor_num_in_bytes = tensor_num * static_cast(sizeof(size_t)); - ORT_ENFORCE(tensor_num_in_bytes < INT_MAX, - "Total tensor number larger than MPI size limit"); - CommInfo_t info_shape_sizes{prefix_tensor_shape_sizes.data(), - tensor_num_in_bytes, - dst, - static_cast(tag_)}; - - ORT_ENFORCE(aggregated_aligned_tensor_bytes < INT_MAX, - "Aggregated tensor size larger than MPI size limit"); - CommInfo_t info_aggregated_size{&aggregated_aligned_tensor_bytes, - static_cast(sizeof(size_t)), - dst, - static_cast(tag_)}; - - int total_tensor_dim_in_bytes = static_cast( - aggregated_tensor_shapes.size()) * static_cast(sizeof(int64_t)); - ORT_ENFORCE(total_tensor_dim_in_bytes < INT_MAX, - "Total dimensions of tensors larger than MPI size limit"); - CommInfo_t info_shapes{aggregated_tensor_shapes.data(), - total_tensor_dim_in_bytes, - dst, - static_cast(tag_)}; - - CommInfo_t info_data{buffer.get(), - static_cast(aggregated_aligned_tensor_bytes), - dst, - static_cast(tag_)}; - - - // Directly use CPU to wait MPI_Send. We cannot use GPU callback because - // MPI_Send may block the entire GPU until it returns. - MPI_CHECK(MPI_Send( - info_shape_sizes.buffer, info_shape_sizes.size, MPI_CHAR, - info_shape_sizes.rank, info_shape_sizes.tag, MPI_COMM_WORLD)); - #ifdef ENABLE_NVTX_PROFILE + // End of data preparation and shape communication. preRange.End(); #endif -#ifdef ENABLE_NVTX_PROFILE - profile::NvtxRangeCreator sendRange( - "Send-" + std::to_string(dst), profile::Color::Red); - // Begin of major communication tasks. - // The first MPI_Send is not included because we don't want to - // count waiting time before setting up the actual communication. - sendRange.Begin(); -#endif - - MPI_CHECK(MPI_Send( - info_aggregated_size.buffer, info_aggregated_size.size, MPI_CHAR, - info_aggregated_size.rank, info_aggregated_size.tag, MPI_COMM_WORLD)); - - MPI_CHECK(MPI_Send( - info_shapes.buffer, info_shapes.size, MPI_CHAR, - info_shapes.rank, info_shapes.tag, MPI_COMM_WORLD)); - - MPI_CHECK(MPI_Send( - info_data.buffer, info_data.size, MPI_CHAR, - info_data.rank, info_data.tag, MPI_COMM_WORLD)); - -#ifdef ENABLE_NVTX_PROFILE - // End of major communication tasks. - sendRange.End(); -#endif + // Send tensors. + SendData(ctx, dst, num_tensors, aggregated_aligned_tensor_bytes, tensor_offsets_in_bytes, tensor_sizes_in_bytes); #ifdef ENABLE_NVTX_PROFILE profile::NvtxRangeCreator postRange( - "PostSend-" + std::to_string(dst), profile::Color::Red); + "PostSend-" + std::to_string(dst), profile::Color::Red); // Begin of post communication tasks. postRange.Begin(); #endif diff --git a/orttraining/orttraining/training_ops/cuda/communication/send.h b/orttraining/orttraining/training_ops/cuda/communication/send.h index 170b2aff6e..ace4bacb37 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/send.h +++ b/orttraining/orttraining/training_ops/cuda/communication/send.h @@ -21,6 +21,20 @@ public: Status ComputeInternal(OpKernelContext* context) const override; private: + void SendShapeInfo( + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector prefix_tensor_shape_sizes, + std::vector aggregated_tensor_shapes) const; + void SendData( + OpKernelContext* ctx, + const int dst, + const int num_tensors, + size_t aggregated_aligned_tensor_bytes, + std::vector tensor_offsets_in_bytes, + std::vector tensor_sizes_in_bytes) const; + int64_t tag_; std::vector element_types_; };