diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 4b2b7610cf..872d022e85 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -9,9 +9,44 @@ #include "core/providers/shared_library/provider_api.h" namespace vaip { using namespace onnxruntime; + +static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { + auto tensor_proto = const_cast(&tensor); + auto file = std::string(); + uintptr_t offset = 0; + size_t size = 0; + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + auto external_data_size = external_data->size(); + for (auto i = 0; i < external_data_size; ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + if (*data.mutable_key() == "location") { + file = *data.mutable_value(); + } else if (*data.mutable_key() == "offset") { + offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "length") { + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "checksum") { + // checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + if (file == "*/_ORT_MEM_ADDR_/*") { + auto addr = reinterpret_cast(offset); + return {addr, size}; + } + } + return {}; +} + gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { + auto maybe_external_memory_address = process_ext_address(tensor); + if (!maybe_external_memory_address.empty()) { + return maybe_external_memory_address; + } + std::vector unpacked_tensor; auto path = graph.ModelPath(); auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor);