Take ownership of node attributes for consistency

Updates comments for clarity.
Copy external data into initializer when saving model for debugging.
This commit is contained in:
Scott McKay 2025-01-07 16:11:35 +10:00
parent 6fb01c19a7
commit 347bd7a3f2
8 changed files with 124 additions and 45 deletions

View file

@ -5158,16 +5158,29 @@ struct OrtModelBuilderApi {
*
* Two options:
*
* Pre-existing memory:
* Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue
* with a tensor that contains a pointer to the existing data.
* User must keep pointer valid for lifetime of the inference session.
* Set `data_is_external` to true.
*
* Allocated memory:
* Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data.
* Set `data_is_external` to false.
*
* Pre-existing memory:
* Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue
* with a tensor that contains a pointer to the existing data.
* Set `data_is_external` to true.
*
* The pointer must remain valid for the duration of the inference session.
* If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session
* is released.
* If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as
* soon as the OrtValue is no longer in use.
*
* NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more.
* For smaller tensors use CreateTensorAsOrtValue.
*
* ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is
* small (typically a single value or limited by the rank of a tensor) and uses less than 128 bytes of
* memory, so this limit acts as a simple catch-all rule to avoid issues.
* e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`.
*
* \param[in] graph The OrtGraph instance to update.
* \param[in] name The value name for the initializer.
* \param[in] tensor The OrtValue instance containing the tensor data.

View file

@ -2418,10 +2418,8 @@ template <>
inline void GraphImpl<OrtGraph>::SetInputs(std::vector<ValueInfo>& inputs) {
std::vector<OrtValueInfo*> inputs_ptrs;
inputs_ptrs.reserve(inputs.size());
// Graph takes ownership.
std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs),
[](ValueInfo& vi) -> OrtValueInfo* { return vi.release(); });
[](ValueInfo& vi) -> OrtValueInfo* { return vi; });
ThrowOnError(GetModelBuilderApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size()));

View file

@ -266,6 +266,18 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
return Status::OK();
}
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
bool has_external_memory = false;
if (utils::HasExternalData(tensor_proto)) {
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_THROW_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
has_external_memory = external_data_info->GetRelPath() == onnxruntime::utils::kTensorProtoMemoryAddressTag;
}
return has_external_memory;
}
void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
tensor_proto.set_raw_data(std::move(param));
}

View file

@ -514,6 +514,10 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
}
#endif
// Check if the TensorProto has an external data entry that points to memory rather than an external file.
// The external data location will be kTensorProtoMemoryAddressTag in this case.
bool HasExternallyAllocatedMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto);
// UnpackTensor from raw data or the type specific data field. Does not handle external data.
// If the tensor does not contain raw data then raw_data should be nullptr and raw_data_len should be 0.
template <typename T>

View file

@ -4093,30 +4093,59 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const {
// This is used for constructing full path for external data
// if it exists
auto add_initializer = [](TensorList& output_initializers, const TensorProto& initializer) -> Status {
TensorProto& output = *output_initializers.Add();
output = initializer;
// inline any in-memory external data
if (utils::HasExternalData(initializer)) {
const std::filesystem::path ignored;
std::basic_string<ORTCHAR_T> location;
onnxruntime::FileOffsetType file_offset;
SafeInt<size_t> tensor_byte_size;
ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size));
if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) {
// file_offset is address
void* data = reinterpret_cast<void*>(file_offset);
// set in raw data
output.clear_data_location();
output.set_raw_data(data, tensor_byte_size);
}
}
return Status::OK();
};
auto* mutable_initializers = result.mutable_initializer();
#if !defined(DISABLE_SPARSE_TENSORS)
const auto& model_path = ModelPath();
// We want to make sure that sparse initializers do not appear
// as dense duplicates within the initializers list.
if (!sparse_tensor_names_.empty()) {
const auto sparse_end = sparse_tensor_names_.end();
auto* mutable_initializer = result.mutable_initializer();
for (const auto& initializer : graph_proto_->initializer()) {
if (sparse_end == sparse_tensor_names_.find(initializer.name())) {
*mutable_initializer->Add() = initializer;
add_initializer(*mutable_initializers, initializer);
} else {
auto& sparse_initializer = *result.add_sparse_initializer();
auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer);
ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse");
}
}
} else {
*result.mutable_initializer() = graph_proto_->initializer();
}
} else
#else
*result.mutable_initializer() = graph_proto_->initializer();
{
for (const auto& initializer : graph_proto_->initializer()) {
add_initializer(*mutable_initializers, initializer);
}
}
#endif
return result;
return result;
}
Status Graph::AddExternalInitializersToGraphProtoImpl(

View file

@ -627,6 +627,12 @@ class InferenceSession {
/// convenience pointer to logger. should always be the same as session_state_.Logger();
const logging::Logger* session_logger_;
// The list of execution providers.
// This MUST be prior to model_ in case there are values in the model that were allocated using an allocator
// provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the
// EP instance.
ExecutionProviders execution_providers_;
// The model served by this inference session instance.
// Currently this has to be a shared ptr because the Model::Load method
// returns a shared_ptr only. Ideally factory functions should always return
@ -637,9 +643,6 @@ class InferenceSession {
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
PathString model_location_;
// The list of execution providers.
ExecutionProviders execution_providers_;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession);
void SetLoggingManager(const SessionOptions& session_options,

View file

@ -93,6 +93,9 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateNode, const char* operator_name, c
n->attributes.reserve(attribs_len);
for (size_t i = 0; i < attribs_len; ++i) {
n->attributes.push_back(*reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attributes[i]));
// take ownership. as we took a copy that means releasing the original value
OrtApis::ReleaseOpAttr(attributes[i]);
attributes[i] = nullptr;
}
}
@ -156,12 +159,31 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::SetGraphOutputs, _In_ OrtGraph* graph,
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name,
_Inout_ OrtValue* tensor, bool data_is_external) {
API_IMPL_BEGIN
if (!tensor->IsTensor()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported.");
}
if (!tensor->IsAllocated()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated.");
}
const auto& t = tensor->Get<onnxruntime::Tensor>();
if (t.Location().device.Type() != OrtDevice::CPU) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported.");
}
if (data_is_external) {
#if !defined(DISABLE_EXTERNAL_INITIALIZERS)
// enforce that an external initializer is not used if the data size is < 128 bytes.
// the reason for this is to avoid potential shape inferencing errors if this initializer is providing an
// input involved in that. the ONNX shape inferencing does not support external data for those values.
// e.g. Reshape's `shape` input, Reduce's `axes', Slice's `starts`, `ends`, `steps`, Clip's `min`, `max`, etc.
if (t.SizeInBytes() < 128) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"External initializer should only be used for data >= 128 bytes. "
"Please use CreateTensorAsOrtValue instead.");
}
graph->external_initializers[name] = std::unique_ptr<OrtValue>(tensor); // take ownership
#else
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External initializers are not supported in this build");
#endif
} else {
graph->initializers[name] = std::unique_ptr<OrtValue>(tensor); // take ownership
}

