Exclude some training specific code from the minimal build. Cleanup some related aspects of allocation planner. (#5861)

* Exclude some training specific code around the allocation planning and initializer handling from the minimal build.
Simplify the code around tracking start/end usage of a value.
This commit is contained in:
Scott McKay 2020-11-20 20:25:46 +10:00 committed by GitHub
parent b057b3d36e
commit 00412a76e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 161 additions and 125 deletions

View file

@ -22,6 +22,7 @@ namespace onnxruntime {
// Generalizing this is future work.
enum class AllocKind {
kNotSet = -1,
kAllocate = 0,
kReuse = 1,
kPreExisting = 2,

View file

@ -40,6 +40,9 @@ std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind) {
case AllocKind::kShare:
out << "Share";
break;
case AllocKind::kNotSet:
out << "NotSet";
break;
}
return out;
}
@ -639,8 +642,7 @@ class PlannerImpl {
} else if (IsNonTensor(*node_output)) {
// we do not try sharing-optimization for non-tensors
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
AllocPlan(current).program_counter_start.emplace_back(program_counter);
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
AllocPlan(current).program_counter.AddStart(program_counter);
} else if (FindReusableInput(*pnode, static_cast<int>(output_arg_def_index), &reused)) {
// Reuse one of this node's input buffers as the output buffer (for in-place update)
Reuse(reused, current, AllocKind::kReuse);
@ -650,18 +652,12 @@ class PlannerImpl {
Reuse(reused, current, AllocKind::kReuse);
OrtValueIndex original = Buffer(reused);
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() != SIZE_MAX);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() < program_counter);
AllocPlan(original).program_counter_start.emplace_back(program_counter);
AllocPlan(original).program_counter_end.emplace_back(SIZE_MAX);
AllocPlan(original).program_counter.AddStart(program_counter);
}
} else {
// otherwise: allocate a new buffer for this output
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
AllocPlan(current).program_counter_start.emplace_back(program_counter);
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
AllocPlan(current).program_counter.AddStart(program_counter);
}
}
@ -675,9 +671,7 @@ class PlannerImpl {
if ((original != -1) && (0 == DecrementUseCount(original))) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
AllocPlan(original).program_counter.AddEnd(program_counter);
}
}
}
@ -692,9 +686,7 @@ class PlannerImpl {
if ((original != -1) && (0 == DecrementUseCount(original))) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
AllocPlan(original).program_counter.AddEnd(program_counter);
}
}
}
@ -708,9 +700,7 @@ class PlannerImpl {
if (0 == DecrementUseCount(original)) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
AllocPlan(original).program_counter.AddEnd(program_counter);
}
}
}
@ -719,6 +709,7 @@ class PlannerImpl {
return Status::OK();
}
#ifdef ENABLE_TRAINING
bool AllocateInputsContiguously(const Node& node) const {
const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
if (ci.kernel_def == nullptr) {
@ -766,38 +757,20 @@ class PlannerImpl {
}
return Status::OK();
}
#endif
// Ensure memory time schedule is sorted.
Status VerifyMemoryTimeSchedule() {
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) {
SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter];
const auto* pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index);
const auto& input_defs = pnode->InputDefs();
for (int input_arg_def_index = 0; static_cast<size_t>(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) {
const auto& node_input = input_defs[input_arg_def_index];
if (!node_input->Exists()) continue;
const auto& current_plan = AllocPlan(Index(node_input->Name()));
if (current_plan.alloc_kind != AllocKind::kAllocate) continue;
ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size());
size_t start = 0;
for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) {
ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0));
ORT_ENFORCE(current_plan.program_counter_start[index] <= current_plan.program_counter_end[index]);
ORT_ENFORCE(current_plan.program_counter_start[index] < SIZE_MAX);
ORT_ENFORCE((current_plan.program_counter_end[index] > 0) || (index == 0));
start = current_plan.program_counter_start[index];
}
void VerifyMemoryTimeSchedule() {
size_t idx = 0;
for (const auto& entry : plan_.allocation_plan) {
if (entry.alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(entry.program_counter.HasValidEntries(), "Invalid program_counter entries at index ", idx);
}
}
return Status::OK();
++idx;
}
}
// Whether a given NodeArg has fence or not.
// If the buffer is reused, need to check whether original OrtValue has fence or not.
bool HasFence(const onnxruntime::NodeArg* arg) {
@ -875,8 +848,7 @@ class PlannerImpl {
for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) {
auto ml_value_idx = plan_.to_be_freed[index];
if (AllocPlan(ml_value_idx).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_start.back() <= program_counter);
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_end.back() == program_counter);
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter.Ends().back() == program_counter);
}
}
@ -914,15 +886,17 @@ Status PlannerImpl::CreatePlan() {
// Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeFenceCheck());
#ifdef ENABLE_TRAINING
// Determine allocation order for weights and activations. This needs to be done after ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeAllocationOrder());
#endif
// convert information in the freelist_ into a deallocation plan in required format
GenerateDeallocationPlan();
// Ensure Memory-Time schedule is sorted. This should be called at the end because memory start/end timestamps
// Ensure Memory-Time schedule is valid. This should be called at the end because memory start/end timestamps
// are updated until GenerateDeallocationPlan is finished.
ORT_RETURN_IF_ERROR(VerifyMemoryTimeSchedule());
VerifyMemoryTimeSchedule();
return Status::OK();
}

