mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
VitisAI EP Context Model (#20926)
# Why so many commits - Runtime debugging - which is necessary - Three different approaches to EP context model - as a result testing back and forth - Windows compatibility issues - this development has been done on Linux for convenience # "Open" (?) questions - Full offloading to a specific EP - Dumping EP context models by EPs vs [by ONNXRT](e2abba18ea/onnxruntime/core/framework/graph_partitioner.cc (L725)) - [Node name to pick nodes](e2abba18ea/onnxruntime/core/framework/graph_partitioner.cc (L654)) # VitisAI EP made three variant implementations that have respective pros and cons (and of course we can combine them) ## Serialize and cache the list of compute capabilities and the original ONNX model itself ## In `ComputeCapability()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key ## In `Compile()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key # EP context model creation - Precondition Session option configuration `kOrtSessionOptionEpContextEnable` (aka "ep.context_enable") is enabled. - Approach 1 - Steps 1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext"). 2. EP implements/overrides `IExecutionProvider::GetEpContextNodes()` method. 3. ONNXRT core creates an EP context model and saves/dumps it. - `CreateEpContextModel()` in the file "graph_partitioner.cc" - In `get_ep_context_node()`, `Node::Name()` is used to check whether a node is an EP context node. This limits that EP model creation can only happen in `IExecutionProvider::Compile()`. - The workaround is (1) not implementing `IExecutionProvider::GetEpContextNodes()` and (2) dumping the EP context model by EP itself. 4. Optionally, EP can also dump the EP context model it created by iteself. - Examples - `QNNExecutionProvider` - `VitisAIExecutionProvider` - Approach 2 - Steps 1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext"). 2. EP does NOT implement `IExecutionProvider::GetEpContextNodes()` at all. 3. EP dumps the EP context model it created. - Examples - `TensorrtExecutionProvider` - UPDATES - TRT EP is switching to leveraging `IExecutionProvider::GetEpContextNodes()` - `OpenVINOExecutionProvider` (?) # What to cache in EP context nodes - Non Compilation based EPs - Examples - `VitisAIExecutionProvider` - Characteristics - Heavy lifting work happens in `IExecutionProvider::GetCapability()`. - Preconditions - `IExecutionProvider::GetCapability()` is only called once by ONNXRT. - Cache content - Serialization of a list of `ComputeCapability` - Not EP-specific - Serialized using `onnx::FunctionProto` - EP-specific cache - Compilation based EPs - Examples - `QNNExecutionProvider` - `TensorrtExecutionProvider` - `MIGraphXExecutionProvider` - `OpenVINOExecutionProvider` - Cache content - EP-specific cache # Requirements - Offline / AOT compilation of ONNX models with EP context cache - Compile somewhere, run everywhere - Pseudo code with brief explanation ``` GenerateCache(original_onnx_file, cache_onnx_file) model_buffer = load(original_onnx_file) --> Load the original ONNX model file model_buffer = decrypt(model_buffer) session_options = { kOrtSessionOptionEpContextEnable: true, kOrtSessionOptionEpContextFilePath: temp_file } --> Set necessary configs Ort::CreateSessionFromArray(model_buffer, session_options) --> The new ONNX model with EP context is created and dumped into the user specified file "temp_file" temp_buffer = encrypt(temp_file) write(temp_buffer, cache_onnx_file) --> Write the encypted context of "temp_file" into the "cache_onnx_file" file InitializeInferenceSession(cache_onnx_file) model_buffer = load(cache_onnx_file) --> Load the ONNX model with EP context from the file generated in the previous step model_buffer = decrypt(model_buffer) session_options = { } Ort::CreateSessionFromArray(model_buffer, session_options) --> Create and initalize an session with the EP context model ``` - Python code with comments - EP context model creation ```python import onnxruntime as onnxrt # Session options for creating an ONNX model with EP context cache. sess_opts = onnxrt.SessionOptions() # Verbose. sess_opts.log_severity_level = 0 # This is REQUIRED. sess_opts.add_session_config_entry("ep.context_enable", "1") # This is OPTIONAL. # Either an absolute path (preferred for now) or a relative path (WIP) is okay. # sess_opts.add_session_config_entry("ep.context_file_path", "/some/path/to/original_model_ctx.onnx") # This is OPTIONAL. sess_opts.add_session_config_entry("ep.context_embed_mode", "1") orig_model_location = "/some/path/to/original_model.onnx" sess = onnxrt.InferenceSession(orig_model_location, sess_opts, providers=["VitisAIExecutionProvider"], provider_options=[]) ``` - Inference run with an EP context model ```python import onnxruntime as onnxrt # Session options for creating an ONNX model with EP context cache. sess_opts = onnxrt.SessionOptions() # Default EP context model path. # ep_ctx_model_location = "/some/path/to/origina_model.onnx_ctx.onnx" # User configured EP context model path. ep_ctx_model_location = "/some/path/to/origina_model_ctx.onnx" sess = onnxrt.InferenceSession(ep_ctx_model_location, sess_opts, providers=["VitisAIExecutionProvider"], provider_options=[]) model_inputs = {} run_opts = onnxrt.RunOptions() # Verbose. run_opts.log_severity_level = 1 sess.run(None, model_inputs, run_opts) ``` --------- Co-authored-by: Glen Cao <glen@Glens-MacBook-Air.local>
This commit is contained in:
parent
92a8407b39
commit
281ed8c12d
13 changed files with 1195 additions and 16 deletions
|
|
@ -19,6 +19,8 @@ class TensorProto;
|
|||
class SparseTensorProto;
|
||||
class TypeProto;
|
||||
class AttributeProto;
|
||||
class FunctionProto;
|
||||
class OperatorSetIdProto;
|
||||
// define types that would come from the ONNX library if we were building against it.
|
||||
#if defined(ORT_MINIMAL_BUILD)
|
||||
using OperatorSetVersion = int;
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ struct NodeProto;
|
|||
struct SparseTensorProto;
|
||||
struct StringStringEntryProto;
|
||||
struct StringStringEntryProtos; // RepeatedPtrField
|
||||
struct OperatorSetIdProto;
|
||||
struct TensorProto;
|
||||
struct TensorProtos; // RepeatedPtrField
|
||||
struct TensorShapeProto_Dimension;
|
||||
|
|
@ -120,6 +121,7 @@ struct TypeProto_Sequence;
|
|||
struct TypeProto;
|
||||
struct ValueInfoProto;
|
||||
struct ValueInfoProtos; // RepeatedPtrField
|
||||
struct FunctionProto;
|
||||
struct InferenceContext;
|
||||
class GraphInferencer;
|
||||
using InferenceFunction = std::function<void(InferenceContext&)>;
|
||||
|
|
@ -146,6 +148,7 @@ struct ConfigOptions;
|
|||
struct DataTransferManager;
|
||||
struct IndexedSubGraph;
|
||||
struct IndexedSubGraph_MetaDef;
|
||||
enum class IndexedSubGraph_SourceOfSchema : uint8_t;
|
||||
struct KernelCreateInfo;
|
||||
struct KernelDef;
|
||||
struct KernelDefBuilder;
|
||||
|
|
|
|||
|
|
@ -304,6 +304,11 @@ struct ProviderHost {
|
|||
virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0;
|
||||
|
||||
// OperatorSetIdProto
|
||||
virtual std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) = 0;
|
||||
virtual void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) = 0;
|
||||
virtual int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) = 0;
|
||||
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
// TypeProto_Optional
|
||||
virtual const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) = 0;
|
||||
|
|
@ -420,6 +425,11 @@ struct ProviderHost {
|
|||
virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0;
|
||||
|
||||
virtual const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) = 0;
|
||||
virtual int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) = 0;
|
||||
|
||||
// NodeProto
|
||||
virtual std::unique_ptr<ONNX_NAMESPACE::NodeProto> NodeProto__construct() = 0;
|
||||
virtual void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) = 0;
|
||||
|
|
@ -427,6 +437,7 @@ struct ProviderHost {
|
|||
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
|
||||
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
|
||||
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0;
|
||||
|
||||
// TensorProto
|
||||
virtual std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() = 0;
|
||||
|
|
@ -495,6 +506,64 @@ struct ProviderHost {
|
|||
|
||||
virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;
|
||||
|
||||
// FunctionProto
|
||||
virtual std::unique_ptr<ONNX_NAMESPACE::FunctionProto> FunctionProto__construct() = 0;
|
||||
virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0;
|
||||
virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0;
|
||||
virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0;
|
||||
virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
|
||||
virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0;
|
||||
|
||||
virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
|
||||
virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0;
|
||||
|
||||
virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
|
||||
virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0;
|
||||
|
||||
virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;
|
||||
|
||||
virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;
|
||||
|
||||
virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;
|
||||
|
||||
virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
|
||||
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0;
|
||||
|
||||
// ConfigOptions
|
||||
|
|
@ -546,6 +615,9 @@ struct ProviderHost {
|
|||
virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) = 0;
|
||||
virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0;
|
||||
|
||||
virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0;
|
||||
virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0;
|
||||
|
||||
// KernelDef
|
||||
virtual void KernelDef__operator_delete(KernelDef* p) = 0;
|
||||
virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0;
|
||||
|
|
|
|||
|
|
@ -80,6 +80,15 @@ struct StringStringEntryProtos final {
|
|||
|
||||
PROVIDER_DISALLOW_ALL(StringStringEntryProtos)
|
||||
};
|
||||
|
||||
struct OperatorSetIdProto final {
|
||||
std::string* mutable_domain() { return g_host->OperatorSetIdProto__mutable_domain(this); }
|
||||
void set_version(int64_t version) { return g_host->OperatorSetIdProto__set_version(this, version); }
|
||||
int64_t version() { return g_host->OperatorSetIdProto__version(this); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(OperatorSetIdProto)
|
||||
};
|
||||
|
||||
struct AttributeProto final {
|
||||
static std::unique_ptr<AttributeProto> Create() { return g_host->AttributeProto__construct(); }
|
||||
void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); }
|
||||
|
|
@ -178,6 +187,11 @@ struct ModelProto final {
|
|||
|
||||
void set_ir_version(int64_t value) { return g_host->ModelProto__set_ir_version(this, value); }
|
||||
|
||||
const OperatorSetIdProto& opset_import(int index) const { return g_host->ModelProto__opset_import(this, index); }
|
||||
OperatorSetIdProto* mutable_opset_import(int index) { return g_host->ModelProto__mutable_opset_import(this, index); }
|
||||
int opset_import_size() const { return g_host->ModelProto__opset_import_size(this); }
|
||||
OperatorSetIdProto* add_opset_import() { return g_host->ModelProto__add_opset_import(this); }
|
||||
|
||||
ModelProto() = delete;
|
||||
ModelProto(const ModelProto&) = delete;
|
||||
void operator=(const ModelProto&) = delete;
|
||||
|
|
@ -190,6 +204,7 @@ struct NodeProto final {
|
|||
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
|
||||
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
|
||||
AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }
|
||||
AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); }
|
||||
|
||||
NodeProto() = delete;
|
||||
NodeProto(const NodeProto&) = delete;
|
||||
|
|
@ -372,6 +387,69 @@ struct ValueInfoProtos final {
|
|||
|
||||
PROVIDER_DISALLOW_ALL(ValueInfoProtos)
|
||||
};
|
||||
|
||||
struct FunctionProto final {
|
||||
static std::unique_ptr<FunctionProto> Create() { return g_host->FunctionProto__construct(); }
|
||||
static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast<FunctionProto*>(p)); }
|
||||
|
||||
bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); }
|
||||
bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); }
|
||||
bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); }
|
||||
std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); }
|
||||
|
||||
bool has_name() const { return g_host->FunctionProto__has_name(this); }
|
||||
const std::string& name() const { return g_host->FunctionProto__name(this); }
|
||||
void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); }
|
||||
|
||||
bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); }
|
||||
const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); }
|
||||
void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); }
|
||||
|
||||
bool has_domain() const { return g_host->FunctionProto__has_domain(this); }
|
||||
const std::string& domain() const { return g_host->FunctionProto__domain(this); }
|
||||
void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); }
|
||||
|
||||
const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); }
|
||||
std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); }
|
||||
int input_size() const { return g_host->FunctionProto__input_size(this); }
|
||||
void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); }
|
||||
|
||||
const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); }
|
||||
std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); }
|
||||
int output_size() const { return g_host->FunctionProto__output_size(this); }
|
||||
void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); }
|
||||
|
||||
const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); }
|
||||
std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); }
|
||||
int attribute_size() const { return g_host->FunctionProto__attribute_size(this); }
|
||||
void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); }
|
||||
|
||||
const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); }
|
||||
AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); }
|
||||
int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); }
|
||||
AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); }
|
||||
|
||||
const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); }
|
||||
NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); }
|
||||
int node_size() const { return g_host->FunctionProto__node_size(this); }
|
||||
NodeProto* add_node() { return g_host->FunctionProto__add_node(this); }
|
||||
|
||||
const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); }
|
||||
ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); }
|
||||
ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); }
|
||||
int value_info_size() const { return g_host->FunctionProto__value_info_size(this); }
|
||||
ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); }
|
||||
|
||||
const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); }
|
||||
StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); }
|
||||
StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); }
|
||||
int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); }
|
||||
StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); }
|
||||
|
||||
FunctionProto() = delete;
|
||||
FunctionProto(const FunctionProto&) = delete;
|
||||
void operator=(const FunctionProto&) = delete;
|
||||
};
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -449,6 +527,12 @@ struct IndexedSubGraph_MetaDef final {
|
|||
void operator=(const IndexedSubGraph_MetaDef&) = delete;
|
||||
};
|
||||
|
||||
enum class IndexedSubGraph_SourceOfSchema : uint8_t {
|
||||
CREATE,
|
||||
REUSE_OR_CREATE,
|
||||
EXISTING,
|
||||
};
|
||||
|
||||
struct IndexedSubGraph final {
|
||||
static std::unique_ptr<IndexedSubGraph> Create() { return g_host->IndexedSubGraph__construct(); }
|
||||
static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast<IndexedSubGraph*>(p)); }
|
||||
|
|
@ -458,6 +542,9 @@ struct IndexedSubGraph final {
|
|||
void SetMetaDef(std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast<std::unique_ptr<IndexedSubGraph_MetaDef>*>(&meta_def_))); }
|
||||
const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast<const IndexedSubGraph_MetaDef*>(g_host->IndexedSubGraph__GetMetaDef(this)); }
|
||||
|
||||
void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); }
|
||||
IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); }
|
||||
|
||||
IndexedSubGraph() = delete;
|
||||
IndexedSubGraph(const IndexedSubGraph&) = delete;
|
||||
void operator=(const IndexedSubGraph&) = delete;
|
||||
|
|
|
|||
682
onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc
Normal file
682
onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc
Normal file
|
|
@ -0,0 +1,682 @@
|
|||
// Standard headers/libs.
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <cctype>
|
||||
#include <cstring>
|
||||
|
||||
// 3rd-party headers/libs.
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "ep_context_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* kVitisAI = "vitisai";
|
||||
|
||||
std::unique_ptr<ONNX_NAMESPACE::FunctionProto> ConvertIndexedSubGraphToFunctionProto(
|
||||
const IndexedSubGraph& sub_graph, const Graph& parent_graph) {
|
||||
auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create();
|
||||
auto* p_meta_def = const_cast<IndexedSubGraph_MetaDef*>(sub_graph.GetMetaDef());
|
||||
if (p_meta_def) {
|
||||
p_func_proto->set_name(p_meta_def->name());
|
||||
p_func_proto->set_domain(p_meta_def->domain());
|
||||
for (const auto& input : p_meta_def->inputs()) {
|
||||
p_func_proto->add_input(input);
|
||||
}
|
||||
auto* p_metadata_props_0 = p_func_proto->add_metadata_props();
|
||||
*(p_metadata_props_0->mutable_key()) = "meta_def_inputs_size";
|
||||
*(p_metadata_props_0->mutable_value()) = std::to_string(p_meta_def->inputs().size());
|
||||
for (const auto& output : p_meta_def->outputs()) {
|
||||
p_func_proto->add_output(output);
|
||||
}
|
||||
// XXX: SerDes with different fields.
|
||||
for (const auto& initializer : p_meta_def->constant_initializers()) {
|
||||
p_func_proto->add_input(initializer);
|
||||
}
|
||||
// XXX: SerDes with different numbers of fields.
|
||||
for (const auto& attr_pair : p_meta_def->attributes()) {
|
||||
p_func_proto->add_attribute(attr_pair.first);
|
||||
auto* p_attr_proto = p_func_proto->add_attribute_proto();
|
||||
*p_attr_proto = attr_pair.second;
|
||||
}
|
||||
p_func_proto->set_doc_string(p_meta_def->doc_string());
|
||||
// "since_version"
|
||||
auto* p_metadata_props_1 = p_func_proto->add_metadata_props();
|
||||
*(p_metadata_props_1->mutable_key()) = "meta_def_since_version";
|
||||
*(p_metadata_props_1->mutable_value()) = std::to_string(p_meta_def->since_version());
|
||||
// "status"
|
||||
auto* p_metadata_props_2 = p_func_proto->add_metadata_props();
|
||||
*(p_metadata_props_2->mutable_key()) = "meta_def_status";
|
||||
*(p_metadata_props_2->mutable_value()) =
|
||||
std::to_string(static_cast<int>(p_meta_def->status()));
|
||||
// TODO: `MetaDef::type_and_shape_inference_function`.
|
||||
}
|
||||
auto p_parent_graph_proto = parent_graph.ToGraphProto();
|
||||
for (auto node_index : const_cast<IndexedSubGraph&>(sub_graph).Nodes()) {
|
||||
auto* p_node_proto = p_parent_graph_proto->mutable_node(static_cast<int>(node_index));
|
||||
auto* p_attr_proto = p_node_proto->add_attribute();
|
||||
p_attr_proto->set_name("parent_graph_node_index");
|
||||
p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_proto->set_i(node_index);
|
||||
*(p_func_proto->add_node()) = *p_node_proto;
|
||||
}
|
||||
#if 0
|
||||
// Alternative.
|
||||
for (const auto node_index : sub_graph.Nodes()) {
|
||||
const auto* p_node = parent_graph.GetNode(node_index);
|
||||
auto p_node_proto = ONNX_NAMESPACE::NodeProto::Create();
|
||||
// XXX
|
||||
p_node->ToProto(*p_node_proto, true);
|
||||
auto* p_attr_proto = p_node_proto->add_attribute();
|
||||
p_attr_proto->set_name("parent_graph_node_index");
|
||||
p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_proto->set_i(node_index);
|
||||
*(p_func_proto.add_node()) = *p_node_proto;
|
||||
}
|
||||
#endif
|
||||
auto* p_metadata_props_3 = p_func_proto->add_metadata_props();
|
||||
*(p_metadata_props_3->mutable_key()) = "schema_source";
|
||||
*(p_metadata_props_3->mutable_value()) =
|
||||
std::to_string(static_cast<uint8_t>(sub_graph.GetSchemaSource()));
|
||||
return p_func_proto;
|
||||
}
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> ConvertFunctionProtoToIndexedSubGraph(
|
||||
const std::unique_ptr<ONNX_NAMESPACE::FunctionProto>& p_func_proto) {
|
||||
auto p_isg = IndexedSubGraph::Create();
|
||||
// "meta_def_inputs_size" (optional) and "schema_source".
|
||||
int func_metadata_props_size = p_func_proto->metadata_props_size();
|
||||
// Precisely, func_metadata_props_size == 4, which implies
|
||||
// `IndexedSubGraph::meta_def_` is not null and `IndexedSubGraph::nodes` > 1.
|
||||
if (func_metadata_props_size > 1) {
|
||||
auto& prop0 = const_cast<ONNX_NAMESPACE::StringStringEntryProto&>(p_func_proto->metadata_props(0));
|
||||
int isg_meta_def_inputs_size = std::stoi(*(prop0.mutable_value()));
|
||||
auto p_meta_def = IndexedSubGraph_MetaDef::Create();
|
||||
p_meta_def->name() = p_func_proto->name();
|
||||
p_meta_def->domain() = p_func_proto->domain();
|
||||
auto& prop1 = const_cast<ONNX_NAMESPACE::StringStringEntryProto&>(p_func_proto->metadata_props(1));
|
||||
p_meta_def->since_version() = std::stoi(*(prop1.mutable_value()));
|
||||
auto& prop2 = const_cast<ONNX_NAMESPACE::StringStringEntryProto&>(p_func_proto->metadata_props(2));
|
||||
p_meta_def->status() = static_cast<ONNX_NAMESPACE::OperatorStatus>(std::stoi(*(prop2.mutable_value())));
|
||||
auto& meta_def_inputs = p_meta_def->inputs();
|
||||
for (int i = 0; i < isg_meta_def_inputs_size; i++) {
|
||||
meta_def_inputs.push_back(p_func_proto->input(i));
|
||||
}
|
||||
auto& meta_def_outputs = p_meta_def->outputs();
|
||||
for (int i = 0, l = p_func_proto->output_size(); i < l; i++) {
|
||||
meta_def_outputs.push_back(p_func_proto->output(i));
|
||||
}
|
||||
auto& meta_def_initializers = p_meta_def->constant_initializers();
|
||||
for (int i = isg_meta_def_inputs_size, l = p_func_proto->input_size(); i < l; i++) {
|
||||
meta_def_initializers.push_back(p_func_proto->input(i));
|
||||
}
|
||||
auto& meta_def_attrs = p_meta_def->attributes();
|
||||
for (int i = 0, l = p_func_proto->attribute_size(); i < l; i++) {
|
||||
meta_def_attrs.emplace(p_func_proto->attribute(i), p_func_proto->attribute_proto(i));
|
||||
}
|
||||
p_meta_def->doc_string() = p_func_proto->doc_string();
|
||||
// TODO: `IndexedSubGraph::type_and_shape_inference_function`.
|
||||
p_isg->SetMetaDef(std::move(p_meta_def));
|
||||
}
|
||||
auto& isg_nodes = p_isg->Nodes();
|
||||
for (int i = 0, l = p_func_proto->node_size(); i < l; i++) {
|
||||
const auto& node_proto = p_func_proto->node(i);
|
||||
isg_nodes.push_back(
|
||||
node_proto.attribute(const_cast<ONNX_NAMESPACE::NodeProto&>(node_proto).attribute_size() - 1).i());
|
||||
}
|
||||
auto schema_source = static_cast<IndexedSubGraph_SourceOfSchema>(
|
||||
std::stoi(*(const_cast<ONNX_NAMESPACE::StringStringEntryProto&>(p_func_proto->metadata_props(func_metadata_props_size - 1)).mutable_value())));
|
||||
p_isg->SetSchemaSource(schema_source);
|
||||
return p_isg;
|
||||
}
|
||||
|
||||
std::string SerializeCapabilities(
|
||||
const std::vector<std::unique_ptr<ComputeCapability>>& capability_ptrs,
|
||||
const Graph& graph) {
|
||||
std::stringstream ss;
|
||||
for (const auto& p : capability_ptrs) {
|
||||
auto& p_subgraph = p->SubGraph();
|
||||
auto p_func_proto = ConvertIndexedSubGraphToFunctionProto(*p_subgraph, graph);
|
||||
std::string func_proto_buf;
|
||||
p_func_proto->SerializeToString(func_proto_buf);
|
||||
size_t buf_len = func_proto_buf.length();
|
||||
ss.write(reinterpret_cast<const char*>(&buf_len), sizeof(buf_len));
|
||||
ss.write(func_proto_buf.data(), buf_len);
|
||||
}
|
||||
if (!ss.good()) {
|
||||
ORT_THROW("Serialization stream bad");
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void DeserializeCapabilities(const std::string& ser_capabilities,
|
||||
std::vector<std::unique_ptr<ComputeCapability>>& capability_ptrs) {
|
||||
std::istringstream ss(ser_capabilities);
|
||||
while (!ss.eof()) {
|
||||
size_t buf_len;
|
||||
ss.read(reinterpret_cast<char*>(&buf_len), sizeof(buf_len));
|
||||
std::string buf(buf_len, '\0');
|
||||
ss.read(&buf[0], buf_len);
|
||||
auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create();
|
||||
p_func_proto->ParseFromString(buf);
|
||||
auto p_subgraph = ConvertFunctionProtoToIndexedSubGraph(p_func_proto);
|
||||
capability_ptrs.push_back(ComputeCapability::Create(std::move(p_subgraph)));
|
||||
}
|
||||
}
|
||||
|
||||
std::string SerializeOrigialGraph(const GraphViewer& graph_viewer) {
|
||||
// XXX: Will Steps 1/2/3 suffice for restoring a model/graph later?
|
||||
// Any information loss or mismatch?
|
||||
// Step 1
|
||||
const Graph& orig_graph = graph_viewer.GetGraph();
|
||||
// Step 2
|
||||
const Model& orig_model = orig_graph.GetModel();
|
||||
// Step 3
|
||||
auto p_orig_model_proto = const_cast<Model&>(orig_model).ToProto();
|
||||
if (p_orig_model_proto->opset_import_size() == 0) {
|
||||
for (const auto& it : graph_viewer.DomainToVersionMap()) {
|
||||
auto* p_opset_import = p_orig_model_proto->add_opset_import();
|
||||
*(p_opset_import->mutable_domain()) = it.first;
|
||||
p_opset_import->set_version(it.second);
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json j_obj;
|
||||
if (p_orig_model_proto->opset_import_size() > 0) {
|
||||
for (int i = 0, n = p_orig_model_proto->opset_import_size(); i < n; ++i) {
|
||||
auto& op_set_id_proto = const_cast<ONNX_NAMESPACE::OperatorSetIdProto&>(p_orig_model_proto->opset_import(i));
|
||||
j_obj[*op_set_id_proto.mutable_domain()] = std::to_string(op_set_id_proto.version());
|
||||
}
|
||||
}
|
||||
j_obj["orig_graph_name"] = graph_viewer.Name();
|
||||
// TODO: platform dependency (Linux vs Windows).
|
||||
j_obj["orig_model_path"] = graph_viewer.ModelPath().string();
|
||||
|
||||
// XXX: `ModelProto::SerializeToString` will lose some info,
|
||||
// e.g., ModelProto.opset_import.
|
||||
std::string ser_buf;
|
||||
p_orig_model_proto->SerializeToString(ser_buf);
|
||||
j_obj["orig_model_proto_ser_str"] = ser_buf;
|
||||
|
||||
return j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace);
|
||||
}
|
||||
|
||||
// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc".
|
||||
ONNX_NAMESPACE::ModelProto* CreateEPContexModel(
|
||||
const GraphViewer& graph_viewer,
|
||||
const std::string& serialized_ctx_cache,
|
||||
const std::string& ctx_cache_file_loc,
|
||||
const int64_t embed_mode,
|
||||
const std::string& backend_cache_dir,
|
||||
const std::string& backend_cache_key,
|
||||
bool saving_orig_graph,
|
||||
const logging::Logger* p_logger) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context node";
|
||||
// Create a new graph/model, reusing the graph name,
|
||||
// the op-domain-to-opset-version map,
|
||||
// and the op schema registry of the current graph.
|
||||
// XXX: This approach (immediately below) has a memory fault issue (std::bad_alloc).
|
||||
// auto& ep_ctx_graph = graph_viewer.CreateModel(*p_logger)->MainGraph();
|
||||
// This apporach (immediately below) has no memory falut issue.
|
||||
auto p_temp_model = graph_viewer.CreateModel(*p_logger);
|
||||
auto& ep_ctx_graph = p_temp_model->MainGraph();
|
||||
|
||||
const auto& graph_inputs = graph_viewer.GetInputs();
|
||||
std::vector<NodeArg*> input_node_arg_ptrs;
|
||||
input_node_arg_ptrs.reserve(graph_inputs.size());
|
||||
// XXX: vs `GraphViewer::GetInputsIncludingInitializers()`.
|
||||
for (const auto* p_node_arg : graph_inputs) {
|
||||
auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg(
|
||||
p_node_arg->Name(), p_node_arg->TypeAsProto());
|
||||
input_node_arg_ptrs.push_back(&temp_node_arg);
|
||||
}
|
||||
const auto& graph_outputs = graph_viewer.GetOutputs();
|
||||
std::vector<NodeArg*> output_node_arg_ptrs;
|
||||
output_node_arg_ptrs.reserve(graph_outputs.size());
|
||||
for (const auto* p_node_arg : graph_outputs) {
|
||||
auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto());
|
||||
output_node_arg_ptrs.push_back(&temp_node_arg);
|
||||
}
|
||||
|
||||
// Attr "embed_mode".
|
||||
auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_0->set_name(kEmbedModeAttr);
|
||||
// p_attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_0->set_i(embed_mode);
|
||||
// Attr "ep_cache_context".
|
||||
auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_1->set_name(kEPCacheContextAttr);
|
||||
// p_attr_1->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
// Relative to the ONNX model file.
|
||||
p_attr_1->set_s(
|
||||
embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache);
|
||||
// Attr "source".
|
||||
auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_2->set_name(kSourceAttr);
|
||||
// p_attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
p_attr_2->set_s(kVitisAIExecutionProvider);
|
||||
// Attr "onnx_model_filename".
|
||||
auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_3->set_name(kONNXModelFileNameAttr);
|
||||
// p_attr_3->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
p_attr_3->set_s(graph_viewer.ModelPath().filename().string());
|
||||
// Attr "notes".
|
||||
auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_4->set_name(kNotesAttr);
|
||||
// p_attr_4->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
// FIXME: 2G-limit of ProtoBuf.
|
||||
if (saving_orig_graph) {
|
||||
p_attr_4->set_s(SerializeOrigialGraph(graph_viewer));
|
||||
} else {
|
||||
nlohmann::json j_obj;
|
||||
j_obj["backend_cache_dir"] = backend_cache_dir;
|
||||
j_obj["backend_cache_key"] = backend_cache_key;
|
||||
p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace));
|
||||
}
|
||||
|
||||
auto p_node_attrs = NodeAttributes::Create();
|
||||
constexpr int num_attrs = 5;
|
||||
p_node_attrs->reserve(num_attrs);
|
||||
p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0);
|
||||
p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1);
|
||||
p_node_attrs->emplace(kSourceAttr, *p_attr_2);
|
||||
p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3);
|
||||
p_node_attrs->emplace(kNotesAttr, *p_attr_4);
|
||||
|
||||
// Since we don't implement `IExecutionProvider::GetEpContextNodes()` and
|
||||
// thus don't leverage `CreateEpContextModel()` in the file "graph_partitioner.cc",
|
||||
// we specify a brand-new node name here.
|
||||
ep_ctx_graph.AddNode(kEPContextOpName, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain);
|
||||
|
||||
auto res_status = ep_ctx_graph.Resolve();
|
||||
ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage());
|
||||
LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved";
|
||||
|
||||
auto p_ep_ctx_graph_viewer = ep_ctx_graph.CreateGraphViewer();
|
||||
auto p_temp_model_2 = p_ep_ctx_graph_viewer->CreateModel(*p_logger);
|
||||
auto p_ep_ctx_model_proto = p_temp_model_2->ToProto();
|
||||
p_ep_ctx_graph_viewer->ToProto(*p_ep_ctx_model_proto->mutable_graph(), true, true);
|
||||
p_ep_ctx_model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
return p_ep_ctx_model_proto.release();
|
||||
}
|
||||
|
||||
// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h".
|
||||
void DumpEPContextModel(
|
||||
const std::unique_ptr<ONNX_NAMESPACE::ModelProto>& p_model_proto, const std::string& ep_ctx_model_file_loc) {
|
||||
std::fstream dump_stream(ep_ctx_model_file_loc, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
p_model_proto->SerializeToOstream(dump_stream);
|
||||
LOGS_DEFAULT(VERBOSE) << "[VitisAI EP] Dumped " << ep_ctx_model_file_loc;
|
||||
}
|
||||
|
||||
const Node* GetEPContextNodePtr(const Graph& graph) {
|
||||
// TODO: Support for multi-node EP context model.
|
||||
for (const auto* p_node : graph.Nodes()) {
|
||||
if (p_node->OpType() == kEPContextOp) {
|
||||
return p_node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool ValidateEPContextNode(const Graph& graph) {
|
||||
// TODO: Support for multi-node EP context model.
|
||||
const auto* p_node = GetEPContextNodePtr(graph);
|
||||
assert(p_node != nullptr);
|
||||
auto& attrs = p_node->GetAttributes();
|
||||
assert(attrs.count(kEmbedModeAttr) > 0);
|
||||
assert(attrs.count(kEPCacheContextAttr) > 0);
|
||||
assert(attrs.count(kSourceAttr) > 0);
|
||||
const auto& source_val = attrs.at(kSourceAttr).s();
|
||||
if (source_val == kVitisAIExecutionProvider) {
|
||||
return true;
|
||||
}
|
||||
size_t vitisai_len = std::strlen(kVitisAI);
|
||||
assert(source_val.length() == vitisai_len);
|
||||
for (size_t i = 0; i < vitisai_len; ++i) {
|
||||
assert(static_cast<unsigned char>(std::tolower(source_val[i])) == kVitisAI[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc".
|
||||
void CreateEPContexNodes(
|
||||
Graph* p_ep_ctx_graph,
|
||||
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
|
||||
const std::string& serialized_ctx_cache,
|
||||
const std::string& ctx_cache_file_loc,
|
||||
const int64_t embed_mode,
|
||||
const std::string& backend_cache_dir,
|
||||
const std::string& backend_cache_key,
|
||||
bool saving_orig_graph,
|
||||
const logging::Logger* p_logger) {
|
||||
LOGS_DEFAULT(VERBOSE) << "[VitisAI EP]Creating EP context nodes";
|
||||
int fused_index = 0;
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
Node& fused_node = fused_node_graph.fused_node;
|
||||
const auto& fused_name = fused_node.Name();
|
||||
const GraphViewer& graph_viewer = fused_node_graph.filtered_graph;
|
||||
// FIXME
|
||||
const auto& graph_inputs = graph_viewer.GetInputs();
|
||||
std::vector<NodeArg*> input_node_arg_ptrs;
|
||||
input_node_arg_ptrs.reserve(graph_inputs.size());
|
||||
// XXX: vs `GraphViewer::GetInputsIncludingInitializers()`.
|
||||
for (const auto* p_node_arg : graph_inputs) {
|
||||
auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg(
|
||||
p_node_arg->Name(), p_node_arg->TypeAsProto());
|
||||
input_node_arg_ptrs.push_back(&temp_node_arg);
|
||||
}
|
||||
const auto& graph_outputs = graph_viewer.GetOutputs();
|
||||
std::vector<NodeArg*> output_node_arg_ptrs;
|
||||
output_node_arg_ptrs.reserve(graph_outputs.size());
|
||||
for (const auto* p_node_arg : graph_outputs) {
|
||||
auto& temp_node_arg = p_ep_ctx_graph->GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto());
|
||||
output_node_arg_ptrs.push_back(&temp_node_arg);
|
||||
}
|
||||
|
||||
auto p_node_attrs = NodeAttributes::Create();
|
||||
if (fused_index == 0) {
|
||||
p_node_attrs->reserve(7);
|
||||
// Attr "ep_cache_context".
|
||||
auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_1->set_name(kEPCacheContextAttr);
|
||||
p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
// Relative to the ONNX model file.
|
||||
p_attr_1->set_s(
|
||||
embed_mode == 0 ? fs::path(ctx_cache_file_loc).filename().string() : serialized_ctx_cache);
|
||||
p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1);
|
||||
// Attr "notes".
|
||||
auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_4->set_name(kNotesAttr);
|
||||
p_attr_4->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
// FIXME: 2G-limit of ProtoBuf.
|
||||
if (saving_orig_graph) {
|
||||
p_attr_4->set_s(SerializeOrigialGraph(graph_viewer));
|
||||
} else {
|
||||
nlohmann::json j_obj;
|
||||
j_obj["backend_cache_dir"] = backend_cache_dir;
|
||||
j_obj["backend_cache_key"] = backend_cache_key;
|
||||
p_attr_4->set_s(j_obj.dump(-1, ' ', false, nlohmann::json::error_handler_t::replace));
|
||||
}
|
||||
p_node_attrs->emplace(kNotesAttr, *p_attr_4);
|
||||
// Attr "main_context".
|
||||
auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_5->set_name(kMainContextAttr);
|
||||
p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_5->set_i(1);
|
||||
p_node_attrs->emplace(kMainContextAttr, *p_attr_5);
|
||||
} else {
|
||||
p_node_attrs->reserve(5);
|
||||
// Attr "main_context".
|
||||
auto p_attr_5 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_5->set_name(kMainContextAttr);
|
||||
p_attr_5->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_5->set_i(0);
|
||||
p_node_attrs->emplace(kMainContextAttr, *p_attr_5);
|
||||
}
|
||||
// Attr "embed_mode".
|
||||
auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_0->set_name(kEmbedModeAttr);
|
||||
p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
p_attr_0->set_i(embed_mode);
|
||||
p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0);
|
||||
// Attr "source".
|
||||
auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_2->set_name(kSourceAttr);
|
||||
p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
p_attr_2->set_s(kVitisAIExecutionProvider);
|
||||
p_node_attrs->emplace(kSourceAttr, *p_attr_2);
|
||||
// Attr "onnx_model_filename".
|
||||
auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_3->set_name(kONNXModelFileNameAttr);
|
||||
p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
p_attr_3->set_s(graph_viewer.ModelPath().filename().string());
|
||||
p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3);
|
||||
// Attr "partition_name".
|
||||
auto p_attr_6 = ONNX_NAMESPACE::AttributeProto::Create();
|
||||
p_attr_6->set_name(kPartitionNameAttr);
|
||||
p_attr_6->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
|
||||
p_attr_6->set_s(fused_name);
|
||||
p_node_attrs->emplace(kPartitionNameAttr, *p_attr_6);
|
||||
|
||||
p_ep_ctx_graph->AddNode(fused_name, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain);
|
||||
|
||||
++fused_index;
|
||||
}
|
||||
auto res_status = p_ep_ctx_graph->Resolve();
|
||||
ORT_ENFORCE(res_status.IsOK(), res_status.ErrorMessage());
|
||||
LOGS_DEFAULT(VERBOSE) << "Created EP context model graph resolved";
|
||||
}
|
||||
|
||||
std::string RetrieveEPContextCache(
|
||||
const Graph& graph, const PathString& ep_ctx_model_loc, bool binary_mode) {
|
||||
// TODO: Support for multi-node EP context model.
|
||||
const auto* p_node = GetEPContextNodePtr(graph);
|
||||
const auto& attrs = p_node->GetAttributes();
|
||||
int64_t embed_mode = attrs.at(kEmbedModeAttr).i();
|
||||
const std::string& ep_ctx_cache = attrs.at(kEPCacheContextAttr).s();
|
||||
if (embed_mode) {
|
||||
return ep_ctx_cache;
|
||||
}
|
||||
fs::path ep_ctx_fs_path(ep_ctx_model_loc);
|
||||
// Attr "ep_cache_context" stores a relative path.
|
||||
ep_ctx_fs_path.replace_filename(fs::path(ep_ctx_cache));
|
||||
// TODO: Validaion of the file location to make sure security is met.
|
||||
if (!fs::exists(ep_ctx_fs_path) || !fs::is_regular_file(ep_ctx_fs_path)) {
|
||||
ORT_THROW("File for EP context cache is missing");
|
||||
}
|
||||
auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in;
|
||||
std::ifstream ifs(ep_ctx_fs_path.string().c_str(), open_mode);
|
||||
if (!ifs.is_open()) {
|
||||
ORT_THROW("Exception opening EP context cache file");
|
||||
}
|
||||
ifs.seekg(0, ifs.end);
|
||||
std::streampos cache_len = ifs.tellg();
|
||||
if (cache_len == -1) {
|
||||
ifs.close();
|
||||
ORT_THROW("Error when operating EP context cache file");
|
||||
} else if (cache_len == 0) {
|
||||
ifs.close();
|
||||
LOGS_DEFAULT(WARNING) << "Empty EP context cache file: " << ep_ctx_fs_path.string();
|
||||
return "";
|
||||
}
|
||||
ifs.seekg(0, ifs.beg);
|
||||
char* buf = new char[static_cast<size_t>(cache_len)];
|
||||
ifs.read(buf, cache_len);
|
||||
if (!ifs.good()) {
|
||||
ifs.close();
|
||||
ORT_THROW("Exception reading EP context cache file");
|
||||
}
|
||||
ifs.close();
|
||||
std::string cache_payload(buf);
|
||||
delete[] buf;
|
||||
return cache_payload;
|
||||
}
|
||||
|
||||
void RetrieveBackendCacheInfo(const Graph& graph, std::string& cache_dir, std::string& cache_key) {
|
||||
// TODO: Support for multi-node EP context model.
|
||||
const auto* p_node = GetEPContextNodePtr(graph);
|
||||
if (p_node == nullptr) {
|
||||
LOGS_DEFAULT(WARNING) << "Failed to retrieve cache info due to no EP context nodes";
|
||||
return;
|
||||
}
|
||||
const auto& attrs = p_node->GetAttributes();
|
||||
const auto& notes_str = attrs.at(kNotesAttr).s();
|
||||
nlohmann::json j_obj = nlohmann::json::parse(notes_str);
|
||||
cache_dir = j_obj["backend_cache_dir"].get<std::string>();
|
||||
cache_key = j_obj["backend_cache_key"].get<std::string>();
|
||||
if (cache_dir.empty()) {
|
||||
LOGS_DEFAULT(WARNING) << "Retrieved backend cache dir empty";
|
||||
}
|
||||
if (cache_key.empty()) {
|
||||
LOGS_DEFAULT(WARNING) << "Retrieved backend cache key empty";
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<GraphViewer> RetrieveOriginalGraph(const Graph& ep_ctx_graph) {
|
||||
// TODO: Support for multi-node EP context model.
|
||||
const auto* p_node = GetEPContextNodePtr(ep_ctx_graph);
|
||||
const auto& attrs = p_node->GetAttributes();
|
||||
const auto& notes_str = attrs.at(kNotesAttr).s();
|
||||
nlohmann::json j_obj = nlohmann::json::parse(notes_str);
|
||||
|
||||
const auto& orig_model_path = j_obj["orig_model_path"].get<std::string>();
|
||||
bool model_loaded = false;
|
||||
auto p_model_proto = ONNX_NAMESPACE::ModelProto::Create();
|
||||
if (!orig_model_path.empty() && fs::exists(orig_model_path) && fs::is_regular_file(orig_model_path)) {
|
||||
auto load_status = Model::Load(ToPathString(orig_model_path), *p_model_proto);
|
||||
model_loaded = load_status.IsOK();
|
||||
}
|
||||
if (!model_loaded) {
|
||||
p_model_proto->ParseFromString(j_obj["orig_model_proto_ser_str"].get<std::string>());
|
||||
if (p_model_proto->opset_import_size() == 0) {
|
||||
for (auto& elem : j_obj.items()) {
|
||||
if (elem.key() == "orig_model_path" || elem.key() == "orig_graph_name" || elem.key() == "orig_model_proto_ser_str") {
|
||||
continue;
|
||||
}
|
||||
auto* p_op_set_id_proto = p_model_proto->add_opset_import();
|
||||
*(p_op_set_id_proto->mutable_domain()) = elem.key();
|
||||
p_op_set_id_proto->set_version(std::stoll(elem.value().get<std::string>()));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto& logger = logging::LoggingManager::DefaultLogger();
|
||||
auto p_model = Model::Create(std::move(*p_model_proto), ToPathString(orig_model_path), nullptr, logger);
|
||||
auto& graph = p_model->MainGraph();
|
||||
graph.ToGraphProto()->set_name(j_obj["orig_graph_name"].get<std::string>());
|
||||
|
||||
return graph.CreateGraphViewer();
|
||||
}
|
||||
|
||||
bool GraphHasEPContextNode(const Graph& graph) {
|
||||
size_t vitisai_len = std::strlen(kVitisAI);
|
||||
for (const auto* p_node : graph.Nodes()) {
|
||||
if (p_node->OpType() != kEPContextOp) {
|
||||
continue;
|
||||
}
|
||||
const auto& attrs = p_node->GetAttributes();
|
||||
if (attrs.count(kSourceAttr) == 0) {
|
||||
continue;
|
||||
}
|
||||
const auto& source_val = attrs.at(kSourceAttr).s();
|
||||
if (source_val == kVitisAIExecutionProvider) {
|
||||
return true;
|
||||
}
|
||||
if (source_val.length() != vitisai_len) {
|
||||
continue;
|
||||
}
|
||||
size_t j = 0;
|
||||
do {
|
||||
if (static_cast<unsigned char>(std::tolower(source_val[j])) != kVitisAI[j]) {
|
||||
break;
|
||||
}
|
||||
++j;
|
||||
} while (j < vitisai_len);
|
||||
if (j == vitisai_len) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FusedGraphHasEPContextNode(
|
||||
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs) {
|
||||
for (const auto& fused_node_graph : fused_nodes_and_graphs) {
|
||||
bool has_node = GraphHasEPContextNode(fused_node_graph.filtered_graph.get().GetGraph());
|
||||
if (has_node) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const fs::path& GetTopLevelModelPath(const GraphViewer& graph_viewer) {
|
||||
const auto& graph = graph_viewer.GetGraph();
|
||||
const Graph* p_graph = &graph;
|
||||
while (p_graph->IsSubgraph()) {
|
||||
p_graph = p_graph->ParentGraph();
|
||||
}
|
||||
return p_graph->ModelPath();
|
||||
}
|
||||
|
||||
bool GetEPContextModelFileLocation(
|
||||
const std::string& ep_ctx_model_path_cfg,
|
||||
const PathString& model_path_str,
|
||||
bool is_ep_ctx_model,
|
||||
PathString& ep_ctx_model_file_loc) {
|
||||
if (!ep_ctx_model_file_loc.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (!ep_ctx_model_path_cfg.empty()) {
|
||||
ep_ctx_model_file_loc = ToPathString(ep_ctx_model_path_cfg);
|
||||
} else if (!model_path_str.empty()) {
|
||||
if (is_ep_ctx_model) {
|
||||
ep_ctx_model_file_loc = model_path_str;
|
||||
} else {
|
||||
// Two alternatives for this case.
|
||||
// Alternative 1:
|
||||
// 1) Implement/override the method `IExecutionProvider::GetEpContextNodes()`.
|
||||
// 2) And follow how the default path is implemented in `CreateEpContextModel()`
|
||||
// in the file "graph_partitioner.cc".
|
||||
// 3) Model dump is not required.
|
||||
// Alternative 2:
|
||||
// 1) Do NOT implement/override `IExecutionProvider::GetEpContextNodes()`.
|
||||
// 2) No need to follow `CreateEpContextModel()` in the file "graph_partitioner.cc",
|
||||
// freely implement what the default path is like.
|
||||
// 3) Model dump is required.
|
||||
#if 0
|
||||
ep_ctx_model_file_loc = model_path_str + ToPathString("_ctx.onnx");
|
||||
#endif
|
||||
#if 1
|
||||
fs::path model_fs_path(model_path_str);
|
||||
fs::path ep_ctx_model_fs_path(model_fs_path.parent_path() / model_fs_path.stem());
|
||||
ep_ctx_model_fs_path += fs::path("_ctx.onnx");
|
||||
ep_ctx_model_file_loc = ToPathString(ep_ctx_model_fs_path.string());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return !ep_ctx_model_file_loc.empty();
|
||||
}
|
||||
|
||||
// The file for EP context cache is in the same folder as the EP context model file.
|
||||
PathString GetEPContextCacheFileLocation(
|
||||
const PathString& ep_ctx_model_file_loc, const PathString& model_path_str) {
|
||||
if (!ep_ctx_model_file_loc.empty()) {
|
||||
fs::path ep_ctx_model_fs_path(ep_ctx_model_file_loc);
|
||||
fs::path ep_ctx_cache_fs_path(ep_ctx_model_fs_path.parent_path() / ep_ctx_model_fs_path.stem());
|
||||
ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin");
|
||||
return ToPathString(ep_ctx_cache_fs_path.string());
|
||||
}
|
||||
fs::path model_fs_path(model_path_str);
|
||||
fs::path ep_ctx_cache_fs_path(model_fs_path.parent_path() / model_fs_path.stem());
|
||||
ep_ctx_cache_fs_path += fs::path("__ep_ctx_cache.bin");
|
||||
return ToPathString(ep_ctx_cache_fs_path.string());
|
||||
}
|
||||
|
||||
std::string Slurp(const fs::path& file_location, bool binary_mode) {
|
||||
// std::filesystem::value_type == onnxruntime::PathChar == ORTCHAR_T
|
||||
// std::filesystem::string_type == onnxruntime::PathString
|
||||
// const char* location_str = PathToUTF8String(file_location.native()).c_str();
|
||||
std::ifstream ifs;
|
||||
ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
|
||||
std::stringstream ss;
|
||||
try {
|
||||
auto open_mode = binary_mode ? (std::ios::in | std::ios::binary) : std::ios::in;
|
||||
ifs.open(file_location.string().c_str(), open_mode);
|
||||
ss << ifs.rdbuf();
|
||||
if (!ss.good()) {
|
||||
LOGS_DEFAULT(WARNING) << "Failed to write to stream";
|
||||
}
|
||||
ifs.close();
|
||||
} catch (std::system_error& se) {
|
||||
LOGS_DEFAULT(WARNING) << "Failed to read " << file_location << ": " << se.code().message();
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -53,6 +53,8 @@ struct OrtVitisAIEpAPI {
|
|||
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_with_options)(
|
||||
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options);
|
||||
uint32_t (*vaip_get_version)();
|
||||
void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data);
|
||||
void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path);
|
||||
void Ensure() {
|
||||
if (handle_)
|
||||
return;
|
||||
|
|
@ -77,6 +79,8 @@ struct OrtVitisAIEpAPI {
|
|||
}
|
||||
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
|
||||
(void**)&vaip_get_version);
|
||||
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache));
|
||||
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -122,13 +126,7 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config
|
|||
|
||||
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(
|
||||
const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) {
|
||||
#ifndef _WIN32
|
||||
auto model_path = graph_viewer.ModelPath().string();
|
||||
#else
|
||||
using convert_t = std::codecvt_utf8<wchar_t>;
|
||||
std::wstring_convert<convert_t, wchar_t> strconverter;
|
||||
auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().string());
|
||||
#endif
|
||||
auto model_path = PathToUTF8String(ToPathString(graph_viewer.ModelPath().string()));
|
||||
if (s_library_vitisaiep.compile_onnx_model_with_options) {
|
||||
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
|
||||
} else {
|
||||
|
|
@ -137,6 +135,17 @@ vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> c
|
|||
}
|
||||
}
|
||||
|
||||
void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data) {
|
||||
const std::string& model_path = PathToUTF8String(model_path_str);
|
||||
const onnxruntime::Graph& graph = graph_viewer.GetGraph();
|
||||
const auto json_str = config_to_json_str(options);
|
||||
s_library_vitisaiep.get_backend_compilation_cache(model_path, graph, json_str.c_str(), compiler_codes, cache_dir, cache_key, cache_data);
|
||||
}
|
||||
|
||||
void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path) {
|
||||
s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path);
|
||||
}
|
||||
|
||||
struct MyCustomOpKernel : OpKernel {
|
||||
MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
|
||||
op_kernel_ =
|
||||
|
|
@ -218,7 +227,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
|
|||
auto& logger = logging::LoggingManager::DefaultLogger();
|
||||
auto& model = const_cast<onnxruntime::Model&>(const_model);
|
||||
auto model_proto = model.ToProto();
|
||||
auto file_path = model.MainGraph().ModelPath().string();
|
||||
auto file_path = ToPathString(model.MainGraph().ModelPath().string());
|
||||
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()};
|
||||
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
|
||||
auto status = ret->MainGraph().Resolve();
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
#pragma once
|
||||
|
||||
// Standard headers/libs.
|
||||
#include <filesystem>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
// 1st-party headers/libs.
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
constexpr const uint8_t kXCCode = 1;
|
||||
constexpr const uint8_t kDDCode = 2;
|
||||
constexpr const uint8_t kVCode = 4;
|
||||
|
||||
static constexpr const char* kEPContextOp = "EPContext";
|
||||
static constexpr const char* kMainContextAttr = "main_context";
|
||||
static constexpr const char* kEPCacheContextAttr = "ep_cache_context";
|
||||
static constexpr const char* kEmbedModeAttr = "embed_mode";
|
||||
static constexpr const char* kPartitionNameAttr = "partition_name";
|
||||
static constexpr const char* kSourceAttr = "source";
|
||||
static constexpr const char* kEPSDKVersionAttr = "ep_sdk_version";
|
||||
static constexpr const char* kONNXModelFileNameAttr = "onnx_model_filename";
|
||||
static constexpr const char* kNotesAttr = "notes";
|
||||
static constexpr const char* kEPContextOpDomain = "com.microsoft";
|
||||
static constexpr const char* kEPContextOpName = "VitisAIEPContextOp";
|
||||
|
||||
std::unique_ptr<ONNX_NAMESPACE::FunctionProto>
|
||||
ConvertIndexedSubGraphToFunctionProto(const IndexedSubGraph&, const Graph&);
|
||||
|
||||
std::unique_ptr<IndexedSubGraph> ConvertFunctionProtoToIndexedSubGraph(
|
||||
const std::unique_ptr<ONNX_NAMESPACE::FunctionProto>&);
|
||||
|
||||
std::string SerializeCapabilities(
|
||||
const std::vector<std::unique_ptr<ComputeCapability>>&, const Graph&);
|
||||
|
||||
void DeserializeCapabilities(
|
||||
const std::string&, std::vector<std::unique_ptr<ComputeCapability>>&);
|
||||
|
||||
std::string SerializeOrigialGraph(const GraphViewer&);
|
||||
|
||||
// Ref.: `CreateEpContextModel()` in the file "graph_partitioner.cc".
|
||||
ONNX_NAMESPACE::ModelProto* CreateEPContexModel(const GraphViewer&, const std::string&, const std::string&, const int64_t,
|
||||
const std::string&, const std::string&, bool, const logging::Logger*);
|
||||
|
||||
// Ref.: `static common::Status Save(Model& model, int fd)` in the file "model.h".
|
||||
void DumpEPContextModel(const std::unique_ptr<ONNX_NAMESPACE::ModelProto>&, const std::string&);
|
||||
|
||||
const Node* GetEPContextNodePtr(const Graph&);
|
||||
|
||||
bool ValidateEPContextNode(const Graph&);
|
||||
|
||||
void CreateEPContexNodes(Graph*, const std::vector<IExecutionProvider::FusedNodeAndGraph>&, const std::string&, const std::string&,
|
||||
const int64_t, const std::string&, const std::string&, bool, const logging::Logger*);
|
||||
|
||||
std::string RetrieveEPContextCache(const Graph&, const PathString&, bool binary_mode = true);
|
||||
|
||||
void RetrieveBackendCacheInfo(const Graph&, std::string&, std::string&);
|
||||
|
||||
std::unique_ptr<GraphViewer> RetrieveOriginalGraph(const Graph&);
|
||||
|
||||
bool GraphHasEPContextNode(const Graph&);
|
||||
|
||||
bool FusedGraphHasEPContextNode(
|
||||
const std::vector<IExecutionProvider::FusedNodeAndGraph>&);
|
||||
|
||||
const fs::path& GetTopLevelModelPath(const GraphViewer&);
|
||||
|
||||
bool GetEPContextModelFileLocation(
|
||||
const std::string&, const PathString&, bool, PathString&);
|
||||
|
||||
// The file for EP context cache is in the same folder as the EP context model file.
|
||||
PathString GetEPContextCacheFileLocation(const PathString&, const PathString&);
|
||||
|
||||
std::string Slurp(const fs::path&, bool binary_mode = false);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -14,3 +14,5 @@ void initialize_vitisai_ep();
|
|||
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options);
|
||||
std::shared_ptr<onnxruntime::KernelRegistry> get_kernel_registry_vitisaiep();
|
||||
const std::vector<OrtCustomOpDomain*>& get_domains_vitisaiep();
|
||||
void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data);
|
||||
void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path);
|
||||
|
|
|
|||
|
|
@ -2,22 +2,43 @@
|
|||
// Licensed under the MIT License.
|
||||
#include "vitisai_execution_provider.h"
|
||||
|
||||
// Standard headers/libs.
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <filesystem>
|
||||
|
||||
// 1st-party headers/libs.
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "core/common/exceptions.h"
|
||||
|
||||
#include "vaip/capability.h"
|
||||
#include "vaip/global_api.h"
|
||||
#include "ep_context_utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace onnxruntime {
|
||||
constexpr const char* VITISAI = "VITISAI";
|
||||
|
||||
VitisAIExecutionProvider::VitisAIExecutionProvider(
|
||||
const ProviderOptions& info)
|
||||
// const ProviderOptions& info, const SessionOptions* p_sess_opts)
|
||||
: IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) {
|
||||
CreateKernelRegistry();
|
||||
|
||||
auto it = info_.find("ep_context_enable");
|
||||
ep_ctx_enabled_ = it != info_.end() && it->second == "1";
|
||||
it = info_.find("ep_context_embed_mode");
|
||||
ep_ctx_embed_mode_ = it != info_.end() && it->second != "0";
|
||||
// ep_ctx_embed_mode_ = it == info_.end() || it->second != "0";
|
||||
it = info_.find("ep_context_file_path");
|
||||
ep_ctx_model_path_cfg_ = it == info_.end() ? "" : it->second;
|
||||
LOGS_DEFAULT(VERBOSE) << "EP Context cache enabled: " << ep_ctx_enabled_;
|
||||
LOGS_DEFAULT(VERBOSE) << "EP context cache embed mode: " << ep_ctx_embed_mode_;
|
||||
LOGS_DEFAULT(VERBOSE) << "User specified EP context cache path: " << ep_ctx_model_path_cfg_;
|
||||
}
|
||||
|
||||
void VitisAIExecutionProvider::CreateKernelRegistry() {
|
||||
|
|
@ -30,9 +51,115 @@ void VitisAIExecutionProvider::CreateKernelRegistry() {
|
|||
|
||||
std::shared_ptr<KernelRegistry> VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); }
|
||||
|
||||
// This method is called after both `GetComputeCapabilityOps()` and `Compile()`.
|
||||
// This timing is required to work with both compilation-based EPs and non-compilation-based EPs.
|
||||
const InlinedVector<const Node*> VitisAIExecutionProvider::GetEpContextNodes() const {
|
||||
InlinedVector<const Node*> ep_context_node_ptrs;
|
||||
// All preconditions are supposed to have happened.
|
||||
if (p_ep_ctx_model_) {
|
||||
auto& graph = p_ep_ctx_model_->MainGraph();
|
||||
for (const auto* p_node : graph.Nodes()) {
|
||||
ep_context_node_ptrs.push_back(p_node);
|
||||
}
|
||||
}
|
||||
return ep_context_node_ptrs;
|
||||
}
|
||||
|
||||
void VitisAIExecutionProvider::LoadEPContexModelFromFile() const {
|
||||
// XXX: should "p_ep_ctx_model_" be checked or not?
|
||||
if (!p_ep_ctx_model_ && !ep_ctx_model_file_loc_.empty()) {
|
||||
auto status = Model::Load(ep_ctx_model_file_loc_, *p_ep_ctx_model_proto_);
|
||||
if (!status.IsOK()) {
|
||||
ORT_THROW("Loading EP context model failed from ", PathToUTF8String(ep_ctx_model_file_loc_));
|
||||
}
|
||||
p_ep_ctx_model_ = Model::Create(std::move(*p_ep_ctx_model_proto_), ep_ctx_model_file_loc_, nullptr, *GetLogger());
|
||||
LOGS_DEFAULT(VERBOSE) << "Loaded EP context model from: " << PathToUTF8String(ep_ctx_model_file_loc_);
|
||||
} else if (ep_ctx_model_file_loc_.empty()) {
|
||||
LOGS_DEFAULT(WARNING) << "Cannot load an EP-context model due to bad file path";
|
||||
}
|
||||
}
|
||||
|
||||
void VitisAIExecutionProvider::PrepareEPContextEnablement(
|
||||
const onnxruntime::GraphViewer& graph_viewer) const {
|
||||
if (model_path_str_.empty()) {
|
||||
// TODO: platform dependency (Linux vs Windows).
|
||||
model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string());
|
||||
}
|
||||
std::string backend_cache_dir, backend_cache_key;
|
||||
get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_);
|
||||
info_["cacheDir"] = backend_cache_dir;
|
||||
info_["cacheKey"] = backend_cache_key;
|
||||
// Create a new model, reusing the graph name, the op-domain-to-opset-version map,
|
||||
// the op schema registry of the current graph, etc.
|
||||
p_ep_ctx_model_ = graph_viewer.CreateModel(*GetLogger());
|
||||
LOGS_DEFAULT(VERBOSE) << "Container model created";
|
||||
}
|
||||
|
||||
void VitisAIExecutionProvider::FulfillEPContextEnablement(
|
||||
const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs) {
|
||||
auto& ep_ctx_graph = p_ep_ctx_model_->MainGraph();
|
||||
if (!ep_ctx_embed_mode_) {
|
||||
auto ep_ctx_cache_path_str = GetEPContextCacheFileLocation(ep_ctx_model_file_loc_, model_path_str_);
|
||||
std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc);
|
||||
if (!ep_ctx_cache_ofs.is_open()) {
|
||||
ORT_THROW("Failed to open a file to write EP context cache: ", ep_ctx_cache_path_str.c_str());
|
||||
}
|
||||
ep_ctx_cache_ofs.write(backend_cache_data_.c_str(), backend_cache_data_.length());
|
||||
if (!ep_ctx_cache_ofs.good()) {
|
||||
ep_ctx_cache_ofs.close();
|
||||
ORT_THROW("Exception writing EP context cache file: ", ep_ctx_cache_path_str.c_str());
|
||||
}
|
||||
ep_ctx_cache_ofs.close();
|
||||
CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, "", PathToUTF8String(ep_ctx_cache_path_str), 0, info_.at("cacheDir"), info_.at("cacheKey"), false, GetLogger());
|
||||
} else {
|
||||
CreateEPContexNodes(&ep_ctx_graph, fused_nodes_and_graphs, backend_cache_data_, "", 1, info_["cacheDir"], info_["cacheKey"], false, GetLogger());
|
||||
}
|
||||
if (GraphHasEPContextNode(ep_ctx_graph)) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Created model has EP context nodes";
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "No EP eontext nodes created";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> VitisAIExecutionProvider::GetCapability(
|
||||
const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const {
|
||||
if (graph.IsSubgraph()) {
|
||||
const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const {
|
||||
bool is_ep_ctx_model = GraphHasEPContextNode(graph_viewer.GetGraph());
|
||||
// TODO: platform dependency (Linux vs Windows).
|
||||
model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string());
|
||||
if (GetEPContextModelFileLocation(
|
||||
ep_ctx_model_path_cfg_, model_path_str_, is_ep_ctx_model, ep_ctx_model_file_loc_)) {
|
||||
if (is_ep_ctx_model) {
|
||||
LOGS_DEFAULT(VERBOSE) << "An EP context model passed in";
|
||||
ValidateEPContextNode(graph_viewer.GetGraph());
|
||||
std::string cache_dir, cache_key;
|
||||
RetrieveBackendCacheInfo(graph_viewer.GetGraph(), cache_dir, cache_key);
|
||||
info_["cacheDir"] = cache_dir;
|
||||
info_["cacheKey"] = cache_key;
|
||||
LOGS_DEFAULT(VERBOSE) << "Trying getting compilation cache from " << PathToUTF8String(ep_ctx_model_file_loc_);
|
||||
auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph(), ep_ctx_model_file_loc_, false);
|
||||
restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string());
|
||||
} else {
|
||||
if (fs::exists(ep_ctx_model_file_loc_) && fs::is_regular_file(ep_ctx_model_file_loc_) && ep_ctx_enabled_) {
|
||||
ORT_THROW("The inference session was created with a normal ONNX model but a model file with EP context cache exists at ",
|
||||
PathToUTF8String(ep_ctx_model_file_loc_), ". Please remove the EP context model manually if you want to re-generate it.");
|
||||
// Disable the flexibility implemented below by throwing an exception.
|
||||
// Now the code below is unreachable but DCE will take care of it.
|
||||
// We might want to re-enable it in future, so we keep it as is.
|
||||
LoadEPContexModelFromFile();
|
||||
ValidateEPContextNode(p_ep_ctx_model_->MainGraph());
|
||||
std::string cache_dir, cache_key;
|
||||
RetrieveBackendCacheInfo(p_ep_ctx_model_->MainGraph(), cache_dir, cache_key);
|
||||
info_["cacheDir"] = cache_dir;
|
||||
info_["cacheKey"] = cache_key;
|
||||
auto ep_ctx_payload = RetrieveEPContextCache(p_ep_ctx_model_->MainGraph(), ep_ctx_model_file_loc_, false);
|
||||
restore_backend_compilation_cache(cache_dir, cache_key, ep_ctx_payload, graph_viewer.ModelPath().string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "Failed to get EP context model file location";
|
||||
}
|
||||
|
||||
if (graph_viewer.IsSubgraph()) {
|
||||
// VITIS AI EP not support sungraph. Assigned to CPU.
|
||||
return {};
|
||||
}
|
||||
|
|
@ -40,13 +167,16 @@ std::vector<std::unique_ptr<ComputeCapability>> VitisAIExecutionProvider::GetCap
|
|||
// Only compiling a model once is currently supported
|
||||
return {};
|
||||
}
|
||||
execution_providers_ = std::make_unique<my_ep_t>(compile_onnx_model(graph, *GetLogger(), info_));
|
||||
auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_);
|
||||
execution_providers_ = std::make_unique<my_ep_t>(compile_onnx_model(graph_viewer, *GetLogger(), info_));
|
||||
auto result = vaip::GetComputeCapabilityOps(graph_viewer, execution_providers_.get(), vitisai_optypes_);
|
||||
size_t index = 0u;
|
||||
for (auto& ep : **execution_providers_) {
|
||||
result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph, ep.get(), index));
|
||||
result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph_viewer, ep.get(), index));
|
||||
index = index + 1;
|
||||
}
|
||||
if (ep_ctx_enabled_ && !is_ep_ctx_model) {
|
||||
PrepareEPContextEnablement(graph_viewer);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -74,6 +204,10 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndG
|
|||
};
|
||||
node_compute_funcs.push_back(compute_info);
|
||||
}
|
||||
if (ep_ctx_enabled_ && p_ep_ctx_model_) {
|
||||
FulfillEPContextEnablement(fused_nodes_and_graphs);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,14 +3,18 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
// Standard headers/libs.
|
||||
#include <ctime>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
// 1st-party headers/libs.
|
||||
// #include "core/framework/session_options.h"
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
|
||||
// we cannot include vaip/vaip.hpp here because header file referred by
|
||||
// onnxruntime_pybind_state_common.cc
|
||||
|
|
@ -24,9 +28,11 @@ namespace onnxruntime {
|
|||
class VitisAIExecutionProvider : public IExecutionProvider {
|
||||
public:
|
||||
explicit VitisAIExecutionProvider(const ProviderOptions& info);
|
||||
// explicit VitisAIExecutionProvider(const ProviderOptions& info,
|
||||
// const SessionOptions* p_sess_opts = nullptr);
|
||||
~VitisAIExecutionProvider() = default;
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>> GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
std::vector<std::unique_ptr<ComputeCapability>> GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const override;
|
||||
|
||||
int GetDeviceId() const { return 0; }
|
||||
|
|
@ -35,16 +41,34 @@ class VitisAIExecutionProvider : public IExecutionProvider {
|
|||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
|
||||
|
||||
// This method is called after both `GetComputeCapabilityOps()` and `Compile()`.
|
||||
// This timing is required to work with both compliation-based EPs and non-compilation-based EPs.
|
||||
const InlinedVector<const Node*> GetEpContextNodes() const override;
|
||||
|
||||
private:
|
||||
void CreateKernelRegistry();
|
||||
using my_ep_t = vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>>;
|
||||
using my_ep_uptr_t = std::shared_ptr<my_ep_t>;
|
||||
// we have to hide the implementation by forward declaration.
|
||||
mutable my_ep_uptr_t execution_providers_;
|
||||
ProviderOptions info_;
|
||||
mutable ProviderOptions info_;
|
||||
std::vector<OrtCustomOpDomain*> custom_op_domains_;
|
||||
std::shared_ptr<KernelRegistry> registry_;
|
||||
std::set<std::string> vitisai_optypes_;
|
||||
// EP context related.
|
||||
bool ep_ctx_enabled_ = false;
|
||||
bool ep_ctx_embed_mode_ = true;
|
||||
std::string ep_ctx_model_path_cfg_{""};
|
||||
mutable std::string backend_cache_data_{""};
|
||||
mutable PathString model_path_str_{};
|
||||
mutable PathString ep_ctx_model_file_loc_{};
|
||||
mutable std::unique_ptr<onnxruntime::Model> p_ep_ctx_model_;
|
||||
mutable std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_ep_ctx_model_proto_;
|
||||
// It might need to be called before loading
|
||||
// the EP context model that is compiled AOT/offline.
|
||||
void LoadEPContexModelFromFile() const;
|
||||
void PrepareEPContextEnablement(const onnxruntime::GraphViewer&) const;
|
||||
void FulfillEPContextEnablement(const std::vector<FusedNodeAndGraph>&);
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
0
onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc
Executable file → Normal file
0
onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc
Executable file → Normal file
|
|
@ -28,6 +28,7 @@
|
|||
#include "core/session/inference_session.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/provider_bridge_ort.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/framework/sparse_utils.h"
|
||||
|
|
@ -68,10 +69,12 @@ using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringE
|
|||
using TensorProtos = google::protobuf::RepeatedPtrField<TensorProto>;
|
||||
using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField<TensorShapeProto_Dimension>;
|
||||
using ValueInfoProtos = google::protobuf::RepeatedPtrField<ValueInfoProto>;
|
||||
using FunctionProtos = google::protobuf::RepeatedPtrField<FunctionProto>;
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
namespace onnxruntime {
|
||||
using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
|
||||
using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema;
|
||||
} // namespace onnxruntime
|
||||
|
||||
#include "core/common/cpuid_info.h"
|
||||
|
|
@ -400,6 +403,11 @@ struct ProviderHostImpl : ProviderHost {
|
|||
int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); }
|
||||
ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); };
|
||||
|
||||
// OperatorSetIdProto
|
||||
std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->mutable_domain(); }
|
||||
void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) override { return p->set_version(version); }
|
||||
int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) override { return p->version(); }
|
||||
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
// TypeProto_Optional (wrapped)
|
||||
const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) override { return p->elem_type(); }
|
||||
|
|
@ -528,6 +536,11 @@ struct ProviderHostImpl : ProviderHost {
|
|||
void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); }
|
||||
ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); };
|
||||
|
||||
const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) override { return p->opset_import(index); }
|
||||
ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) override { return p->mutable_opset_import(index); }
|
||||
int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) override { return p->opset_import_size(); }
|
||||
ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) override { return p->add_opset_import(); }
|
||||
|
||||
// NodeProto (wrapped)
|
||||
std::unique_ptr<ONNX_NAMESPACE::NodeProto> NodeProto__construct() override { return std::make_unique<ONNX_NAMESPACE::NodeProto>(); }
|
||||
void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) override { delete p; }
|
||||
|
|
@ -535,6 +548,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) override { return p->attribute_size(); }
|
||||
const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const override { return p->attribute(index); }
|
||||
ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) override { return p->mutable_attribute(index); }
|
||||
ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) override { return p->add_attribute(); }
|
||||
|
||||
// TensorProto (wrapped)
|
||||
std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() override { return std::make_unique<ONNX_NAMESPACE::TensorProto>(); }
|
||||
|
|
@ -609,6 +623,64 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; }
|
||||
|
||||
// FunctionProto (wrapped)
|
||||
std::unique_ptr<ONNX_NAMESPACE::FunctionProto> FunctionProto__construct() override { return std::make_unique<ONNX_NAMESPACE::FunctionProto>(); }
|
||||
void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) override { delete p; }
|
||||
|
||||
bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) override { return p->SerializeToString(&string); }
|
||||
bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) override { return p->SerializeToOstream(&output); }
|
||||
bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) override { return p->ParseFromString(data); }
|
||||
std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) override { return p->SerializeAsString(); }
|
||||
|
||||
bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_name(); }
|
||||
const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->name(); }
|
||||
void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const std::string& name) override { p->set_name(name); }
|
||||
|
||||
bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_doc_string(); }
|
||||
const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->doc_string(); }
|
||||
void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const std::string& doc_string) override { p->set_doc_string(doc_string); }
|
||||
|
||||
bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_domain(); }
|
||||
const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->domain(); }
|
||||
void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const std::string& domain) override { p->set_domain(domain); }
|
||||
|
||||
const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->input(index); }
|
||||
std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_input(index); }
|
||||
int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->input_size(); }
|
||||
void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_input(value); }
|
||||
|
||||
const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->output(index); }
|
||||
std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_output(index); }
|
||||
int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->output_size(); }
|
||||
void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_output(value); }
|
||||
|
||||
const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute(index); }
|
||||
std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute(index); }
|
||||
int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_size(); }
|
||||
void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_attribute(value); }
|
||||
|
||||
const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute_proto(index); }
|
||||
ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute_proto(index); }
|
||||
int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_proto_size(); }
|
||||
ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_attribute_proto(); }
|
||||
|
||||
const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->node(index); }
|
||||
ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_node(index); }
|
||||
int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->node_size(); }
|
||||
ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_node(); }
|
||||
|
||||
const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->value_info(index); }
|
||||
ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_value_info(index); }
|
||||
ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_value_info(); }
|
||||
int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->value_info_size(); }
|
||||
ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_value_info(); }
|
||||
|
||||
const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->metadata_props(index); }
|
||||
ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_metadata_props(index); }
|
||||
ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_metadata_props(); }
|
||||
int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); }
|
||||
ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); }
|
||||
|
||||
static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) {
|
||||
int32_t elemType = 0;
|
||||
if (data_type->s() == "float32") {
|
||||
|
|
@ -791,9 +863,12 @@ struct ProviderHostImpl : ProviderHost {
|
|||
|
||||
std::vector<onnxruntime::NodeIndex>& IndexedSubGraph__Nodes(IndexedSubGraph* p) override { return p->nodes; }
|
||||
|
||||
void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) override { return p->SetMetaDef(std::move(meta_def_)); }
|
||||
void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) override { p->SetMetaDef(std::move(meta_def_)); }
|
||||
const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) override { return p->GetMetaDef(); }
|
||||
|
||||
void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) override { p->schema_source = schema_source; }
|
||||
IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) override { return p->schema_source; }
|
||||
|
||||
// KernelDef (wrapped)
|
||||
void KernelDef__operator_delete(KernelDef* p) override { delete p; }
|
||||
void KernelDef__SinceVersion(const KernelDef* p, int* start, int* end) override { return p->SinceVersion(start, end); }
|
||||
|
|
@ -2842,6 +2917,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_
|
|||
|
||||
provider_options[provider_options_keys[i]] = provider_options_values[i];
|
||||
}
|
||||
// EP context related session config options.
|
||||
provider_options["ep_context_enable"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0");
|
||||
provider_options["ep_context_embed_mode"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1");
|
||||
provider_options["ep_context_file_path"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
|
||||
auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options);
|
||||
if (!factory) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library");
|
||||
|
|
|
|||
|
|
@ -1114,6 +1114,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
if (it != provider_options_map.end()) {
|
||||
info = it->second;
|
||||
}
|
||||
info["ep_context_enable"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0");
|
||||
info["ep_context_embed_mode"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1");
|
||||
info["ep_context_file_path"] = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
|
||||
return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider();
|
||||
#endif
|
||||
} else if (type == kAclExecutionProvider) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue