Avoid reusing buffer for node outputs with no consumers (#21019)

This commit is contained in:
Baiju Meswani 2024-06-13 16:08:16 -07:00 committed by GitHub
parent 846cac6e2c
commit fff68c3151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 152 additions and 15 deletions

View file

@ -469,6 +469,15 @@ class PlannerImpl {
*/
}
static bool OutputHasConsumerNode(const Node& node, int output_idx) {
// there will be an edge to all consumer nodes.
// if consumed in a subgraph the edge will be to an implicit input of the node containing the subgraph.
return std::any_of(node.OutputEdgesBegin(), node.OutputEdgesEnd(),
[&output_idx](const Node::EdgeEnd& edge) {
return edge.GetSrcArgIndex() == output_idx;
});
}
bool SameSize(const onnxruntime::NodeArg& arg1, const onnxruntime::NodeArg& arg2) {
if ((!arg1.Exists()) || (!arg2.Exists())) return false;
auto p_shape1 = context_->GetShape(arg1);
@ -1172,8 +1181,8 @@ class PlannerImpl {
value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
} // if
}
}
}
}
@ -1456,7 +1465,13 @@ class PlannerImpl {
} else if (IsNonTensor(*node_output)) {
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
} else if (!context_->IsParallelExecutionEnabled() &&
OutputHasConsumerNode(*pnode, static_cast<int>(output_arg_def_index)) &&
FindReusableTensor(*node_output, &reused)) {
// The check that OutputHasConsumerNode is to handle an edge case where a node produces a value that is
// not consumed by any other nodes. If we set it to kReuse the buffer will be freed prematurely as the
// logic in GenerateDeallocationPlan is based on processing consumer nodes. Changing the implementation of
// GenerateDeallocationPlan is an alternative but that would be a much bigger change.
// Reuse an available (dead) buffer for this output, this is only for sequential execution.
Reuse(reused, current, AllocKind::kReuse);
} else {
@ -1906,8 +1921,8 @@ class PlannerImpl {
node_to_wait[it->Index()].insert({node_index, wait_handle});
}
}
} // output->Exists
} // for each output
}
}
if (output_consumed_in_subgraph) {
const auto downstream = plan_.node_stream_map_[it->Index()];
if (downstream != i) {

View file

@ -1410,7 +1410,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
// Record the allocation plan
// Uncomment the below to dump the allocation plan to std::cout
// LOGS(logger_, VERBOSE) << std::make_pair(p_seq_exec_plan_.get(), this);
// std::cout << std::make_pair(&*p_seq_exec_plan_, this);
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
GetMemoryProfiler()->Init(GetExecutionPlan(), GetOrtValueNameIdxMap());

View file

@ -2040,5 +2040,42 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream";
}
#endif
#ifdef ENABLE_TRAINING_OPS
// use a carefully constructed model to re-produce a customer reported issue where a model produced invalid output.
// this issue required:
// - buffer A that is re-used later in the model
// - output of the first Shape node
// - first usage completes after the following Cast node
// - buffer B which has the same size requirement and is used after the first usage of A is complete
// - buffer B is used for the output from `squeeze2` and a number of other nodes in that part of the model.
// - re-use of buffer A for an output of a node that has no consumers whilst buffer B is still in use
// - this is the `per_input_length` output of the ConcatTraining node
//
// Because the logic to determine when a buffer can be freed is based on consumers, buffer A gets freed after the
// Cast node. It is then re-used as buffer B because the memory pattern planner believes that block to be available.
// When we re-use buffer A for the ConcatTraining output we are using the same address for two different node output
// buffers, leading to corruption of the output.
// This tests that the change in allocation planner to not re-use a buffer for outputs with no consumers prevents this.
TEST(AllocationPlannerTest, AvoidReuseOfBufferForNodeOutputWithNoConsumers) {
SessionOptions sess_opt;
sess_opt.graph_optimization_level = TransformerLevel::Default;
InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/avoid_reuse_of_buffer_for_node_output_with_no_consumers.onnx"));
auto status = sess.Load();
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
const auto& session_state = sess.GetSessionState();
const auto& ort_value_index_map = session_state.GetOrtValueNameIdxMap();
const SequentialExecutionPlan* plan = session_state.GetExecutionPlan();
OrtValueIndex concat_training_unused_out_index;
// Here per_input_length output of the ConcatTraining node has no consumers, so it should not reuse the buffer.
ASSERT_STATUS_OK(ort_value_index_map.GetIdx("per_input_length", concat_training_unused_out_index));
EXPECT_EQ(plan->allocation_plan[concat_training_unused_out_index].alloc_kind, AllocKind::kAllocate);
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -454,9 +454,14 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
#endif
TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) {
// load model with 2 Scan ops that both incorrectly use shapes of { 'None', 'None' } for their outputs.
// as 'None' is not a special value it's treated as a variable name, leading to a runtime error when we
// attempt to re-use the output from the first Scan node for the second. validate we detect this and error out.
// Model that has 2 inputs with shape {'Symbolic', 'Symbolic'} that is carefully constructed to re-use a
// buffer the size of one input for output the size of the other input.
// The model is fine if all values of 'Symbolic' are the same, but invalid if they are not.
// As both inputs claim to have the same size, the allocation plan is based on that.
// Code in ExecutionFrame catches what would result in buffer overflow if input 2 is actually larger than input 1
// and we're attempting to re-use a buffer the size of input 1.
// The 'real' problem being tested is inconsistent values for a dim_param in a model, which could occur anywhere
// in the model.
SessionOptions so;
so.session_logid = "BadModelInvalidDimParamUsage";
@ -464,17 +469,27 @@ TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) {
ASSERT_STATUS_OK(session_object.Load("testdata/invalid_dim_param_value_repetition.onnx"));
ASSERT_STATUS_OK(session_object.Initialize());
std::vector<int64_t> dims_X = {10, 6};
std::vector<float> values_X;
values_X.reserve(60);
std::vector<int64_t> dims_X1 = {10, 6};
std::vector<float> values_X1;
values_X1.reserve(60);
for (int i = 0; i < 60; ++i) {
values_X.push_back(float(i));
values_X1.push_back(float(i));
}
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_X, values_X, &ml_value);
std::vector<int64_t> dims_X2 = {10, 12};
std::vector<float> values_X2;
values_X2.reserve(120);
for (int i = 0; i < 120; ++i) {
values_X2.push_back(float(i));
}
OrtValue ml_value1;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_X1, values_X1, &ml_value1);
OrtValue ml_value2;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_X2, values_X2, &ml_value2);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value));
feeds.insert({"X1", ml_value1});
feeds.insert({"X2", ml_value2});
// prepare outputs
std::vector<std::string> output_names;

