From 6361b221033f04df2b38fded82030b89fd73ffa0 Mon Sep 17 00:00:00 2001 From: BoarQing Date: Thu, 3 Aug 2023 06:56:39 +0800 Subject: [PATCH] vitis ai support generic data type (#16902) ### Description Support more data types for vitis ai. ### Motivation and Context It is required because the models we are testing now have uint8 data type. To solve this once for all, we changed the code to support generic data type. --- .../providers/vitisai/imp/register_xir_ops.cc | 2 ++ .../providers/vitisai/imp/tensor_proto.cc | 33 ++++--------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 7cc42983f3..544e183506 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -26,6 +26,8 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); } else if (data_type->s() == "int8") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); + } else if (data_type->s() == "uint8") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); } else if (data_type->s() == "int32") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); } else if (data_type->s() == "int64") { diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 73317bf520..edb14eda14 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "./tensor_proto.h" #include "./vai_assert.h" +#include "core/framework/tensorprotoutils.h" #include #include @@ -10,36 +11,14 @@ namespace vaip { gsl::span tensor_proto_as_raw( const ONNX_NAMESPACE::TensorProto& tensor) { - auto data_type = tensor.data_type(); auto& mut_tensor = const_cast(tensor); - if (tensor.has_raw_data()) { - return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); - } else if (tensor.float_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::FLOAT) { - return gsl::span((char*)tensor.float_data().data(), tensor.float_data().size() * sizeof(float)); - } else if (tensor.int32_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT32) { - return gsl::span((char*)tensor.int32_data().data(), tensor.int32_data().size() * sizeof(int)); - // test case: graph_opt model #43 - } else if (tensor.int64_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT64) { - return gsl::span((char*)tensor.int64_data().data(), tensor.int64_data().size() * sizeof(int64_t)); - } else if (data_type == ONNX_NAMESPACE::TensorProto::INT8) { - auto size = tensor.int32_data_size(); - assert(size > 0); - mut_tensor.mutable_raw_data()->resize(sizeof(char) * size); - char* base = &(*mut_tensor.mutable_raw_data())[0]; - for (auto i = 0; i < size; ++i) { - auto value = (char)tensor.int32_data(i); - assert(value >= std::numeric_limits::min()); - assert(value <= std::numeric_limits::max()); - base[i] = value; - } - return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); - } else { - vai_assert(false, "not support data_type"); + if (!tensor.has_raw_data()) { + std::vector unpacked_tensor; + auto s = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor); + mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); + memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size()); } -#ifndef _WIN32 return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); -#endif - return gsl::span(); } size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {