mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[VitisAI] Remove shape infer from bridge ort (#21331)
### Description Vitis AI EP's custom op are completely self contained within Vitis AI EP implementation (rather than needing to add static functions in provider_bridge). --------- Co-authored-by: liumingyue <mingyue@xilinx.com>
This commit is contained in:
parent
509cb54d6f
commit
047f32c79d
4 changed files with 4 additions and 133 deletions
|
|
@ -567,7 +567,7 @@ struct ProviderHost {
|
|||
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
|
||||
|
||||
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0;
|
||||
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0;
|
||||
virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0;
|
||||
virtual const std::string& OpSchema__inputs__GetName(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0;
|
||||
virtual const std::string& OpSchema__inputs__GetTypeStr(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0;
|
||||
|
|
|
|||
|
|
@ -13,14 +13,7 @@ void register_xir_ops(const std::vector<OrtCustomOpDomain*>& domains) {
|
|||
for (auto domain : domains) {
|
||||
for (auto op : domain->custom_ops_) {
|
||||
if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) == nullptr) {
|
||||
auto name = op->GetName(op);
|
||||
if ((std::string)name == "super_layer") {
|
||||
Provider_GetHost()->RegisterSchema(domain->domain_, op, 1);
|
||||
} else if ((std::string)name == "FixNeuron") {
|
||||
Provider_GetHost()->RegisterSchema(domain->domain_, op, 2);
|
||||
} else {
|
||||
Provider_GetHost()->RegisterSchema(domain->domain_, op, 3);
|
||||
}
|
||||
Provider_GetHost()->RegisterSchema(domain->domain_, op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ struct OrtApi;
|
|||
|
||||
namespace vaip_core {
|
||||
|
||||
#define VAIP_ORT_API_MAJOR (7u)
|
||||
#define VAIP_ORT_API_MAJOR (8u)
|
||||
#define VAIP_ORT_API_MINOR (0u)
|
||||
#define VAIP_ORT_API_PATCH (0u)
|
||||
struct OrtApiForVaip {
|
||||
|
|
|
|||
|
|
@ -682,135 +682,13 @@ struct ProviderHostImpl : ProviderHost {
|
|||
int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); }
|
||||
ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); }
|
||||
|
||||
static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) {
|
||||
int32_t elemType = 0;
|
||||
if (data_type->s() == "float32") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
} else if (data_type->s() == "int8") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8;
|
||||
} else if (data_type->s() == "uint8") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
|
||||
} else if (data_type->s() == "int32") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
} else if (data_type->s() == "uint32") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
|
||||
} else if (data_type->s() == "int64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
} else if (data_type->s() == "uint64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
|
||||
} else if (data_type->s() == "int1") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
} else if (data_type->s() == "bfloat16") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
|
||||
} else if (data_type->s() == "float16") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
|
||||
} else if (data_type->s() == "uint16") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
|
||||
} else if (data_type->s() == "int16") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
} else if (data_type->s() == "double") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
} else if (data_type->s() == "string") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING;
|
||||
} else if (data_type->s() == "complex64") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
|
||||
} else if (data_type->s() == "complex128") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
|
||||
} else if (data_type->s() == "float8e4m3fn") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
|
||||
} else if (data_type->s() == "float8e4m3fnuz") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
|
||||
} else if (data_type->s() == "float8e5m2") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
|
||||
} else if (data_type->s() == "float8e5m2funz") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
|
||||
} else if (data_type->s() == "uint4") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
|
||||
} else if (data_type->s() == "int4") {
|
||||
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
|
||||
}
|
||||
return elemType;
|
||||
}
|
||||
|
||||
static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
auto num_output = ctx.getNumOutputs();
|
||||
if (num_output == 1) {
|
||||
auto* shape = ctx.getAttribute("shape");
|
||||
auto* data_type = ctx.getAttribute("data_type");
|
||||
if (data_type == nullptr) {
|
||||
std::cerr << "Custom op is missing `data_type` attr." << std::endl;
|
||||
return;
|
||||
}
|
||||
int32_t elemType = convert_elem_type(data_type);
|
||||
ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType);
|
||||
if (shape != nullptr) {
|
||||
for (auto i = 0; i < shape->ints_size(); ++i) {
|
||||
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
|
||||
}
|
||||
} else {
|
||||
// set scalar type.
|
||||
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim();
|
||||
}
|
||||
} else {
|
||||
for (auto idx = 0u; idx < num_output; idx++) {
|
||||
auto* shape = ctx.getAttribute("shape_" + std::to_string(idx));
|
||||
auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx));
|
||||
if (shape == nullptr || data_type == nullptr) {
|
||||
// this output is optional
|
||||
} else {
|
||||
int32_t elemType = convert_elem_type(data_type);
|
||||
ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType);
|
||||
for (auto i = 0; i < shape->ints_size(); ++i) {
|
||||
ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
auto num_inputs = ctx.getNumInputs();
|
||||
|
||||
// Run inferencing on the subgraph
|
||||
auto* graphInferencer = ctx.getGraphAttributeInferencer("body");
|
||||
|
||||
std::vector<const ONNX_NAMESPACE::TensorProto*> input_data;
|
||||
std::vector<const ONNX_NAMESPACE::TypeProto*> subgraph_input_types;
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
input_data.push_back(ctx.getInputData(i));
|
||||
subgraph_input_types.push_back(ctx.getInputType(i));
|
||||
}
|
||||
|
||||
auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
|
||||
for (size_t i = 0, end = output_types.size(); i < end; ++i) {
|
||||
*ctx.getOutputType(i) = *output_types[i];
|
||||
}
|
||||
}
|
||||
void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override {
|
||||
void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override {
|
||||
auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
|
||||
const auto& domain_to_version_map = domain_instance.Map();
|
||||
if (domain_to_version_map.find(domain) == domain_to_version_map.end()) {
|
||||
domain_instance.AddDomainToVersion(domain, 1, 1000);
|
||||
}
|
||||
auto schema = CreateSchema(domain, {op});
|
||||
switch (type) {
|
||||
case 1:
|
||||
schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference);
|
||||
break;
|
||||
case 2:
|
||||
schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference);
|
||||
break;
|
||||
case 3:
|
||||
schema.TypeAndShapeInferenceFunction(xir_shape_infer);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION);
|
||||
}
|
||||
const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) override {
|
||||
|
|
|
|||
Loading…
Reference in a new issue