mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Fix memory planning issues (#5752)
* Fix memory planning issues * fix build * fix the wrong line...
This commit is contained in:
parent
44d3c31200
commit
49288de17c
5 changed files with 23 additions and 19 deletions
|
|
@ -733,8 +733,7 @@ class PlannerImpl {
|
|||
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
|
||||
std::vector<OrtValueIndex>& initializer_allocation_order(plan_.initializer_allocation_order);
|
||||
std::vector<OrtValueIndex>& activation_allocation_order(plan_.activation_allocation_order);
|
||||
for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) {
|
||||
SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter];
|
||||
for (auto& step : execution_plan) {
|
||||
const auto* pnode = graph_viewer_.GetNode(step.node_index);
|
||||
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index);
|
||||
if (!AllocateInputsContiguously(*pnode)) continue;
|
||||
|
|
@ -742,8 +741,7 @@ class PlannerImpl {
|
|||
const auto& input_defs = pnode->InputDefs();
|
||||
onnxruntime::AllocKind input_kind = AllocKind::kAllocateStatically;
|
||||
bool set_input_kind = true;
|
||||
for (int input_arg_def_index = 0; static_cast<size_t>(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) {
|
||||
const auto& node_input = input_defs[input_arg_def_index];
|
||||
for (const auto& node_input : input_defs) {
|
||||
if (!node_input->Exists()) continue;
|
||||
const auto current_idx = Index(node_input->Name());
|
||||
const auto& current_plan = AllocPlan(current_idx);
|
||||
|
|
@ -784,7 +782,7 @@ class PlannerImpl {
|
|||
if (current_plan.alloc_kind != AllocKind::kAllocate) continue;
|
||||
|
||||
ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size());
|
||||
|
||||
|
||||
size_t start = 0;
|
||||
for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) {
|
||||
ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0));
|
||||
|
|
|
|||
|
|
@ -318,8 +318,10 @@ Status ResolveDimParams(const GraphViewer& graph,
|
|||
Status TryResolveShape(
|
||||
const NodeArg* arg,
|
||||
const std::unordered_map<std::string, int64_t>& symbolic_dimensions,
|
||||
size_t& is_resolved, // indicate whether resolve successfully or not.
|
||||
std::vector<int64_t>& resolved_shape) {
|
||||
if (!arg->Shape()) {
|
||||
is_resolved = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -343,8 +345,9 @@ Status TryResolveShape(
|
|||
}
|
||||
}
|
||||
|
||||
is_resolved = safe_size;
|
||||
// Only assign shape if all symbolic dimensions are resolved.
|
||||
if (safe_size != 0) {
|
||||
if (is_resolved != 0) {
|
||||
resolved_shape = std::move(shape);
|
||||
}
|
||||
|
||||
|
|
@ -396,12 +399,13 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
continue;
|
||||
|
||||
auto* arg = node->OutputDefs()[i];
|
||||
size_t is_resolved = 0;
|
||||
std::vector<int64_t> resolved_shape;
|
||||
ORT_RETURN_IF_ERROR(TryResolveShape(arg, map, resolved_shape));
|
||||
ORT_RETURN_IF_ERROR(TryResolveShape(arg, map, is_resolved, resolved_shape));
|
||||
|
||||
// Store all valid resolved shapes. They will be queried in, for example,
|
||||
// Recv operator to bypass the dependency of output shapes on inputs.
|
||||
if (resolved_shape.size() > 0) {
|
||||
if (is_resolved != 0) {
|
||||
resolved_shapes[ml_value_idx] = resolved_shape;
|
||||
}
|
||||
}
|
||||
|
|
@ -420,7 +424,9 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
|
|||
size_t size = 0;
|
||||
TryCalculateSizeFromResolvedShape(ml_value_idx, resolved_shapes, size);
|
||||
if (size == 0) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
|
||||
std::string node_name;
|
||||
ORT_RETURN_IF_ERROR(this->ort_value_name_idx_map_.GetName(ml_value_idx, node_name));
|
||||
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute, node name is : " + node_name);
|
||||
}
|
||||
|
||||
if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(size, ml_data_type->Size(), &size)) {
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ common::Status SaveInitializedTensors(
|
|||
}
|
||||
id_to_initialized_tensor[ort_value_index] = entry.second;
|
||||
}
|
||||
|
||||
|
||||
// tensors requiring a specific allocation order are traced first, to ensure they are allocated in order
|
||||
auto initialized_tensors_to_allocate = id_to_initialized_tensor;
|
||||
for (int ort_value_index : initializer_allocation_order) {
|
||||
|
|
@ -149,7 +149,7 @@ common::Status SaveInitializedTensors(
|
|||
initialized_tensors_to_allocate.erase(entry);
|
||||
}
|
||||
|
||||
for (const auto& entry : id_to_initialized_tensor) {
|
||||
for (const auto& entry : initialized_tensors_to_allocate) {
|
||||
// We don't want to trace shared initializers since their memory is provided by the user
|
||||
if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Failed to get allocator for location: " + location.ToString());
|
||||
|
||||
// Don't allocate memory when there is no memory usage..
|
||||
// Don't allocate memory when there is no memory usage.
|
||||
if (mem_patterns_.patterns[i].PeakSize() <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -42,8 +42,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
// Arena has a specific way to store static memory.
|
||||
// Arena does not reuse static memory allocated by Reserve.
|
||||
buffer = static_cast<IArenaAllocator*>(alloc.get())->Reserve(peak_size);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
buffer = alloc->Alloc(peak_size);
|
||||
}
|
||||
weights_buffers_.push_back(BufferUniquePtr(buffer, alloc));
|
||||
|
|
|
|||
|
|
@ -513,8 +513,7 @@ void RegisterTrainingOpSchemas() {
|
|||
|
||||
// dA = reshape(reduce_sum(dY / B, axes_A), shape_A)
|
||||
{{"dY_over_B"}, "Div", {"dY", "B"}},
|
||||
{{"reduce_dA"}, "ReduceSumTraining", {"dY_over_B", "axes_A"},
|
||||
{ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}},
|
||||
{{"reduce_dA"}, "ReduceSumTraining", {"dY_over_B", "axes_A"}, {ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}},
|
||||
{{"dA"}, "Reshape", {"reduce_dA", "shape_A"}},
|
||||
|
||||
// dB = reshape(reduce_sum(dY * -A / (B * B)), axes_B), shape_B)
|
||||
|
|
@ -522,8 +521,7 @@ void RegisterTrainingOpSchemas() {
|
|||
{{"minus_A"}, "Neg", {"A"}},
|
||||
{{"minus_A_over_B_squared"}, "Div", {"minus_A", "B_squared"}},
|
||||
{{"pre_reduce_dB"}, "Mul", {"dY", "minus_A_over_B_squared"}},
|
||||
{{"reduce_dB"}, "ReduceSumTraining", {"pre_reduce_dB", "axes_B"},
|
||||
{ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}},
|
||||
{{"reduce_dB"}, "ReduceSumTraining", {"pre_reduce_dB", "axes_B"}, {ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}},
|
||||
{{"dB"}, "Reshape", {"reduce_dB", "shape_B"}}});
|
||||
|
||||
for (size_t contrib_node_index : {2, 4, 10}) {
|
||||
|
|
@ -1929,7 +1927,10 @@ Return true if all elements are true and false otherwise.
|
|||
.TypeConstraint(
|
||||
"TOut",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
"Constrain scale types to float tensors.");
|
||||
"Constrain scale types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
updateOutputShape(ctx, 0, {});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Send)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
Loading…
Reference in a new issue