View file

@ -131,14 +131,14 @@ struct TestAllocator : public OrtAllocator {
// Uses the ORT C++ api for the rest for simplicity
TEST(ModelBuilderAPITest, Basic_CApi) {
const auto& api = Ort::GetApi();
const auto& graph_api = Ort::GetModelBuilderApi();
const auto& model_builder_api = Ort::GetModelBuilderApi();
TestAllocator deleter;
// return void so we can use ASSERT_* in the lambda
const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void {
OrtGraph* graph = nullptr;
Ort::ThrowOnError(graph_api.CreateGraph(&graph));
Ort::ThrowOnError(model_builder_api.CreateGraph(&graph));
//
// Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3.
@ -164,7 +164,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
// create ValueInfo and release the type info as CreateValueInfo takes a copy.
OrtValueInfo* input_value_info = nullptr;
Ort::ThrowOnError(graph_api.CreateValueInfo("X", input_type_info, &input_value_info));
Ort::ThrowOnError(model_builder_api.CreateValueInfo("X", input_type_info, &input_value_info));
api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy
tensor_type_info = nullptr;
@ -180,13 +180,15 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy
OrtValueInfo* output_value_info = nullptr;
Ort::ThrowOnError(graph_api.CreateValueInfo("Z", output_type_info, &output_value_info));
Ort::ThrowOnError(model_builder_api.CreateValueInfo("Z", output_type_info, &output_value_info));
api.ReleaseTypeInfo(output_type_info);
std::vector<OrtValueInfo*> graph_inputs = {input_value_info};
std::vector<OrtValueInfo*> graph_outputs = {output_value_info};
Ort::ThrowOnError(graph_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(graph_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size()));
Ort::ThrowOnError(model_builder_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(model_builder_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size()));
input_value_info = nullptr; // graph now owns the input/output values
output_value_info = nullptr;
//
// Gemm node
@ -200,11 +202,10 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z";
std::vector<const char*> node_output_names = {gemm_output_name.c_str()};
std::vector<OrtOpAttr*> node_attributes{alpha_attr};
OrtNode* node = CreateNode(graph_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes);
OrtNode* node = CreateNode(model_builder_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes);
alpha_attr = nullptr; // Node now owns
api.ReleaseOpAttr(alpha_attr); // CreateNode copies all OrtOpAttr instances
Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node));
Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node));
node = nullptr; // graph now owns node
// Y input
@ -214,11 +215,8 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
4.0f, 5.0f, 6.0f}));
auto& y_values = *deleter.weights.back();
// create an initializer for the Y input. add to `weights` so the memory remains valid
// create an initializer for the Y input. add to `weights` so the memory remains valid.
OrtValue* y_tensor = nullptr;
auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession
Ort::ThrowOnError(
api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter,
y_values.data(), y_values.size() * sizeof(y_values[0]),
@ -226,7 +224,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
&y_tensor));
Ort::ThrowOnError(graph_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true));
Ort::ThrowOnError(model_builder_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true));
y_tensor = nullptr; // graph now owns
if (use_constant_node) {
@ -237,20 +235,20 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
float max = 60.0f;
Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &value_attr));
node = CreateNode(graph_api, "Constant", "clip_max", {}, {"max"}, {value_attr});
Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node));
node = CreateNode(model_builder_api, "Constant", "clip_max", {}, {"max"}, {value_attr});
Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node));
node = nullptr; // graph now owns node
node = CreateNode(graph_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"});
Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node));
node = CreateNode(model_builder_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"});
Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node));
node = nullptr; // graph now owns node
}
std::vector<const char*> domain_names = {onnxruntime::kOnnxDomain};
std::vector<int> opset_versions = {18};
Ort::ThrowOnError(graph_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(),
&model));
Ort::ThrowOnError(graph_api.AddGraphToModel(model, graph));
Ort::ThrowOnError(model_builder_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(),
&model));
Ort::ThrowOnError(model_builder_api.AddGraphToModel(model, graph));
graph = nullptr; // model now owns
};