View file

@ -0,0 +1,70 @@
"""
Run this script to recreate the original onnx model.
Example usage:
python invalid_dim_param_value_repetition.py
"""
import numpy as np
import onnx
def order_repeated_field(repeated_proto, key_name, order):
order = list(order)
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
node = onnx.helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
if doc_string == "":
node.doc_string = ""
order_repeated_field(node.attribute, "name", kwargs.keys())
return node
def make_graph(*args, doc_string=None, **kwargs):
graph = onnx.helper.make_graph(*args, doc_string=doc_string, **kwargs)
if doc_string == "":
graph.doc_string = ""
return graph
model = onnx.helper.make_model(
opset_imports=[onnx.helper.make_operatorsetid("", 11)],
ir_version=5,
producer_name="skl2onnx",
producer_version="1.5.9999",
domain="ai.onnx",
model_version=0,
graph=make_graph(
name="OnnxIdentity",
inputs=[
onnx.helper.make_tensor_value_info("X1", onnx.TensorProto.FLOAT, shape=["Symbolic", "Symbolic"]),
onnx.helper.make_tensor_value_info("X2", onnx.TensorProto.FLOAT, shape=["Symbolic", "Symbolic"]),
],
outputs=[
onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=[None, None]),
],
initializer=[
onnx.numpy_helper.from_array(np.array([0.10000000149011612], dtype="float32"), name="Addcst"),
],
nodes=[
# take an input. Add to create a local output buffer for O01.
make_node("Add", inputs=["X1", "Addcst"], outputs=["O01"], name="Add1", domain=""),
# Use Shape -> ConstantOfShape to make O01 available for reuse
make_node("Shape", inputs=["O01"], outputs=["O02"], name="Shape1", domain=""),
# ConstantOfShape to get back to the right rank, and ReduceSum so the value is broadcastable in the
# the downstream Add
make_node("ConstantOfShape", inputs=["O02"], outputs=["O03"], name="ConstantOfShape ", domain=""),
make_node("ReduceSum", inputs=["O03"], outputs=["O04"], name="ReduceSum1", domain=""),
# Two Add nodes with the ReduceSum output. One could be in-place, but the other needs a buffer.
# This should trigger attempted re-use of O01, so provided X2 is larger than X1 that should break
make_node("Add", inputs=["O04", "X2"], outputs=["O05"], name="Add2", domain=""),
make_node("Add", inputs=["X2", "O04"], outputs=["O06"], name="Add3", domain=""),
# concat to separate the Add outputs from graph output (which is always allocated)
make_node("Concat", inputs=["O05", "O06"], outputs=["Y"], axis=-1, name="Concat", domain=""),
],
),
)
if __name__ == "__main__":
onnx.save(model, "invalid_dim_param_value_repetition.onnx")