diff --git a/include/onnxruntime/core/framework/alloc_kind.h b/include/onnxruntime/core/framework/alloc_kind.h index a749e6b26c..4534d08470 100644 --- a/include/onnxruntime/core/framework/alloc_kind.h +++ b/include/onnxruntime/core/framework/alloc_kind.h @@ -22,6 +22,7 @@ namespace onnxruntime { // Generalizing this is future work. enum class AllocKind { + kNotSet = -1, kAllocate = 0, kReuse = 1, kPreExisting = 2, diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index c6bbc2b522..b7c6f690d9 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -40,6 +40,9 @@ std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind) { case AllocKind::kShare: out << "Share"; break; + case AllocKind::kNotSet: + out << "NotSet"; + break; } return out; } @@ -639,8 +642,7 @@ class PlannerImpl { } else if (IsNonTensor(*node_output)) { // we do not try sharing-optimization for non-tensors AllocPlan(current).alloc_kind = AllocKind::kAllocate; - AllocPlan(current).program_counter_start.emplace_back(program_counter); - AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX); + AllocPlan(current).program_counter.AddStart(program_counter); } else if (FindReusableInput(*pnode, static_cast(output_arg_def_index), &reused)) { // Reuse one of this node's input buffers as the output buffer (for in-place update) Reuse(reused, current, AllocKind::kReuse); @@ -650,18 +652,12 @@ class PlannerImpl { Reuse(reused, current, AllocKind::kReuse); OrtValueIndex original = Buffer(reused); if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { - ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); - ORT_ENFORCE(AllocPlan(original).program_counter_end.back() != SIZE_MAX); - ORT_ENFORCE(AllocPlan(original).program_counter_end.back() < program_counter); - - AllocPlan(original).program_counter_start.emplace_back(program_counter); - AllocPlan(original).program_counter_end.emplace_back(SIZE_MAX); + AllocPlan(original).program_counter.AddStart(program_counter); } } else { // otherwise: allocate a new buffer for this output AllocPlan(current).alloc_kind = AllocKind::kAllocate; - AllocPlan(current).program_counter_start.emplace_back(program_counter); - AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX); + AllocPlan(current).program_counter.AddStart(program_counter); } } @@ -675,9 +671,7 @@ class PlannerImpl { if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { - ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); - ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); - AllocPlan(original).program_counter_end.back() = program_counter; + AllocPlan(original).program_counter.AddEnd(program_counter); } } } @@ -692,9 +686,7 @@ class PlannerImpl { if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { - ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); - ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); - AllocPlan(original).program_counter_end.back() = program_counter; + AllocPlan(original).program_counter.AddEnd(program_counter); } } } @@ -708,9 +700,7 @@ class PlannerImpl { if (0 == DecrementUseCount(original)) { freelist_.push_front(FreeBufferInfo(original, program_counter)); if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { - ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); - ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); - AllocPlan(original).program_counter_end.back() = program_counter; + AllocPlan(original).program_counter.AddEnd(program_counter); } } } @@ -719,6 +709,7 @@ class PlannerImpl { return Status::OK(); } +#ifdef ENABLE_TRAINING bool AllocateInputsContiguously(const Node& node) const { const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); if (ci.kernel_def == nullptr) { @@ -766,38 +757,20 @@ class PlannerImpl { } return Status::OK(); } +#endif - // Ensure memory time schedule is sorted. - Status VerifyMemoryTimeSchedule() { - std::vector& execution_plan(plan_.execution_plan); - for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) { - SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter]; - const auto* pnode = graph_viewer_.GetNode(step.node_index); - if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index); - const auto& input_defs = pnode->InputDefs(); - for (int input_arg_def_index = 0; static_cast(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) { - const auto& node_input = input_defs[input_arg_def_index]; - if (!node_input->Exists()) continue; - const auto& current_plan = AllocPlan(Index(node_input->Name())); - if (current_plan.alloc_kind != AllocKind::kAllocate) continue; - - ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size()); - - size_t start = 0; - for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) { - ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0)); - ORT_ENFORCE(current_plan.program_counter_start[index] <= current_plan.program_counter_end[index]); - ORT_ENFORCE(current_plan.program_counter_start[index] < SIZE_MAX); - ORT_ENFORCE((current_plan.program_counter_end[index] > 0) || (index == 0)); - - start = current_plan.program_counter_start[index]; - } + void VerifyMemoryTimeSchedule() { + size_t idx = 0; + for (const auto& entry : plan_.allocation_plan) { + if (entry.alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(entry.program_counter.HasValidEntries(), "Invalid program_counter entries at index ", idx); } - } - return Status::OK(); + ++idx; + } } + // Whether a given NodeArg has fence or not. // If the buffer is reused, need to check whether original OrtValue has fence or not. bool HasFence(const onnxruntime::NodeArg* arg) { @@ -875,8 +848,7 @@ class PlannerImpl { for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) { auto ml_value_idx = plan_.to_be_freed[index]; if (AllocPlan(ml_value_idx).alloc_kind == AllocKind::kAllocate) { - ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_start.back() <= program_counter); - ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_end.back() == program_counter); + ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter.Ends().back() == program_counter); } } @@ -914,15 +886,17 @@ Status PlannerImpl::CreatePlan() { // Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan. ORT_RETURN_IF_ERROR(ComputeFenceCheck()); +#ifdef ENABLE_TRAINING // Determine allocation order for weights and activations. This needs to be done after ComputeReusePlan. ORT_RETURN_IF_ERROR(ComputeAllocationOrder()); +#endif // convert information in the freelist_ into a deallocation plan in required format GenerateDeallocationPlan(); - // Ensure Memory-Time schedule is sorted. This should be called at the end because memory start/end timestamps + // Ensure Memory-Time schedule is valid. This should be called at the end because memory start/end timestamps // are updated until GenerateDeallocationPlan is finished. - ORT_RETURN_IF_ERROR(VerifyMemoryTimeSchedule()); + VerifyMemoryTimeSchedule(); return Status::OK(); } diff --git a/onnxruntime/core/framework/mem_pattern_planner.h b/onnxruntime/core/framework/mem_pattern_planner.h index 3f45ac5d71..0c7de29919 100644 --- a/onnxruntime/core/framework/mem_pattern_planner.h +++ b/onnxruntime/core/framework/mem_pattern_planner.h @@ -29,27 +29,33 @@ namespace onnxruntime { // Thread-safe. class MemPatternPlanner { public: - MemPatternPlanner() = default; + // only the Training code currently uses the program counter based logic + MemPatternPlanner(bool using_counters) : using_counters_{using_counters} {} +#ifdef ENABLE_TRAINING + // TODO: OverlappingTimeSchedules should be private // Returns true if there is an intersection between two time schedules. - // ASSUMES EACH TIME SCHEDULE IS SORTED. THIS IS VALIDATED AT THE END OF MEMORY PLANNING. - bool OverlappingTimeSchedules(const std::vector& program_counter_start_1, const std::vector& program_counter_end_1, - const std::vector& program_counter_start_2, const std::vector& program_counter_end_2) { - ORT_ENFORCE(program_counter_start_1.size() > 0); - ORT_ENFORCE(program_counter_start_2.size() > 0); - ORT_ENFORCE(program_counter_start_1.size() == program_counter_end_1.size()); - ORT_ENFORCE(program_counter_start_2.size() == program_counter_end_2.size()); + // ProgramCounter values are validated when the execution plan is created + bool OverlappingTimeSchedules(const AllocPlanPerValue::ProgramCounter& counter1, + const AllocPlanPerValue::ProgramCounter& counter2) const { + const auto& starts_1 = counter1.Starts(); + const auto& ends_1 = counter1.Ends(); + const auto& starts_2 = counter2.Starts(); + const auto& ends_2 = counter2.Ends(); size_t index_1 = 0; size_t index_2 = 0; - while ((index_1 < program_counter_start_1.size()) && (index_2 < program_counter_start_2.size())) { - if (program_counter_start_1[index_1] <= program_counter_start_2[index_2]) { - if (program_counter_end_1[index_1] >= program_counter_start_2[index_2]) { + size_t index_1_end = starts_1.size(); + size_t index_2_end = starts_2.size(); + + while ((index_1 < index_1_end) && (index_2 < index_2_end)) { + if (starts_1[index_1] <= starts_2[index_2]) { + if (ends_1[index_1] >= starts_2[index_2]) { return true; } index_1 += 1; } else { - if (program_counter_end_2[index_2] >= program_counter_start_1[index_1]) { + if (ends_2[index_2] >= starts_1[index_1]) { return true; } index_2 += 1; @@ -59,7 +65,9 @@ class MemPatternPlanner { return false; } - void TraceAllocation(int ml_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size) { + void TraceAllocation(int ml_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size) { + ORT_ENFORCE(using_counters_); + std::lock_guard lock(lock_); if (size == 0) { @@ -73,8 +81,7 @@ class MemPatternPlanner { bool best_offset_found = false; for (auto it = blocks_.begin(); it != blocks_.end(); it++) { // Memory block can be re-used as long as there is no overlap between their time schedules. - if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(program_counter_start, program_counter_end, - allocs_[*it].program_counter_start_, allocs_[*it].program_counter_end_)) { + if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(counter, *allocs_[*it].counter_)) { continue; } @@ -107,7 +114,7 @@ class MemPatternPlanner { // we only need to bounds check the addition of size to best_offset as that is the only time we extend // the maximum size of the buffer. buffer_size_ = std::max(buffer_size_, SafeInt(best_offset) + size); - allocs_.emplace_back(ml_value_idx, program_counter_start, program_counter_end, MemoryBlock(best_offset, size)); + allocs_.emplace_back(ml_value_idx, counter, MemoryBlock(best_offset, size)); std::list::iterator best_fit_it = blocks_.end(); for (auto it = blocks_.begin(); it != blocks_.end(); it++) { if (allocs_[*it].block_.offset_ < best_offset) @@ -121,8 +128,11 @@ class MemPatternPlanner { blocks_.insert(best_fit_it, (static_cast(allocs_.size()) - 1)); } +#endif void TraceAllocation(int ml_value_idx, size_t size) { + ORT_ENFORCE(!using_counters_); + std::lock_guard lock(lock_); if (size == 0) { @@ -190,35 +200,38 @@ class MemPatternPlanner { } } - MemoryPattern GenerateMemPattern() { + MemoryPattern GenerateMemPattern() const { std::lock_guard lock(lock_); - // Time schedules of overlapping memory blocks SHOULD NOT intersect. - for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) { - if (!allocs_[index_1].reuse_) - continue; - - for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) { - if (!allocs_[index_2].reuse_) +#ifdef ENABLE_TRAINING + if (using_counters_) { + // Time schedules of overlapping memory blocks SHOULD NOT intersect. + for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) { + if (!allocs_[index_1].reuse_) continue; - size_t alloc_1_start = allocs_[index_1].block_.offset_; - size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1; + for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) { + if (!allocs_[index_2].reuse_) + continue; - ORT_ENFORCE(alloc_1_start <= alloc_1_end); + size_t alloc_1_start = allocs_[index_1].block_.offset_; + size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1; - size_t alloc_2_start = allocs_[index_2].block_.offset_; - size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1; + ORT_ENFORCE(alloc_1_start <= alloc_1_end); - ORT_ENFORCE(alloc_2_start <= alloc_2_end); + size_t alloc_2_start = allocs_[index_2].block_.offset_; + size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1; - if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) || - ((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) { - ORT_ENFORCE(!OverlappingTimeSchedules(allocs_[index_1].program_counter_start_, allocs_[index_1].program_counter_end_, - allocs_[index_2].program_counter_start_, allocs_[index_2].program_counter_end_)); + ORT_ENFORCE(alloc_2_start <= alloc_2_end); + + if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) || + ((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) { + ORT_ENFORCE(!OverlappingTimeSchedules(*allocs_[index_1].counter_, *allocs_[index_2].counter_)); + } } } } +#endif MemoryPattern pattern; pattern.peak_size_ = buffer_size_; @@ -233,18 +246,20 @@ class MemPatternPlanner { struct OrtValueAllocationBlock { int index_{-1}; MemoryBlock block_; - const std::vector program_counter_start_; - const std::vector program_counter_end_; + const AllocPlanPerValue::ProgramCounter* counter_{nullptr}; bool reuse_{false}; OrtValueAllocationBlock() = default; OrtValueAllocationBlock(int index, const MemoryBlock& block) : index_(index), block_(block), reuse_{false} {} - OrtValueAllocationBlock(int index, std::vector program_counter_start, std::vector program_counter_end, const MemoryBlock& block) : index_(index), block_(block), program_counter_start_(program_counter_start), program_counter_end_(program_counter_end), reuse_{true} {} + OrtValueAllocationBlock(int index, const AllocPlanPerValue::ProgramCounter& counter, const MemoryBlock& block) + : index_(index), block_(block), counter_(&counter), reuse_{true} { + } }; std::vector allocs_; // blocks_ the list of currently allocated memory blocks, sorted in order of their offset std::list blocks_; SafeInt buffer_size_{0}; + bool using_counters_; mutable OrtMutex lock_; }; diff --git a/onnxruntime/core/framework/ort_value_pattern_planner.cc b/onnxruntime/core/framework/ort_value_pattern_planner.cc index 6fd8b140f5..a525cb5a1d 100644 --- a/onnxruntime/core/framework/ort_value_pattern_planner.cc +++ b/onnxruntime/core/framework/ort_value_pattern_planner.cc @@ -6,27 +6,31 @@ #include "core/framework/execution_plan_base.h" namespace onnxruntime { -OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan) +OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters) : execution_planner_(execution_plan) { for (auto& location : execution_plan.GetAllLocations()) { - planner_map_.emplace(location, onnxruntime::make_unique()); + planner_map_.emplace(location, onnxruntime::make_unique(trace_using_counters)); } } -common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size) { +#ifdef ENABLE_TRAINING +common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, + const AllocPlanPerValue::ProgramCounter& counter, + size_t size) { // TODO(codemzs): refactor code. - auto location = execution_planner_.GetLocation(ort_value_idx); + const auto& location = execution_planner_.GetLocation(ort_value_idx); auto it = planner_map_.find(location); if (it == planner_map_.end()) { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); } - it->second->TraceAllocation(ort_value_idx, program_counter_start, program_counter_end, size); + it->second->TraceAllocation(ort_value_idx, counter, size); return common::Status::OK(); } +#endif common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t size) { - auto location = execution_planner_.GetLocation(ort_value_idx); + const auto& location = execution_planner_.GetLocation(ort_value_idx); auto it = planner_map_.find(location); if (it == planner_map_.end()) { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); @@ -37,7 +41,7 @@ common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t } common::Status OrtValuePatternPlanner::TraceFree(int ort_value_index) { - auto location = execution_planner_.GetLocation(ort_value_index); + const auto& location = execution_planner_.GetLocation(ort_value_index); auto it = planner_map_.find(location); if (it == planner_map_.end()) { return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); diff --git a/onnxruntime/core/framework/ort_value_pattern_planner.h b/onnxruntime/core/framework/ort_value_pattern_planner.h index fd25e49e92..77fcb41728 100644 --- a/onnxruntime/core/framework/ort_value_pattern_planner.h +++ b/onnxruntime/core/framework/ort_value_pattern_planner.h @@ -18,8 +18,12 @@ class ExecutionPlanBase; // SessionOptions.enable_mem_pattern class OrtValuePatternPlanner { public: - explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan); - common::Status TraceAllocation(int ort_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size); + // trace_using_counters should be true if the TraceAllocation with ProgramCounter is used. Only one + // variant of the TraceAllocation calls may be used. + explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters = false); +#ifdef ENABLE_TRAINING + common::Status TraceAllocation(int ort_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size); +#endif common::Status TraceAllocation(int ort_value_idx, size_t size); common::Status TraceFree(int ort_value_index); common::Status GeneratePatterns(MemoryPatternGroup* out); diff --git a/onnxruntime/core/framework/parallel_executor.cc b/onnxruntime/core/framework/parallel_executor.cc index 10ddb74856..9c0c70770c 100644 --- a/onnxruntime/core/framework/parallel_executor.cc +++ b/onnxruntime/core/framework/parallel_executor.cc @@ -191,8 +191,10 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index, // Execute the kernel. ORT_TRY { +#ifdef ENABLE_TRAINING if (p_op_kernel->KernelDef().AllocateInputsContiguously()) utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context); +#endif status = p_op_kernel->Compute(&op_kernel_context); } diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index e806e65cd2..fb997d689e 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -22,7 +22,7 @@ class SessionState; // Captures information required to allocate/reuse buffer for a ml-value struct AllocPlanPerValue { - AllocKind alloc_kind{AllocKind::kAllocate}; + AllocKind alloc_kind{AllocKind::kNotSet}; MLDataType value_type{nullptr}; OrtMemoryInfo location; // reused_buffer is valid only if alloc_kind == kReuse. It indicates @@ -31,8 +31,37 @@ struct AllocPlanPerValue { // if the value is used in async kernel, a fence object would be created // note the fence object would be shared between MLValues reusing the same buffer bool create_fence_if_async{false}; - std::vector program_counter_start; - std::vector program_counter_end; + + class ProgramCounter { + public: + ProgramCounter() = default; + void AddStart(size_t start) { + ORT_ENFORCE(starts_.size() == ends_.size(), "Previous entry was not terminated."); + ORT_ENFORCE(starts_.empty() || start > ends_.back(), "Invalid 'start'. Value is smaller than previous 'end'."); + starts_.push_back(start); + } + + void AddEnd(size_t end) { + ORT_ENFORCE(starts_.size() == ends_.size() + 1, "No matching 'start' entry."); + ORT_ENFORCE(end >= starts_.back(), "Invalid 'end'. Value is larger than 'start'."); + ends_.push_back(end); + } + + // return true if there are entries, and the number of start/end pairs match. + // validity of the individual start/end values is checked when they are added. + bool HasValidEntries() const { + return !starts_.empty() && starts_.size() == ends_.size(); + } + + const std::vector& Starts() const { return starts_; } + const std::vector& Ends() const { return ends_; } + + private: + std::vector starts_; + std::vector ends_; + }; + + ProgramCounter program_counter; public: AllocPlanPerValue() : location(CPU, Invalid) {} diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 6877dac0de..e3806a8091 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -307,9 +307,11 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: node_compute_range.Begin(); #endif ORT_TRY { +#ifdef ENABLE_TRAINING if (p_op_kernel->KernelDef().AllocateInputsContiguously()) utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context); - +#endif + compute_status = p_op_kernel->Compute(&op_kernel_context); } ORT_CATCH(const std::exception& ex) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 260b975c73..4838fa575b 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -396,14 +396,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorexecution_plan) { int node_index = node_index_info.GetNodeOffset(node_plan.node_index); auto* node = graph_viewer_->GetNode(node_plan.node_index); - int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); + int output_start = node_index + static_cast(node->InputDefs().size()) + + static_cast(node->ImplicitInputDefs().size()); for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); @@ -427,7 +428,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vectoractivation_allocation_order) { ORT_ENFORCE(ml_value_idx >= 0); @@ -450,13 +451,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorallocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate); - ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size()); - for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1) - ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]); - - mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start, - exe_plan->allocation_plan[ml_value_idx].program_counter_end, size); + const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter; + mem_planner.TraceAllocation(ml_value_idx, counter, size); } } @@ -464,12 +461,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorexecution_plan) { int node_index = node_index_info.GetNodeOffset(node_plan.node_index); auto* node = graph_viewer_->GetNode(node_plan.node_index); - int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); + int output_start = node_index + static_cast(node->InputDefs().size()) + + static_cast(node->ImplicitInputDefs().size()); //allocate output for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); if (ml_value_idx == NodeIndexInfo::kInvalidEntry || - (std::find(exe_plan->activation_allocation_order.begin(), exe_plan->activation_allocation_order.end(), ml_value_idx) != exe_plan->activation_allocation_order.end())) + (std::find(exe_plan->activation_allocation_order.begin(), + exe_plan->activation_allocation_order.end(), ml_value_idx) != + exe_plan->activation_allocation_order.end())) continue; const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type; if (!ml_type->IsTensorType()) @@ -487,13 +487,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorallocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate); - ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size()); - for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1) - ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]); - - mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start, - exe_plan->allocation_plan[ml_value_idx].program_counter_end, aligned_size); + const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter; + mem_planner.TraceAllocation(ml_value_idx, counter, aligned_size); } } @@ -991,7 +987,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string tensor_allocator_( + std::unique_ptr tensor_allocator( ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_)); const auto& initializer_allocation_order = p_seq_exec_plan_->initializer_allocation_order; @@ -1001,7 +997,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string Status { return AddInitializedTensor(idx, value, &d, constant); }, diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 7521e9ab37..c9b3ce0250 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -90,7 +90,9 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st common::Status SaveInitializedTensors( const Env& env, const std::basic_string& graph_loc, const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, const std::vector& initializer_allocation_order, ITensorAllocator& planner, + const OrtValueNameIdxMap& ort_value_name_idx_map, + const std::vector& initializer_allocation_order, + ITensorAllocator& planner, const std::function& save_tensor_func, const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index 247fbadc61..179c19e1d0 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -61,7 +61,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { TensorAllocatorWithMemPattern(const ExecutionPlanBase& execution_plan, const SessionState& session_state, std::vector& weights_buffers) : ITensorAllocator(session_state), - planner_(execution_plan), + planner_(execution_plan, /*using counters*/ false), weights_buffers_(weights_buffers), seq_plan_(execution_plan) {} diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 936da1294c..bebfcb2c26 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -570,6 +570,7 @@ int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType onn } } +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) { const Tensor* prev_input = context->Input(0); for (int i = 1; i < context->InputCount(); i++) { @@ -581,15 +582,18 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) size_t input_element_size = prev_input->DataType()->Size(); size_t input_aligned_bytes = 0; - ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes)); + ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, + &input_aligned_bytes)); - ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + input_aligned_bytes || - curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + prev_input->SizeInBytes()); + ORT_RETURN_IF_NOT( + curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + input_aligned_bytes || + curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + prev_input->SizeInBytes()); prev_input = curr_input; } return Status::OK(); } +#endif } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index fa380027e3..2b69e4f864 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -156,7 +156,9 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); +#endif } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/test/framework/mem_pattern_planner_test.cc b/onnxruntime/test/framework/mem_pattern_planner_test.cc index 6ea448b9f7..42e9018e0c 100644 --- a/onnxruntime/test/framework/mem_pattern_planner_test.cc +++ b/onnxruntime/test/framework/mem_pattern_planner_test.cc @@ -7,7 +7,8 @@ namespace onnxruntime { namespace test { TEST(MemPatternPlannerTest, TraceAllocaitonTest) { - MemPatternPlanner planner; + const bool using_counters = false; // we're not tracking start/end for use/re-use of each allocation via counters + MemPatternPlanner planner{using_counters}; planner.TraceAllocation(0, 1024); planner.TraceAllocation(1, 256); planner.TraceAllocation(2, 512);