View file

@ -29,27 +29,33 @@ namespace onnxruntime {
// Thread-safe.
class MemPatternPlanner {
public:
MemPatternPlanner() = default;
// only the Training code currently uses the program counter based logic
MemPatternPlanner(bool using_counters) : using_counters_{using_counters} {}
#ifdef ENABLE_TRAINING
// TODO: OverlappingTimeSchedules should be private
// Returns true if there is an intersection between two time schedules.
// ASSUMES EACH TIME SCHEDULE IS SORTED. THIS IS VALIDATED AT THE END OF MEMORY PLANNING.
bool OverlappingTimeSchedules(const std::vector<size_t>& program_counter_start_1, const std::vector<size_t>& program_counter_end_1,
const std::vector<size_t>& program_counter_start_2, const std::vector<size_t>& program_counter_end_2) {
ORT_ENFORCE(program_counter_start_1.size() > 0);
ORT_ENFORCE(program_counter_start_2.size() > 0);
ORT_ENFORCE(program_counter_start_1.size() == program_counter_end_1.size());
ORT_ENFORCE(program_counter_start_2.size() == program_counter_end_2.size());
// ProgramCounter values are validated when the execution plan is created
bool OverlappingTimeSchedules(const AllocPlanPerValue::ProgramCounter& counter1,
const AllocPlanPerValue::ProgramCounter& counter2) const {
const auto& starts_1 = counter1.Starts();
const auto& ends_1 = counter1.Ends();
const auto& starts_2 = counter2.Starts();
const auto& ends_2 = counter2.Ends();
size_t index_1 = 0;
size_t index_2 = 0;
while ((index_1 < program_counter_start_1.size()) && (index_2 < program_counter_start_2.size())) {
if (program_counter_start_1[index_1] <= program_counter_start_2[index_2]) {
if (program_counter_end_1[index_1] >= program_counter_start_2[index_2]) {
size_t index_1_end = starts_1.size();
size_t index_2_end = starts_2.size();
while ((index_1 < index_1_end) && (index_2 < index_2_end)) {
if (starts_1[index_1] <= starts_2[index_2]) {
if (ends_1[index_1] >= starts_2[index_2]) {
return true;
}
index_1 += 1;
} else {
if (program_counter_end_2[index_2] >= program_counter_start_1[index_1]) {
if (ends_2[index_2] >= starts_1[index_1]) {
return true;
}
index_2 += 1;
@ -59,7 +65,9 @@ class MemPatternPlanner {
return false;
}
void TraceAllocation(int ml_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size) {
void TraceAllocation(int ml_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size) {
ORT_ENFORCE(using_counters_);
std::lock_guard<OrtMutex> lock(lock_);
if (size == 0) {
@ -73,8 +81,7 @@ class MemPatternPlanner {
bool best_offset_found = false;
for (auto it = blocks_.begin(); it != blocks_.end(); it++) {
// Memory block can be re-used as long as there is no overlap between their time schedules.
if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(program_counter_start, program_counter_end,
allocs_[*it].program_counter_start_, allocs_[*it].program_counter_end_)) {
if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(counter, *allocs_[*it].counter_)) {
continue;
}
@ -107,7 +114,7 @@ class MemPatternPlanner {
// we only need to bounds check the addition of size to best_offset as that is the only time we extend
// the maximum size of the buffer.
buffer_size_ = std::max(buffer_size_, SafeInt<size_t>(best_offset) + size);
allocs_.emplace_back(ml_value_idx, program_counter_start, program_counter_end, MemoryBlock(best_offset, size));
allocs_.emplace_back(ml_value_idx, counter, MemoryBlock(best_offset, size));
std::list<int>::iterator best_fit_it = blocks_.end();
for (auto it = blocks_.begin(); it != blocks_.end(); it++) {
if (allocs_[*it].block_.offset_ < best_offset)
@ -121,8 +128,11 @@ class MemPatternPlanner {
blocks_.insert(best_fit_it, (static_cast<int>(allocs_.size()) - 1));
}
#endif
void TraceAllocation(int ml_value_idx, size_t size) {
ORT_ENFORCE(!using_counters_);
std::lock_guard<OrtMutex> lock(lock_);
if (size == 0) {
@ -190,35 +200,38 @@ class MemPatternPlanner {
}
}
MemoryPattern GenerateMemPattern() {
MemoryPattern GenerateMemPattern() const {
std::lock_guard<OrtMutex> lock(lock_);
// Time schedules of overlapping memory blocks SHOULD NOT intersect.
for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) {
if (!allocs_[index_1].reuse_)
continue;
for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) {
if (!allocs_[index_2].reuse_)
#ifdef ENABLE_TRAINING
if (using_counters_) {
// Time schedules of overlapping memory blocks SHOULD NOT intersect.
for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) {
if (!allocs_[index_1].reuse_)
continue;
size_t alloc_1_start = allocs_[index_1].block_.offset_;
size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1;
for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) {
if (!allocs_[index_2].reuse_)
continue;
ORT_ENFORCE(alloc_1_start <= alloc_1_end);
size_t alloc_1_start = allocs_[index_1].block_.offset_;
size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1;
size_t alloc_2_start = allocs_[index_2].block_.offset_;
size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1;
ORT_ENFORCE(alloc_1_start <= alloc_1_end);
ORT_ENFORCE(alloc_2_start <= alloc_2_end);
size_t alloc_2_start = allocs_[index_2].block_.offset_;
size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1;
if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) ||
((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) {
ORT_ENFORCE(!OverlappingTimeSchedules(allocs_[index_1].program_counter_start_, allocs_[index_1].program_counter_end_,
allocs_[index_2].program_counter_start_, allocs_[index_2].program_counter_end_));
ORT_ENFORCE(alloc_2_start <= alloc_2_end);
if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) ||
((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) {
ORT_ENFORCE(!OverlappingTimeSchedules(*allocs_[index_1].counter_, *allocs_[index_2].counter_));
}
}
}
}
#endif
MemoryPattern pattern;
pattern.peak_size_ = buffer_size_;
@ -233,18 +246,20 @@ class MemPatternPlanner {
struct OrtValueAllocationBlock {
int index_{-1};
MemoryBlock block_;
const std::vector<size_t> program_counter_start_;
const std::vector<size_t> program_counter_end_;
const AllocPlanPerValue::ProgramCounter* counter_{nullptr};
bool reuse_{false};
OrtValueAllocationBlock() = default;
OrtValueAllocationBlock(int index, const MemoryBlock& block) : index_(index), block_(block), reuse_{false} {}
OrtValueAllocationBlock(int index, std::vector<size_t> program_counter_start, std::vector<size_t> program_counter_end, const MemoryBlock& block) : index_(index), block_(block), program_counter_start_(program_counter_start), program_counter_end_(program_counter_end), reuse_{true} {}
OrtValueAllocationBlock(int index, const AllocPlanPerValue::ProgramCounter& counter, const MemoryBlock& block)
: index_(index), block_(block), counter_(&counter), reuse_{true} {
}
};
std::vector<OrtValueAllocationBlock> allocs_;
// blocks_ the list of currently allocated memory blocks, sorted in order of their offset
std::list<int> blocks_;
SafeInt<size_t> buffer_size_{0};
bool using_counters_;
mutable OrtMutex lock_;
};

View file

@ -6,27 +6,31 @@
#include "core/framework/execution_plan_base.h"
namespace onnxruntime {
OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan)
OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters)
: execution_planner_(execution_plan) {
for (auto& location : execution_plan.GetAllLocations()) {
planner_map_.emplace(location, onnxruntime::make_unique<MemPatternPlanner>());
planner_map_.emplace(location, onnxruntime::make_unique<MemPatternPlanner>(trace_using_counters));
}
}
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size) {
#ifdef ENABLE_TRAINING
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx,
const AllocPlanPerValue::ProgramCounter& counter,
size_t size) {
// TODO(codemzs): refactor code.
auto location = execution_planner_.GetLocation(ort_value_idx);
const auto& location = execution_planner_.GetLocation(ort_value_idx);
auto it = planner_map_.find(location);
if (it == planner_map_.end()) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
it->second->TraceAllocation(ort_value_idx, program_counter_start, program_counter_end, size);
it->second->TraceAllocation(ort_value_idx, counter, size);
return common::Status::OK();
}
#endif
common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t size) {
auto location = execution_planner_.GetLocation(ort_value_idx);
const auto& location = execution_planner_.GetLocation(ort_value_idx);
auto it = planner_map_.find(location);
if (it == planner_map_.end()) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
@ -37,7 +41,7 @@ common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t
}
common::Status OrtValuePatternPlanner::TraceFree(int ort_value_index) {
auto location = execution_planner_.GetLocation(ort_value_index);
const auto& location = execution_planner_.GetLocation(ort_value_index);
auto it = planner_map_.find(location);
if (it == planner_map_.end()) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);

View file

@ -18,8 +18,12 @@ class ExecutionPlanBase;
// SessionOptions.enable_mem_pattern
class OrtValuePatternPlanner {
public:
explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan);
common::Status TraceAllocation(int ort_value_idx, const std::vector<size_t>& program_counter_start, const std::vector<size_t>& program_counter_end, size_t size);
// trace_using_counters should be true if the TraceAllocation with ProgramCounter is used. Only one
// variant of the TraceAllocation calls may be used.
explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan, bool trace_using_counters = false);
#ifdef ENABLE_TRAINING
common::Status TraceAllocation(int ort_value_idx, const AllocPlanPerValue::ProgramCounter& counter, size_t size);
#endif
common::Status TraceAllocation(int ort_value_idx, size_t size);
common::Status TraceFree(int ort_value_index);
common::Status GeneratePatterns(MemoryPatternGroup* out);

View file

@ -191,8 +191,10 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
// Execute the kernel.
ORT_TRY {
#ifdef ENABLE_TRAINING
if (p_op_kernel->KernelDef().AllocateInputsContiguously())
utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context);
#endif
status = p_op_kernel->Compute(&op_kernel_context);
}

View file

@ -22,7 +22,7 @@ class SessionState;
// Captures information required to allocate/reuse buffer for a ml-value
struct AllocPlanPerValue {
AllocKind alloc_kind{AllocKind::kAllocate};
AllocKind alloc_kind{AllocKind::kNotSet};
MLDataType value_type{nullptr};
OrtMemoryInfo location;
// reused_buffer is valid only if alloc_kind == kReuse. It indicates
@ -31,8 +31,37 @@ struct AllocPlanPerValue {
// if the value is used in async kernel, a fence object would be created
// note the fence object would be shared between MLValues reusing the same buffer
bool create_fence_if_async{false};
std::vector<size_t> program_counter_start;
std::vector<size_t> program_counter_end;
class ProgramCounter {
public:
ProgramCounter() = default;
void AddStart(size_t start) {
ORT_ENFORCE(starts_.size() == ends_.size(), "Previous entry was not terminated.");
ORT_ENFORCE(starts_.empty() || start > ends_.back(), "Invalid 'start'. Value is smaller than previous 'end'.");
starts_.push_back(start);
}
void AddEnd(size_t end) {
ORT_ENFORCE(starts_.size() == ends_.size() + 1, "No matching 'start' entry.");
ORT_ENFORCE(end >= starts_.back(), "Invalid 'end'. Value is larger than 'start'.");
ends_.push_back(end);
}
// return true if there are entries, and the number of start/end pairs match.
// validity of the individual start/end values is checked when they are added.
bool HasValidEntries() const {
return !starts_.empty() && starts_.size() == ends_.size();
}
const std::vector<size_t>& Starts() const { return starts_; }
const std::vector<size_t>& Ends() const { return ends_; }
private:
std::vector<size_t> starts_;
std::vector<size_t> ends_;
};
ProgramCounter program_counter;
public:
AllocPlanPerValue() : location(CPU, Invalid) {}

View file

@ -307,9 +307,11 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
node_compute_range.Begin();
#endif
ORT_TRY {
#ifdef ENABLE_TRAINING
if (p_op_kernel->KernelDef().AllocateInputsContiguously())
utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context);
#endif
compute_status = p_op_kernel->Compute(&op_kernel_context);
}
ORT_CATCH(const std::exception& ex) {

View file

@ -396,14 +396,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
ORT_RETURN_IF_ERROR(ResolveDimParams(*graph_viewer_, feeds, map));
auto* exe_plan = GetExecutionPlan();
ORT_ENFORCE(exe_plan);
OrtValuePatternPlanner mem_planner(*exe_plan);
OrtValuePatternPlanner mem_planner(*exe_plan, /*using counters*/ true);
// Try to resolve shapes for activations.
auto& node_index_info = GetNodeIndexInfo();
for (auto& node_plan : exe_plan->execution_plan) {
int node_index = node_index_info.GetNodeOffset(node_plan.node_index);
auto* node = graph_viewer_->GetNode(node_plan.node_index);
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
static_cast<int>(node->ImplicitInputDefs().size());
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
@ -427,7 +428,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
}
}
// Allocate activations that want to be laid out contigously in memory.
// Allocate activations that want to be laid out contiguously in memory.
for (auto ml_value_idx : exe_plan->activation_allocation_order) {
ORT_ENFORCE(ml_value_idx >= 0);
@ -450,13 +451,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
}
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate);
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size());
for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1)
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]);
mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start,
exe_plan->allocation_plan[ml_value_idx].program_counter_end, size);
const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter;
mem_planner.TraceAllocation(ml_value_idx, counter, size);
}
}
@ -464,12 +461,15 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
for (auto& node_plan : exe_plan->execution_plan) {
int node_index = node_index_info.GetNodeOffset(node_plan.node_index);
auto* node = graph_viewer_->GetNode(node_plan.node_index);
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
int output_start = node_index + static_cast<int>(node->InputDefs().size()) +
static_cast<int>(node->ImplicitInputDefs().size());
//allocate output
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
if (ml_value_idx == NodeIndexInfo::kInvalidEntry ||
(std::find(exe_plan->activation_allocation_order.begin(), exe_plan->activation_allocation_order.end(), ml_value_idx) != exe_plan->activation_allocation_order.end()))
(std::find(exe_plan->activation_allocation_order.begin(),
exe_plan->activation_allocation_order.end(), ml_value_idx) !=
exe_plan->activation_allocation_order.end()))
continue;
const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type;
if (!ml_type->IsTensorType())
@ -487,13 +487,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
}
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate);
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size());
for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1)
ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]);
mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start,
exe_plan->allocation_plan[ml_value_idx].program_counter_end, aligned_size);
const auto& counter = exe_plan->allocation_plan[ml_value_idx].program_counter;
mem_planner.TraceAllocation(ml_value_idx, counter, aligned_size);
}
}
@ -991,7 +987,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
// Uncomment the below to dump the allocation plan to std::cout
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
std::unique_ptr<ITensorAllocator> tensor_allocator_(
std::unique_ptr<ITensorAllocator> tensor_allocator(
ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_));
const auto& initializer_allocation_order = p_seq_exec_plan_->initializer_allocation_order;
@ -1001,7 +997,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
session_state_utils::SaveInitializedTensors(
Env::Default(), graph_location, *graph_viewer_,
execution_providers_.GetDefaultCpuMemoryInfo(),
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator_,
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator,
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status {
return AddInitializedTensor(idx, value, &d, constant);
},

View file

@ -90,7 +90,9 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
common::Status SaveInitializedTensors(
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info,
const OrtValueNameIdxMap& ort_value_name_idx_map, const std::vector<OrtValueIndex>& initializer_allocation_order, ITensorAllocator& planner,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const std::vector<OrtValueIndex>& initializer_allocation_order,
ITensorAllocator& planner,
const std::function<Status(int idx, const OrtValue& value, const OrtCallback& d, bool constant)>& save_tensor_func,
const logging::Logger& logger, const DataTransferManager& data_transfer_mgr,
const ExecutionPlanBase& exec_plan,

View file

@ -61,7 +61,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
TensorAllocatorWithMemPattern(const ExecutionPlanBase& execution_plan, const SessionState& session_state,
std::vector<BufferUniquePtr>& weights_buffers)
: ITensorAllocator(session_state),
planner_(execution_plan),
planner_(execution_plan, /*using counters*/ false),
weights_buffers_(weights_buffers),
seq_plan_(execution_plan) {}

View file

@ -570,6 +570,7 @@ int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType onn
}
}
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) {
const Tensor* prev_input = context->Input<Tensor>(0);
for (int i = 1; i < context->InputCount(); i++) {
@ -581,15 +582,18 @@ common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context)
size_t input_element_size = prev_input->DataType()->Size();
size_t input_aligned_bytes = 0;
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes));
ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size,
&input_aligned_bytes));
ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
ORT_RETURN_IF_NOT(
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + input_aligned_bytes ||
curr_input->DataRaw() == static_cast<const int8_t*>(prev_input->DataRaw()) + prev_input->SizeInBytes());
prev_input = curr_input;
}
return Status::OK();
}
#endif
} // namespace utils
} // namespace onnxruntime

View file

@ -156,7 +156,9 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif
} // namespace utils
} // namespace onnxruntime

View file

@ -7,7 +7,8 @@
namespace onnxruntime {
namespace test {
TEST(MemPatternPlannerTest, TraceAllocaitonTest) {
MemPatternPlanner planner;
const bool using_counters = false; // we're not tracking start/end for use/re-use of each allocation via counters
MemPatternPlanner planner{using_counters};
planner.TraceAllocation(0, 1024);
planner.TraceAllocation(1, 256);
planner.TraceAllocation(2, 512);