mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Exclude some training specific code from the minimal build. Cleanup some related aspects of allocation planner. (#5861)
* Exclude some training specific code around the allocation planning and initializer handling from the minimal build. Simplify the code around tracking start/end usage of a value.
This commit is contained in:
parent
b057b3d36e
commit
00412a76e9
14 changed files with 161 additions and 125 deletions
|
|
@ -22,6 +22,7 @@ namespace onnxruntime {
|
|||
// Generalizing this is future work.
|
||||
|
||||
enum class AllocKind {
|
||||
kNotSet = -1,
|
||||
kAllocate = 0,
|
||||
kReuse = 1,
|
||||
kPreExisting = 2,
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind) {
|
|||
case AllocKind::kShare:
|
||||
out << "Share";
|
||||
break;
|
||||
case AllocKind::kNotSet:
|
||||
out << "NotSet";
|
||||
break;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
|
@ -639,8 +642,7 @@ class PlannerImpl {
|
|||
} else if (IsNonTensor(*node_output)) {
|
||||
// we do not try sharing-optimization for non-tensors
|
||||
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
|
||||
AllocPlan(current).program_counter_start.emplace_back(program_counter);
|
||||
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
|
||||
AllocPlan(current).program_counter.AddStart(program_counter);
|
||||
} else if (FindReusableInput(*pnode, static_cast<int>(output_arg_def_index), &reused)) {
|
||||
// Reuse one of this node's input buffers as the output buffer (for in-place update)
|
||||
Reuse(reused, current, AllocKind::kReuse);
|
||||
|
|
@ -650,18 +652,12 @@ class PlannerImpl {
|
|||
Reuse(reused, current, AllocKind::kReuse);
|
||||
OrtValueIndex original = Buffer(reused);
|
||||
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() != SIZE_MAX);
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() < program_counter);
|
||||
|
||||
AllocPlan(original).program_counter_start.emplace_back(program_counter);
|
||||
AllocPlan(original).program_counter_end.emplace_back(SIZE_MAX);
|
||||
AllocPlan(original).program_counter.AddStart(program_counter);
|
||||
}
|
||||
} else {
|
||||
// otherwise: allocate a new buffer for this output
|
||||
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
|
||||
AllocPlan(current).program_counter_start.emplace_back(program_counter);
|
||||
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
|
||||
AllocPlan(current).program_counter.AddStart(program_counter);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -675,9 +671,7 @@ class PlannerImpl {
|
|||
if ((original != -1) && (0 == DecrementUseCount(original))) {
|
||||
freelist_.push_front(FreeBufferInfo(original, program_counter));
|
||||
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
|
||||
AllocPlan(original).program_counter_end.back() = program_counter;
|
||||
AllocPlan(original).program_counter.AddEnd(program_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -692,9 +686,7 @@ class PlannerImpl {
|
|||
if ((original != -1) && (0 == DecrementUseCount(original))) {
|
||||
freelist_.push_front(FreeBufferInfo(original, program_counter));
|
||||
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
|
||||
AllocPlan(original).program_counter_end.back() = program_counter;
|
||||
AllocPlan(original).program_counter.AddEnd(program_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -708,9 +700,7 @@ class PlannerImpl {
|
|||
if (0 == DecrementUseCount(original)) {
|
||||
freelist_.push_front(FreeBufferInfo(original, program_counter));
|
||||
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
|
||||
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
|
||||
AllocPlan(original).program_counter_end.back() = program_counter;
|
||||
AllocPlan(original).program_counter.AddEnd(program_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -719,6 +709,7 @@ class PlannerImpl {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
bool AllocateInputsContiguously(const Node& node) const {
|
||||
const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
|
||||
if (ci.kernel_def == nullptr) {
|
||||
|
|
@ -766,38 +757,20 @@ class PlannerImpl {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Ensure memory time schedule is sorted.
|
||||
Status VerifyMemoryTimeSchedule() {
|
||||
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
|
||||
for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) {
|
||||
SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter];
|
||||
const auto* pnode = graph_viewer_.GetNode(step.node_index);
|
||||
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index);
|
||||
const auto& input_defs = pnode->InputDefs();
|
||||
for (int input_arg_def_index = 0; static_cast<size_t>(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) {
|
||||
const auto& node_input = input_defs[input_arg_def_index];
|
||||
if (!node_input->Exists()) continue;
|
||||
const auto& current_plan = AllocPlan(Index(node_input->Name()));
|
||||
if (current_plan.alloc_kind != AllocKind::kAllocate) continue;
|
||||
|
||||
ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size());
|
||||
|
||||
size_t start = 0;
|
||||
for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) {
|
||||
ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0));
|
||||
ORT_ENFORCE(current_plan.program_counter_start[index] <= current_plan.program_counter_end[index]);
|
||||
ORT_ENFORCE(current_plan.program_counter_start[index] < SIZE_MAX);
|
||||
ORT_ENFORCE((current_plan.program_counter_end[index] > 0) || (index == 0));
|
||||
|
||||
start = current_plan.program_counter_start[index];
|
||||
}
|
||||
void VerifyMemoryTimeSchedule() {
|
||||
size_t idx = 0;
|
||||
for (const auto& entry : plan_.allocation_plan) {
|
||||
if (entry.alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(entry.program_counter.HasValidEntries(), "Invalid program_counter entries at index ", idx);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Whether a given NodeArg has fence or not.
|
||||
// If the buffer is reused, need to check whether original OrtValue has fence or not.
|
||||
bool HasFence(const onnxruntime::NodeArg* arg) {
|
||||
|
|
@ -875,8 +848,7 @@ class PlannerImpl {
|
|||
for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) {
|
||||
auto ml_value_idx = plan_.to_be_freed[index];
|
||||
if (AllocPlan(ml_value_idx).alloc_kind == AllocKind::kAllocate) {
|
||||
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_start.back() <= program_counter);
|
||||
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_end.back() == program_counter);
|
||||
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter.Ends().back() == program_counter);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -914,15 +886,17 @@ Status PlannerImpl::CreatePlan() {
|
|||
// Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan.
|
||||
ORT_RETURN_IF_ERROR(ComputeFenceCheck());
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// Determine allocation order for weights and activations. This needs to be done after ComputeReusePlan.
|
||||
ORT_RETURN_IF_ERROR(ComputeAllocationOrder());
|
||||
#endif
|
||||
|
||||
// convert information in the freelist_ into a deallocation plan in required format
|
||||
GenerateDeallocationPlan();
|
||||
|
||||
// Ensure Memory-Time schedule is sorted. This should be called at the end because memory start/end timestamps
|
||||
// Ensure Memory-Time schedule is valid. This should be called at the end because memory start/end timestamps
|
||||
// are updated until GenerateDeallocationPlan is finished.
|
||||
ORT_RETURN_IF_ERROR(VerifyMemoryTimeSchedule());
|
||||
VerifyMemoryTimeSchedule();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,27 +29,33 @@ namespace onnxruntime {
|
|||
// Thread-safe.
|
||||
class MemPatternPlanner {
|
||||
public:
|
||||
MemPatternPlanner() = default;
|
||||
// only the Training code currently uses the program counter based logic
|
||||
MemPatternPlanner(bool using_counters) : using_counters_{using_counters} {}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// TODO: OverlappingTimeSchedules should be private
|
||||
// Returns true if there is an intersection between two time schedules.
|
||||
// ASSUMES EACH TIME SCHEDULE IS SORTED. THIS IS VALIDATED AT THE END OF MEMORY PLANNING.
|
||||
bool OverlappingTimeSchedules(const std::vector<size_t>& program_counter_start_1, const std::vector<size_t>& program_counter_end_1,
|
||||
const std::vector<size_t>& program_counter_start_2, const std::vector<size_t>& program_counter_end_2) {
|
||||
ORT_ENFORCE(program_counter_start_1.size() > 0);
|
||||
ORT_ENFORCE(program_counter_start_2.size() > 0);
|
||||
ORT_ENFORCE(program_counter_start_1.size() == program_counter_end_1.size());
|
||||
ORT_ENFORCE(program_counter_start_2.size() == program_counter_end_2.size());
|
||||
// ProgramCounter values are validated when the execution plan is created
|
||||
bool OverlappingTimeSchedules(const AllocPlanPerValue::ProgramCounter& counter1,
|
||||
const AllocPlanPerValue::ProgramCounter& counter2) const {
|
||||
const auto& starts_1 = counter1.Starts();
|
||||
const auto& ends_1 = counter1.Ends();
|
||||
const auto& starts_2 = counter2.Starts();
|
||||
const auto& ends_2 = counter2.Ends();
|
||||
|
||||
size_t index_1 = 0;
|
||||
size_t index_2 = 0;
|
||||
while ((index_1 < program_counter_start_1.size()) && (index_2 < program_counter_start_2.size())) {
|
||||
if (program_counter_start_1[index_1] <= program_counter_start_2[index_2]) {
|
||||
if (program_counter_end_1[index_1] >= program_counter_start_2[index_2]) {
|
||||
size_t index_1_end = starts_1.size();
|
||||
size_t index_2_end = starts_2.size();
|
||||
|
||||
while ((index_1 < index_1_end) && (index_2 < index_2_end)) {
|
||||
if (starts_1[index_1] <= starts_2[index_2]) {
|
||||
if (ends_1[index_1] >= starts_2[index_2]) {
|
||||
return true;
|
||||
}
|
||||
index_1 += 1;
|
||||
} else {
|
||||
if (program_counter_end_2[index_2] >= program_counter_start_1[index_1]) {
|
||||
if (ends_2[index_2] >= starts_1[index_1]) {
|
||||
return true;
|
||||
}
|
||||
index_2 += 1;
|
||||
|
|
@ -59,7 +65,9 @@ class MemPatternPlanner {
|
|||
return false;
|
||||
}
|
||||
|
||||
void TraceAllocation(int ml_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size) {
|
||||
void TraceAllocation(int ml_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size) {
|
||||
ORT_ENFORCE(using_counters_);
|
||||
|
||||
std::lock_guard<OrtMutex> lock(lock_);
|
||||
|
||||
if (size == 0) {
|
||||
|
|
@ -73,8 +81,7 @@ class MemPatternPlanner {
|
|||
bool best_offset_found = false;
|
||||
for (auto it = blocks_.begin(); it != blocks_.end(); it++) {
|
||||
// Memory block can be re-used as long as there is no overlap between their time schedules.
|
||||
if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(program_counter_start, program_counter_end,
|
||||
allocs_[*it].program_counter_start_, allocs_[*it].program_counter_end_)) {
|
||||
if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(counter, *allocs_[*it].counter_)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +114,7 @@ class MemPatternPlanner {
|
|||
// we only need to bounds check the addition of size to best_offset as that is the only time we extend
|
||||
// the maximum size of the buffer.
|
||||
buffer_size_ = std::max(buffer_size_, SafeInt<size_t>(best_offset) + size);
|
||||
allocs_.emplace_back(ml_value_idx, program_counter_start, program_counter_end, MemoryBlock(best_offset, size));
|
||||
allocs_.emplace_back(ml_value_idx, counter, MemoryBlock(best_offset, size));
|
||||
std::list<int>::iterator best_fit_it = blocks_.end();
|
||||
for (auto it = blocks_.begin(); it != blocks_.end(); it++) {
|
||||
if (allocs_[*it].block_.offset_ < best_offset)
|
||||
|
|
@ -121,8 +128,11 @@ class MemPatternPlanner {
|
|||
|
||||
blocks_.insert(best_fit_it, (static_cast<int>(allocs_.size()) - 1));
|
||||
}
|
||||
#endif
|
||||
|
||||
void TraceAllocation(int ml_value_idx, size_t size) {
|
||||
ORT_ENFORCE(!using_counters_);
|
||||
|
||||
std::lock_guard<OrtMutex> lock(lock_);
|
||||
|
||||
if (size == 0) {
|
||||
|
|
@ -190,35 +200,38 @@ class MemPatternPlanner {
|
|||
}
|
||||
}
|
||||
|
||||
MemoryPattern GenerateMemPattern() {
|
||||
MemoryPattern GenerateMemPattern() const {
|
||||
std::lock_guard<OrtMutex> lock(lock_);
|
||||
|
||||
// Time schedules of overlapping memory blocks SHOULD NOT intersect.
|
||||
for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) {
|
||||
if (!allocs_[index_1].reuse_)
|
||||
continue;
|
||||
|
||||
for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) {
|
||||
if (!allocs_[index_2].reuse_)
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (using_counters_) {
|
||||
// Time schedules of overlapping memory blocks SHOULD NOT intersect.
|
||||
for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) {
|
||||
if (!allocs_[index_1].reuse_)
|
||||
continue;
|
||||
|
||||
size_t alloc_1_start = allocs_[index_1].block_.offset_;
|
||||
size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1;
|
||||
for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) {
|
||||
if (!allocs_[index_2].reuse_)
|
||||
continue;
|
||||
|
||||
ORT_ENFORCE(alloc_1_start <= alloc_1_end);
|
||||
size_t alloc_1_start = allocs_[index_1].block_.offset_;
|
||||
size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1;
|
||||
|
||||
size_t alloc_2_start = allocs_[index_2].block_.offset_;
|
||||
size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1;
|
||||
ORT_ENFORCE(alloc_1_start <= alloc_1_end);
|
||||
|
||||
ORT_ENFORCE(alloc_2_start <= alloc_2_end);
|
||||
size_t alloc_2_start = allocs_[index_2].block_.offset_;
|
||||
size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1;
|
||||
|
||||
if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) ||
|
||||
((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) {
|
||||
ORT_ENFORCE(!OverlappingTimeSchedules(allocs_[index_1].program_counter_start_, allocs_[index_1].program_counter_end_,
|
||||
allocs_[index_2].program_counter_start_, allocs_[index_2].program_counter_end_));
|
||||
ORT_ENFORCE(alloc_2_start <= alloc_2_end);
|
||||
|
||||
if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) ||
|
||||
((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) {
|
||||
ORT_ENFORCE(!OverlappingTimeSchedules(*allocs_[index_1].counter_, *allocs_[index_2].counter_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
MemoryPattern pattern;
|
||||
pattern.peak_size_ = buffer_size_;
|
||||
|
|
@ -233,18 +246,20 @@ class MemPatternPlanner {
|
|||
struct OrtValueAllocationBlock {
|
||||
int index_{-1};
|
||||
MemoryBlock block_;
|
||||
const std::vector<size_t> program_counter_start_;
|
||||
const std::vector<size_t> program_counter_end_;
|
||||
const AllocPlanPerValue::ProgramCounter* counter_{nullptr};
|
||||
bool reuse_{false};
|
||||
OrtValueAllocationBlock() = default;
|
||||
OrtValueAllocationBlock(int index, const MemoryBlock& block) : index_(index), block_(block), reuse_{false} {}
|
||||
OrtValueAllocationBlock(int index, std::vector<size_t> program_counter_start, std::vector<size_t> program_counter_end, const MemoryBlock& block) : index_(index), block_(block), program_counter_start_(program_counter_start), program_counter_end_(program_counter_end), reuse_{true} {}
|
||||
OrtValueAllocationBlock(int index, const AllocPlanPerValue::ProgramCounter& counter, const MemoryBlock& block)
|
||||
: index_(index), block_(block), counter_(&counter), reuse_{true} {
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<OrtValueAllocationBlock> allocs_;
|
||||
// blocks_ the list of currently allocated memory blocks, sorted in order of their offset
|
||||
std::list<int> blocks_;
|
||||
SafeInt<size_t> buffer_size_{0};
|
||||
bool using_counters_;
|
||||
mutable OrtMutex lock_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -6,27 +6,31 @@
|
|||
#include "core/framework/execution_plan_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan)
|
||||
OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters)
|
||||
: execution_planner_(execution_plan) {
|
||||
for (auto& location : execution_plan.GetAllLocations()) {
|
||||
planner_map_.emplace(location, onnxruntime::make_unique<MemPatternPlanner>());
|
||||
planner_map_.emplace(location, onnxruntime::make_unique<MemPatternPlanner>(trace_using_counters));
|
||||
}
|
||||
}
|
||||
|
||||
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size) {
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx,
|
||||
const AllocPlanPerValue::ProgramCounter& counter,
|
||||
size_t size) {
|
||||
// TODO(codemzs): refactor code.
|
||||
auto location = execution_planner_.GetLocation(ort_value_idx);
|
||||
const auto& location = execution_planner_.GetLocation(ort_value_idx);
|
||||
auto it = planner_map_.find(location);
|
||||
if (it == planner_map_.end()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
it->second->TraceAllocation(ort_value_idx, program_counter_start, program_counter_end, size);
|
||||
it->second->TraceAllocation(ort_value_idx, counter, size);
|
||||
return common::Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t size) {
|
||||
auto location = execution_planner_.GetLocation(ort_value_idx);
|
||||
const auto& location = execution_planner_.GetLocation(ort_value_idx);
|
||||
auto it = planner_map_.find(location);
|
||||
if (it == planner_map_.end()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
|
||||
|
|
@ -37,7 +41,7 @@ common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t
|
|||
}
|
||||
|
||||
common::Status OrtValuePatternPlanner::TraceFree(int ort_value_index) {
|
||||
auto location = execution_planner_.GetLocation(ort_value_index);
|
||||
const auto& location = execution_planner_.GetLocation(ort_value_index);
|
||||
auto it = planner_map_.find(location);
|
||||
if (it == planner_map_.end()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
|
||||
|
|
|
|||
|
|
@ -18,8 +18,12 @@ class ExecutionPlanBase;
|
|||
// SessionOptions.enable_mem_pattern
|
||||
class OrtValuePatternPlanner {
|
||||
public:
|
||||
explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan);
|
||||
common::Status TraceAllocation(int ort_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size);
|
||||
// trace_using_counters should be true if the TraceAllocation with ProgramCounter is used. Only one
|
||||
// variant of the TraceAllocation calls may be used.
|
||||
explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters = false);
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status TraceAllocation(int ort_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size);
|
||||
#endif
|
||||
common::Status TraceAllocation(int ort_value_idx, size_t size);
|
||||
common::Status TraceFree(int ort_value_index);
|
||||
common::Status GeneratePatterns(MemoryPatternGroup* out);
|
||||
|
|
|
|||
|
|
@ -191,8 +191,10 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
|
|||
|
||||
// Execute the kernel.
|
||||
ORT_TRY {
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (p_op_kernel->KernelDef().AllocateInputsContiguously())
|
||||
utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context);
|
||||
#endif
|
||||
|
||||
status = p_op_kernel->Compute(&op_kernel_context);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class SessionState;
|
|||
|
||||
// Captures information required to allocate/reuse buffer for a ml-value
|
||||
struct AllocPlanPerValue {
|
||||
AllocKind alloc_kind{AllocKind::kAllocate};
|
||||
AllocKind alloc_kind{AllocKind::kNotSet};
|
||||
MLDataType value_type{nullptr};
|
||||
OrtMemoryInfo location;
|
||||
// reused_buffer is valid only if alloc_kind == kReuse. It indicates
|
||||
|
|
@ -31,8 +31,37 @@ struct AllocPlanPerValue {
|
|||
// if the value is used in async kernel, a fence object would be created
|
||||
// note the fence object would be shared between MLValues reusing the same buffer
|
||||
bool create_fence_if_async{false};
|
||||
std::vector<size_t> program_counter_start;
|
||||
std::vector<size_t> program_counter_end;
|
||||
|
||||
class ProgramCounter {
|
||||
public:
|
||||
ProgramCounter() = default;
|
||||
void AddStart(size_t start) {
|
||||
ORT_ENFORCE(starts_.size() == ends_.size(), "Previous entry was not terminated.");
|
||||
ORT_ENFORCE(starts_.empty() || start > ends_.back(), "Invalid 'start'. Value is smaller than previous 'end'.");
|
||||
starts_.push_back(start);
|
||||
}
|
||||
|
||||
void AddEnd(size_t end) {
|
||||
ORT_ENFORCE(starts_.size() == ends_.size() + 1, "No matching 'start' entry.");
|
||||
ORT_ENFORCE(end >= starts_.back(), "Invalid 'end'. Value is larger than 'start'.");
|
||||
ends_.push_back(end);
|
||||
}
|
||||
|
||||
// return true if there are entries, and the number of start/end pairs match.
|
||||
// validity of the individual start/end values is checked when they are added.
|
||||
bool HasValidEntries() const {
|
||||
return !starts_.empty() && starts_.size() == ends_.size();
|
||||
}
|
||||
|
||||
const std::vector<size_t>& Starts() const { return starts_; }
|
||||
const std::vector<size_t>& Ends() const { return ends_; }
|
||||
|
||||
private:
|
||||
std::vector<size_t> starts_;
|
||||
std::vector<size_t> ends_;
|
||||
};
|
||||
|
||||
ProgramCounter program_counter;
|
||||
|
||||
public:
|
||||
AllocPlanPerValue() : location(CPU, Invalid) {}
|
||||
|
|
|
|||
|
|
@ -307,9 +307,11 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
node_compute_range.Begin();
|
||||
#endif
|
||||
ORT_TRY {
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (p_op_kernel->KernelDef().AllocateInputsContiguously())
|
||||
utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context);
|
||||
|
||||
#endif
|
||||
|
||||
compute_status = p_op_kernel->Compute(&op_kernel_context);
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
|
|
|
|||
|
|
@ -396,14 +396,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
ORT_RETURN_IF_ERROR(ResolveDimParams(*graph_viewer_, feeds, map));
|
||||
auto* exe_plan = GetExecutionPlan();
|
||||
ORT_ENFORCE(exe_plan);
|
||||
OrtValuePatternPlanner mem_planner(*exe_plan);
|
||||
OrtValuePatternPlanner mem_planner(*exe_plan, /*using counters*/ true);
|
||||
|
||||
// Try to resolve shapes for activations.
|
||||
auto& node_index_info = GetNodeIndexInfo();
|
||||
for (auto& node_plan : exe_plan->execution_plan) {
|
||||
int node_index = node_index_info.GetNodeOffset(node_plan.node_index);
|
||||
auto* node = graph_viewer_->GetNode(node_plan.node_index);
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
|
||||
static_cast<int>(node->ImplicitInputDefs().size());
|
||||
|
||||
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
|
||||
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
|
||||
|
|
@ -427,7 +428,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
}
|
||||
}
|
||||
|
||||
// Allocate activations that want to be laid out contigously in memory.
|
||||
// Allocate activations that want to be laid out contiguously in memory.
|
||||
for (auto ml_value_idx : exe_plan->activation_allocation_order) {
|
||||
ORT_ENFORCE(ml_value_idx >= 0);
|
||||
|
||||
|
|
@ -450,13 +451,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
}
|
||||
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate);
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size());
|
||||
|
||||
for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1)
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]);
|
||||
|
||||
mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start,
|
||||
exe_plan->allocation_plan[ml_value_idx].program_counter_end, size);
|
||||
const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter;
|
||||
mem_planner.TraceAllocation(ml_value_idx, counter, size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -464,12 +461,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
for (auto& node_plan : exe_plan->execution_plan) {
|
||||
int node_index = node_index_info.GetNodeOffset(node_plan.node_index);
|
||||
auto* node = graph_viewer_->GetNode(node_plan.node_index);
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
|
||||
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
|
||||
static_cast<int>(node->ImplicitInputDefs().size());
|
||||
//allocate output
|
||||
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
|
||||
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
|
||||
if (ml_value_idx == NodeIndexInfo::kInvalidEntry ||
|
||||
(std::find(exe_plan->activation_allocation_order.begin(), exe_plan->activation_allocation_order.end(), ml_value_idx) != exe_plan->activation_allocation_order.end()))
|
||||
(std::find(exe_plan->activation_allocation_order.begin(),
|
||||
exe_plan->activation_allocation_order.end(), ml_value_idx) !=
|
||||
exe_plan->activation_allocation_order.end()))
|
||||
continue;
|
||||
const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type;
|
||||
if (!ml_type->IsTensorType())
|
||||
|
|
@ -487,13 +487,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
}
|
||||
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate);
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size());
|
||||
|
||||
for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1)
|
||||
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]);
|
||||
|
||||
mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start,
|
||||
exe_plan->allocation_plan[ml_value_idx].program_counter_end, aligned_size);
|
||||
const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter;
|
||||
mem_planner.TraceAllocation(ml_value_idx, counter, aligned_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -991,7 +987,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
// Uncomment the below to dump the allocation plan to std::cout
|
||||
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
|
||||
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator_(
|
||||
std::unique_ptr<ITensorAllocator> tensor_allocator(
|
||||
ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_));
|
||||
|
||||
const auto& initializer_allocation_order = p_seq_exec_plan_->initializer_allocation_order;
|
||||
|
|
@ -1001,7 +997,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
session_state_utils::SaveInitializedTensors(
|
||||
Env::Default(), graph_location, *graph_viewer_,
|
||||
execution_providers_.GetDefaultCpuMemoryInfo(),
|
||||
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator_,
|
||||
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator,
|
||||
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
|
||||
return AddInitializedTensor(idx, value, &d, constant);
|
||||
},
|
||||
|
|
|
|||
|
|
@ -90,7 +90,9 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
|
|||
common::Status SaveInitializedTensors(
|
||||
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map, const std::vector<OrtValueIndex>& initializer_allocation_order, ITensorAllocator& planner,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const std::vector<OrtValueIndex>& initializer_allocation_order,
|
||||
ITensorAllocator& planner,
|
||||
const std::function<Status(int idx, const OrtValue& value, const OrtCallback& d, bool constant)>& save_tensor_func,
|
||||
const logging::Logger& logger, const DataTransferManager& data_transfer_mgr,
|
||||
const ExecutionPlanBase& exec_plan,
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
TensorAllocatorWithMemPattern(const ExecutionPlanBase& execution_plan, const SessionState& session_state,
|
||||
std::vector<BufferUniquePtr>& weights_buffers)
|
||||
: ITensorAllocator(session_state),
|
||||
planner_(execution_plan),
|
||||
planner_(execution_plan, /*using counters*/ false),
|
||||
weights_buffers_(weights_buffers),
|
||||
seq_plan_(execution_plan) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -570,6 +570,7 @@ int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType onn
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) {
|
||||
const Tensor* prev_input = context->Input<Tensor>(0);
|
||||
for (int i = 1; i < context->InputCount(); i++) {
|
||||
|
|
@ -581,15 +582,18 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context)
|
|||
size_t input_element_size = prev_input->DataType()->Size();
|
||||
size_t input_aligned_bytes = 0;
|
||||
|
||||
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes));
|
||||
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size,
|
||||
&input_aligned_bytes));
|
||||
|
||||
ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
|
||||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
|
||||
ORT_RETURN_IF_NOT(
|
||||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
|
||||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
|
||||
|
||||
prev_input = curr_input;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -156,7 +156,9 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
|
|||
|
||||
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
|
||||
#endif
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
TEST(MemPatternPlannerTest, TraceAllocaitonTest) {
|
||||
MemPatternPlanner planner;
|
||||
const bool using_counters = false; // we're not tracking start/end for use/re-use of each allocation via counters
|
||||
MemPatternPlanner planner{using_counters};
|
||||
planner.TraceAllocation(0, 1024);
|
||||
planner.TraceAllocation(1, 256);
|
||||
planner.TraceAllocation(2, 512);
|
||||
|
|
|
|||
Loading…
Reference in a new issue