diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index a6e040812a..902d9baf53 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1745,9 +1745,8 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { + void PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); @@ -1760,7 +1759,7 @@ class PlannerImpl { num_logic_streams_ = stream_nodes_.size(); } - // build each logic streams + // Build each logic streams Status BuildExecutionPlan(const ExecutionProviders& execution_providers, const IStreamCommandHandleRegistry& stream_handle_registry) { // 1. create logic stream instance @@ -1780,12 +1779,12 @@ class PlannerImpl { execution_plan.emplace_back(nullptr); } } - // 2. determing following things: - // a. which node need to generate notification - // b. which node need to trigger downstream + // 2. Determining following things: + // a. which node needs to generate the notification + // b. which node needs to trigger downstream #ifdef ENABLE_TRAINING // We will leverage the topological order for the training scenario. - // The nodes before yieldOp in topo order will be executed in RunForward() and nodes after will be executed in RunBackward() + // The nodes before yieldOp in topo-order will be executed in RunForward() and nodes after will be executed in RunBackward() // This partition may not be exactly the same as forward model/gradient model, for example, some nodes in gradient model are // before yieldOp thus will be executed in RunForward() // But the final result is still correct, as long as all the nodes will be executed in either RunForward() or RunBackward() @@ -1820,7 +1819,7 @@ class PlannerImpl { if (node_stream_map_[it->Index()] != i #ifdef ENABLE_TRAINING // Do not insert Barrier/TriggerDownStream step if the producer and consumer are in different sides of yieldOp - // As in this case producer will surely be ready before consumer is running. + // As in this case producer will surely be ready before the consumer is running. && !AreNodesSeparatedByYield(node_index, it->Index()) #endif ) { @@ -2048,8 +2047,7 @@ class PlannerImpl { } #endif - static bool - IsNonTensor(const onnxruntime::NodeArg& nodearg) { + static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) { // TODO: unclear why we should go through a string-representation of type auto ptype = nodearg.Type(); auto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(ptype); diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 00aeff37d7..c21d67d2e9 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -28,6 +28,7 @@ using namespace onnxruntime::common; namespace onnxruntime { #ifdef ORT_ENABLE_STREAM static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) { + ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr"); if (allocator->Info().alloc_type == OrtArenaAllocator) { BFCArena* arena_ptr = static_cast(allocator.get()); return StreamAwareArena::FromBFCArena(*arena_ptr); @@ -137,7 +138,7 @@ Status IExecutionFrame::GetOutputs(gsl::span fetch_mlvalue_idxs, std: #endif -// Return nullptr if index map to an value that is an unused optional input/output +// Return nullptr if index map to a value that is an unused optional input/output const OrtValue* IExecutionFrame::GetNodeInputOrOutputMLValue(int index) const { int ort_value_idx = GetNodeIdxToMLValueIdx(index); return ort_value_idx != NodeIndexInfo::kInvalidEntry ? &(all_values_[ort_value_idx]) : nullptr; @@ -147,9 +148,9 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) { return const_cast(GetNodeInputOrOutputMLValue(index)); } -// TO DO: make it thread safe -// This method is not thread safe! -// Return S_OK and nullptr if index map to an value that is an unused optional input/output +// TO DO: make it thread-safe +// This method is not thread-safe! +// Return S_OK and nullptr if index map to a value that is an unused optional input/output Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int output_arg_index, const TensorShape* shape, OrtValue*& p_ort_value, @@ -191,7 +192,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int } bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const { - // By default, there is not information about inferred shape, so this default + // By default, there is no information about inferred shape, so this default // implementation always returns false. The derived class of IExecutionFrame // can override this function to provide, for example, activations' shape information. return false; @@ -213,7 +214,7 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { } int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const { - // the validity of index is checked by GetMLValueIndex + // The validity of the index is checked by GetMLValueIndex int ort_value_idx = node_index_info_.GetMLValueIndex(index); return ort_value_idx; } @@ -241,7 +242,7 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::spansize_ - << " but the actually size is: " << size + << " but the actual size is: " << size << ", fall back to default allocation behavior"; } } @@ -572,6 +573,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va // no memory pattern, or the pattern is not correct. if (!alloc) alloc = GetAllocator(location); + ORT_ENFORCE(alloc && alloc.get() != nullptr, "Failed to get allocator for ", location.ToString()); + Stream* current_stream = GetValueStream(ort_value_index); if (current_stream) { #ifdef ORT_ENABLE_STREAM @@ -825,7 +828,7 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { } // This method is not thread safe! -// Return S_OK and nullptr if index map to an value that is an unused optional input/output +// Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); } @@ -930,7 +933,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { } // Search for inferred shape. - // If inferred shape is found, it's assigned to "shape" so that caller can use it. + // If the inferred shape is found, it's assigned to "shape" so that caller can use it. if (inferred_shapes_ != nullptr) { auto it = inferred_shapes_->find(ort_value_idx); if (it != inferred_shapes_->end()) { diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index d06946579b..6db63752c5 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1511,7 +1511,7 @@ void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0d3c54d676..cba0fde0a9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2533,7 +2533,7 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 08f50454ad..ba98a79a72 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1210,7 +1210,7 @@ void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; } diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index de3db2108b..79dfb95f8b 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2338,7 +2338,7 @@ void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7162c737b8..5526ee1151 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2844,7 +2844,7 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; }