mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Address PR comments.
Add remaining unit tests.
This commit is contained in:
parent
4e2d061977
commit
0b853cae75
10 changed files with 311 additions and 88 deletions
|
|
@ -40,7 +40,7 @@
|
|||
#include "core/graph/node_arg.h"
|
||||
#include "core/graph/ort_format_load_options.h"
|
||||
|
||||
// Type from Graph API in ORT C API so can't be in a namespace
|
||||
// Type from Model Builder API in ORT C API so can't be in a namespace
|
||||
struct OrtGraph;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -5047,7 +5047,7 @@ struct OrtModelBuilderApi {
|
|||
/** \brief Create an OrtValueInfo for use as an OrtGraph input or output.
|
||||
*
|
||||
* \param[in] name The name of the input or output.
|
||||
* \param[in] type_info The type information for the input or output.
|
||||
* \param[in] type_info The type information for the input or output. The provided value is copied.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5102,7 +5102,7 @@ struct OrtModelBuilderApi {
|
|||
*
|
||||
* \since Version 1.21.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, const char* domain_name, _In_ const char* node_name,
|
||||
ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name,
|
||||
_In_reads_(input_names_len) const char* const* input_names, size_t input_names_len,
|
||||
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
|
||||
_In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len,
|
||||
|
|
@ -5263,9 +5263,16 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession using the OrtModel.
|
||||
*
|
||||
* Create an inference session using the OrtModel.
|
||||
* Create an inference session using the OrtModel instance.
|
||||
* The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs
|
||||
* and SetGraphOutputs must have been called.
|
||||
* This will validate the model, run optimizers, and prepare the session for inferencing.
|
||||
*
|
||||
* \param[in] env The OrtEnv instance.
|
||||
* \param[in] model The OrtModel instance.
|
||||
* \param[in] options The OrtSessionOptions instance.
|
||||
* \param[out] out The OrtSession instance.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
* \since Version 1.21.
|
||||
|
|
@ -5275,11 +5282,17 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession to augment an existing model.
|
||||
*
|
||||
* Create an OrtSession with an existing model that can be augmented with additional nodes.
|
||||
* Nodes can be added to the model using AddNodeToGraph.
|
||||
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
|
||||
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
|
||||
* FinalizeModelBuilderSession.
|
||||
* Create an OrtSession with an existing model that will be augmented with additional nodes and initializers.
|
||||
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
|
||||
* model is finalized.
|
||||
*
|
||||
* To add nodes and initializers to the existing model, first create an OrtModel using CreateModel.
|
||||
* Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
|
||||
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made
|
||||
* by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes.
|
||||
*
|
||||
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
|
||||
* session for inferencing by calling FinalizeModelBuilderSession.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5291,11 +5304,22 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Create an OrtSession to augment an existing model.
|
||||
*
|
||||
* Create an OrtSession with an existing model that can be augmented with additional nodes.
|
||||
* Nodes can be added to the model using AddNodeToGraph.
|
||||
* Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes.
|
||||
* Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling
|
||||
* FinalizeModelBuilderSession.
|
||||
* Create an OrtSession with an existing model that will be augmented with additional nodes and initializers.
|
||||
* Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the
|
||||
* model is finalized.
|
||||
*
|
||||
* To add nodes and initializers to the existing model, first create an OrtModel using CreateModel.
|
||||
* Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph.
|
||||
* Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made
|
||||
* by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes.
|
||||
*
|
||||
* Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the
|
||||
* session for inferencing by calling FinalizeModelBuilderSession.
|
||||
*
|
||||
* \param{in} env The OrtEnv instance.
|
||||
* \param{in} model_data The model data for the existing model to augment.
|
||||
* \param{in} model_data_length The length of the model data.
|
||||
* \param{in} options The OrtSessionOptions instance.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
@ -5308,12 +5332,12 @@ struct OrtModelBuilderApi {
|
|||
|
||||
/** \brief Apply the changes from the model to the session.
|
||||
*
|
||||
* Apply the changes from the model to the session that was created using CreateModelBuilderSession[FromArray].
|
||||
* Apply the changes from the model to a session that was created using CreateModelBuilderSession[FromArray].
|
||||
* All changes will be validated.
|
||||
* Call FinalizeModelBuilderSession to prepare the session for inferencing.
|
||||
*
|
||||
* Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel.
|
||||
* i.e. you don't need to call SetGraphInputs/Outputs if they are unchanged.
|
||||
* i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
|
|
|
|||
|
|
@ -667,7 +667,8 @@ struct ModelMetadata;
|
|||
|
||||
namespace ModelBuilderAPI {
|
||||
struct Model;
|
||||
}
|
||||
struct ValueInfo;
|
||||
} // namespace ModelBuilderAPI
|
||||
|
||||
/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
|
||||
* and release them at the end of the scope. The lifespan of the given allocator
|
||||
|
|
@ -1120,6 +1121,11 @@ struct ConstSessionImpl : Base<T> {
|
|||
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
|
||||
|
||||
int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain
|
||||
|
||||
// NOTE: We will probably move ValueInfo from ModelBuilderAPI to the ORT API as it will also be relevant to the Plugin EP API.
|
||||
// Will move before checkin if that's the case.
|
||||
std::vector<ModelBuilderAPI::ValueInfo> GetInputs() const;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> GetOutputs() const;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -2592,9 +2598,6 @@ struct NodeImpl : Ort::detail::Base<T> {
|
|||
};
|
||||
} // namespace detail
|
||||
|
||||
// Const object holder that does not own the underlying object
|
||||
using ConstNode = detail::NodeImpl<Ort::detail::Unowned<const OrtNode>>;
|
||||
|
||||
/** \brief Wrapper around ::OrtNode
|
||||
*
|
||||
*/
|
||||
|
|
@ -2616,8 +2619,6 @@ struct Node : detail::NodeImpl<OrtNode> {
|
|||
const std::vector<std::string>& output_names,
|
||||
std::vector<OpAttr>& attributes);
|
||||
|
||||
ConstNode GetConst() const { return ConstNode{this->p_}; }
|
||||
|
||||
private:
|
||||
static void Init(const std::string& operator_name, const std::string& operator_domain,
|
||||
const std::string& node_name,
|
||||
|
|
@ -2640,9 +2641,6 @@ struct GraphImpl : Ort::detail::Base<T> {
|
|||
};
|
||||
} // namespace detail
|
||||
|
||||
// Const object holder that does not own the underlying object
|
||||
using ConstGraph = detail::GraphImpl<Ort::detail::Unowned<const OrtGraph>>;
|
||||
|
||||
/** \brief Wrapper around ::OrtGraph
|
||||
*
|
||||
*/
|
||||
|
|
@ -2650,8 +2648,6 @@ struct Graph : detail::GraphImpl<OrtGraph> {
|
|||
explicit Graph(std::nullptr_t) {} ///< No instance is created
|
||||
explicit Graph(OrtGraph* p) : GraphImpl<OrtGraph>{p} {} ///< Take ownership of a pointer created by C API
|
||||
Graph();
|
||||
|
||||
ConstGraph GetConst() const { return ConstGraph{this->p_}; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
|
@ -2675,7 +2671,7 @@ struct Model : detail::ModelImpl<OrtModel> {
|
|||
|
||||
explicit Model(std::nullptr_t) {} ///< No instance is created
|
||||
explicit Model(OrtModel* p) : ModelImpl<OrtModel>{p} {} ///< Take ownership of a pointer created by C API
|
||||
Model(const std::vector<DomainOpsetPair>& opsets);
|
||||
explicit Model(const std::vector<DomainOpsetPair>& opsets);
|
||||
|
||||
ConstModel GetConst() const { return ConstModel{this->p_}; }
|
||||
};
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@
|
|||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
// Convert OrtStatus to Ort::Status and return
|
||||
// instead of throwing
|
||||
|
|
@ -1109,6 +1111,36 @@ inline int ConstSessionImpl<T>::GetOpset(const std::string& domain) const {
|
|||
return opset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetInputs() const {
|
||||
const std::vector<std::string> input_names = GetInputNames();
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> inputs;
|
||||
inputs.reserve(input_names.size());
|
||||
|
||||
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||
auto type_info = GetInputTypeInfo(i);
|
||||
inputs.emplace_back(ModelBuilderAPI::ValueInfo{input_names[i], type_info.GetConst()});
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ModelBuilderAPI::ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
|
||||
const std::vector<std::string> output_names = GetOutputNames();
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> outputs;
|
||||
outputs.reserve(output_names.size());
|
||||
|
||||
for (size_t i = 0; i < output_names.size(); ++i) {
|
||||
auto type_info = GetOutputTypeInfo(i);
|
||||
outputs.emplace_back(ModelBuilderAPI::ValueInfo{output_names[i], type_info.GetConst()});
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
||||
const char* const* output_names, size_t output_count) {
|
||||
|
|
@ -2328,6 +2360,7 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con
|
|||
}
|
||||
|
||||
namespace ModelBuilderAPI {
|
||||
namespace detail {
|
||||
inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>& strings) {
|
||||
std::vector<const char*> ptrs;
|
||||
ptrs.reserve(strings.size());
|
||||
|
|
@ -2336,6 +2369,7 @@ inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>
|
|||
|
||||
return ptrs;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// static
|
||||
inline void Node::Init(const std::string& operator_name, const std::string& operator_domain,
|
||||
|
|
@ -2344,8 +2378,8 @@ inline void Node::Init(const std::string& operator_name, const std::string& oper
|
|||
const std::vector<std::string>& output_names,
|
||||
std::vector<OpAttr>& attributes,
|
||||
OrtNode*& node) {
|
||||
auto inputs = StringsToCharPtrs(input_names);
|
||||
auto outputs = StringsToCharPtrs(output_names);
|
||||
auto inputs = detail::StringsToCharPtrs(input_names);
|
||||
auto outputs = detail::StringsToCharPtrs(output_names);
|
||||
|
||||
std::vector<OrtOpAttr*> attributes_ptrs;
|
||||
attributes_ptrs.reserve(attributes.size());
|
||||
|
|
|
|||
|
|
@ -1289,22 +1289,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
|
|||
const auto* raw_data = tensor.DataRaw();
|
||||
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");
|
||||
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
|
||||
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
|
||||
// use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the
|
||||
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(raw_data));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(tensor.SizeInBytes()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, tensor.SizeInBytes(), tensor_proto);
|
||||
|
||||
} else {
|
||||
utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5807,21 +5807,11 @@ Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updat
|
|||
|
||||
if (is_external) {
|
||||
// pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo
|
||||
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
const void* data_offset = t.DataRaw(); // address of memory not offset into file
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
// magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = tensor_proto.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(t.SizeInBytes()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, t.SizeInBytes(), tensor_proto);
|
||||
|
||||
// copy OrtValue to keep it alive and to store the deleter if provided.
|
||||
ortvalue_initializers_.emplace(name, v);
|
||||
|
|
|
|||
|
|
@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
const auto* fbs_raw_data = fbs_tensor.raw_data();
|
||||
if (fbs_raw_data) {
|
||||
if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) {
|
||||
initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
|
||||
const void* data_offset = fbs_raw_data->Data();
|
||||
// we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto.
|
||||
|
|
@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init
|
|||
// high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path.
|
||||
auto offset = narrow<ExternalDataInfo::OFFSET_TYPE>(reinterpret_cast<intptr_t>(data_offset));
|
||||
|
||||
ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("location");
|
||||
entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag));
|
||||
entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("offset");
|
||||
entry->set_value(std::to_string(offset));
|
||||
entry = initializer.mutable_external_data()->Add();
|
||||
entry->set_key("length");
|
||||
entry->set_value(std::to_string(fbs_raw_data->size()));
|
||||
ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag,
|
||||
offset, fbs_raw_data->size(), initializer);
|
||||
|
||||
} else {
|
||||
// fbs_raw_data is uint8_t vector, so the size is byte size
|
||||
initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size());
|
||||
|
|
|
|||
|
|
@ -627,7 +627,7 @@ class InferenceSession {
|
|||
/// convenience pointer to logger. should always be the same as session_state_.Logger();
|
||||
const logging::Logger* session_logger_;
|
||||
|
||||
// The list of execution providers.
|
||||
// The list of execution providers.
|
||||
// This MUST be prior to model_ in case there are values in the model that were allocated using an allocator
|
||||
// provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the
|
||||
// EP instance.
|
||||
|
|
|
|||
|
|
@ -48,13 +48,15 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateValueInfo, _In_ const char* name,
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) {
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info,
|
||||
_Out_ const char** name) {
|
||||
API_IMPL_BEGIN
|
||||
*name = value_info->name.c_str();
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) {
|
||||
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info,
|
||||
_Outptr_ const OrtTypeInfo** type_info) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
*type_info = value_info->type_info.get();
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
|
|||
|
||||
api.ReleaseSession(session.release());
|
||||
|
||||
ASSERT_EQ(deleter.weights.size(), 0) << "All weights should have been freed";
|
||||
ASSERT_EQ(deleter.weights.size(), size_t(0)) << "All weights should have been freed";
|
||||
};
|
||||
|
||||
run_test(false);
|
||||
|
|
@ -410,12 +410,15 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
|
||||
std::vector<std::string> input_names = session.GetInputNames();
|
||||
ASSERT_EQ(input_names.size(), 1);
|
||||
|
||||
TypeInfo orig_input = session.GetInputTypeInfo(0);
|
||||
ASSERT_EQ(orig_input.GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
|
||||
ASSERT_EQ(graph_inputs.size(), size_t(1));
|
||||
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
// typically this isn't needed. we replace this input but need to read info from it later on in the test
|
||||
// validation so we save the info locally to keep it accessible.
|
||||
auto orig_input_name = graph_inputs[0].Name();
|
||||
auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape();
|
||||
const std::string new_input_name = "Int64Input";
|
||||
|
||||
// Add Cast node to convert input from float to int64
|
||||
|
|
@ -423,16 +426,16 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
|
||||
|
||||
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]},
|
||||
ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"},
|
||||
// the existing node will now consume the output from the Cast instead of a graph input
|
||||
{orig_input_name},
|
||||
attributes);
|
||||
|
||||
// we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs
|
||||
// in order to preserve them
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
// we're replacing the only input. the shape is the same but the name and data type change.
|
||||
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
orig_input.GetTensorTypeAndShapeInfo().GetShape());
|
||||
input_shape);
|
||||
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
|
||||
graph_inputs.emplace_back(new_input_name, input_type_info.GetConst());
|
||||
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
|
||||
|
||||
ModelBuilderAPI::Graph graph; // new info to augment the model with
|
||||
|
||||
|
|
@ -441,13 +444,12 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
|
||||
// the node we added does not require any new opsets.
|
||||
model.AddGraph(graph);
|
||||
|
||||
session.FinalizeModelBuilderSession(model, so);
|
||||
|
||||
std::vector<Input<int64_t>> inputs(1);
|
||||
auto& input = inputs[0];
|
||||
input.name = new_input_name.c_str();
|
||||
input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
|
||||
input.dims = input_shape;
|
||||
|
||||
auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies<int64_t>());
|
||||
input.values.resize(size_t(num_values));
|
||||
|
|
@ -465,8 +467,8 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
|
|||
Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so);
|
||||
std::vector<Input<float>> expected_inputs(1);
|
||||
auto& expected_input = expected_inputs[0];
|
||||
expected_input.name = input_names[0].c_str();
|
||||
expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape();
|
||||
expected_input.name = orig_input_name.c_str();
|
||||
expected_input.dims = input_shape;
|
||||
expected_input.values.reserve(size_t(num_values));
|
||||
std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values),
|
||||
[&](int64_t value) { return float(value); });
|
||||
|
|
@ -489,10 +491,200 @@ TEST(ModelBuilderAPITest, InvalidDimension) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Tests required
|
||||
TEST(ModelBuilderAPITest, CreateInvalidModel_NoOpsets) {
|
||||
Ort::ModelBuilderAPI::Graph graph;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
|
||||
|
||||
- Create invalid model. Graph::Resolve should fail.
|
||||
- Invalid edit. Graph::Resolve should fail.
|
||||
- All the non-tensor Create*TypeInfo functions need to be validated
|
||||
*/
|
||||
std::vector<int64_t> dims({4});
|
||||
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
|
||||
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
|
||||
graph_inputs.emplace_back("X", type_info.GetConst());
|
||||
graph_outputs.emplace_back("Z", type_info.GetConst());
|
||||
|
||||
graph.SetInputs(graph_inputs);
|
||||
graph.SetOutputs(graph_outputs);
|
||||
|
||||
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"});
|
||||
|
||||
graph.AddNode(node);
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets;
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
FAIL();
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, CreateInvalidModel_MissingValue) {
|
||||
Ort::ModelBuilderAPI::Graph graph;
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs;
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_outputs;
|
||||
|
||||
std::vector<int64_t> dims({4});
|
||||
TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims);
|
||||
auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst());
|
||||
graph_inputs.emplace_back("X", type_info.GetConst());
|
||||
graph_outputs.emplace_back("Z", type_info.GetConst());
|
||||
|
||||
graph.SetInputs(graph_inputs);
|
||||
graph.SetOutputs(graph_outputs);
|
||||
|
||||
ModelBuilderAPI::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"});
|
||||
graph.AddNode(node);
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets{{onnxruntime::kOnnxDomain, 18}};
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
auto session = CreateSession(*ort_env, model);
|
||||
FAIL();
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, "
|
||||
"initializer, or output of a previous node."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, InvalidModelEdit) {
|
||||
// Add a node but make the edit invalid in various ways
|
||||
// - add node but don't update graph inputs
|
||||
// - add node with invalid domain
|
||||
const auto edit_model = [](bool invalid_domain) {
|
||||
SessionOptions so;
|
||||
|
||||
// Set this to save the model if you want to debug.
|
||||
// so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx"));
|
||||
|
||||
Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so);
|
||||
|
||||
ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string
|
||||
|
||||
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
|
||||
ModelBuilderAPI::Model model(opsets);
|
||||
ModelBuilderAPI::Graph graph; // new info to augment the model with
|
||||
|
||||
const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain;
|
||||
|
||||
std::vector<ModelBuilderAPI::ValueInfo> graph_inputs = session.GetInputs();
|
||||
ASSERT_EQ(graph_inputs.size(), size_t(1));
|
||||
ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
const std::string new_input_name = "Int64Input";
|
||||
|
||||
// Add Cast node to convert input from float to int64
|
||||
std::vector<OpAttr> attributes;
|
||||
int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT));
|
||||
|
||||
ModelBuilderAPI::Node node("Cast", domain, "NewInputNode", {new_input_name},
|
||||
// the existing node will now consume the output from the Cast instead of a graph input
|
||||
{graph_inputs[0].Name()},
|
||||
attributes);
|
||||
graph.AddNode(node);
|
||||
|
||||
if (invalid_domain) {
|
||||
// we're replacing the only input. the shape is the same but the name and data type change.
|
||||
TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape());
|
||||
auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst());
|
||||
graph_inputs[0] = ModelBuilderAPI::ValueInfo(new_input_name, input_type_info.GetConst());
|
||||
graph.SetInputs(graph_inputs);
|
||||
} else {
|
||||
// model should be invalid as we didn't connect the new node up to the graph inputs
|
||||
}
|
||||
|
||||
// the node we added does not require any new opsets.
|
||||
model.AddGraph(graph);
|
||||
|
||||
try {
|
||||
session.FinalizeModelBuilderSession(model, so);
|
||||
FAIL() << "Should have failed to resolve graph due to invalid edits.";
|
||||
} catch (const Ort::Exception& e) {
|
||||
if (invalid_domain) {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'"));
|
||||
} else {
|
||||
ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
edit_model(false);
|
||||
edit_model(true); // add node with invalid domain
|
||||
}
|
||||
|
||||
TEST(ModelBuilderAPITest, CreateTypeInfo) {
|
||||
const auto& api = Ort::GetApi();
|
||||
TensorTypeAndShapeInfo base_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
|
||||
{2, 4});
|
||||
|
||||
OrtTypeInfo* base_tensor_type_info = nullptr;
|
||||
Ort::ThrowOnError(api.CreateTensorTypeInfo(base_tensor_info, &base_tensor_type_info));
|
||||
|
||||
ONNXType onnx_type = ONNX_TYPE_UNKNOWN;
|
||||
OrtTypeInfo* sparse_tensor_type_info = nullptr;
|
||||
const OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
|
||||
ONNXTensorElementDataType onnx_element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
|
||||
// sparse tensor
|
||||
Ort::ThrowOnError(api.CreateSparseTensorTypeInfo(base_tensor_info, &sparse_tensor_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sparse_tensor_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SPARSETENSOR);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sparse_tensor_type_info, &tensor_info));
|
||||
Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type));
|
||||
ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
// sequence
|
||||
OrtTypeInfo* sequence_type_info = nullptr;
|
||||
const OrtSequenceTypeInfo* sequence_info = nullptr;
|
||||
OrtTypeInfo* sequence_element_type_info = nullptr;
|
||||
|
||||
Ort::ThrowOnError(api.CreateSequenceTypeInfo(base_tensor_type_info, &sequence_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SEQUENCE);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToSequenceTypeInfo(sequence_type_info, &sequence_info));
|
||||
Ort::ThrowOnError(api.GetSequenceElementType(sequence_info, &sequence_element_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_element_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sequence_element_type_info, &tensor_info));
|
||||
Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type));
|
||||
ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
// map
|
||||
OrtTypeInfo* map_type_info = nullptr;
|
||||
const OrtMapTypeInfo* map_info = nullptr;
|
||||
ONNXTensorElementDataType map_key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
OrtTypeInfo* map_value_type_info = nullptr;
|
||||
Ort::ThrowOnError(api.CreateMapTypeInfo(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, base_tensor_type_info,
|
||||
&map_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_MAP);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToMapTypeInfo(map_type_info, &map_info));
|
||||
Ort::ThrowOnError(api.GetMapKeyType(map_info, &map_key_type));
|
||||
ASSERT_EQ(map_key_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
|
||||
Ort::ThrowOnError(api.GetMapValueType(map_info, &map_value_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_value_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(map_value_type_info, &tensor_info));
|
||||
Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type));
|
||||
ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
// optional
|
||||
OrtTypeInfo* optional_type_info = nullptr;
|
||||
const OrtOptionalTypeInfo* optional_info = nullptr;
|
||||
OrtTypeInfo* optional_contained_type_info = nullptr;
|
||||
Ort::ThrowOnError(api.CreateOptionalTypeInfo(base_tensor_type_info, &optional_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_OPTIONAL);
|
||||
Ort::ThrowOnError(api.CastTypeInfoToOptionalTypeInfo(optional_type_info, &optional_info));
|
||||
Ort::ThrowOnError(api.GetOptionalContainedTypeInfo(optional_info, &optional_contained_type_info));
|
||||
Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_contained_type_info, &onnx_type));
|
||||
ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue