diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 7cc42983f3..544e183506 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -26,6 +26,8 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); } else if (data_type->s() == "int8") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); + } else if (data_type->s() == "uint8") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); } else if (data_type->s() == "int32") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); } else if (data_type->s() == "int64") { diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 73317bf520..edb14eda14 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "./tensor_proto.h" #include "./vai_assert.h" +#include "core/framework/tensorprotoutils.h" #include #include @@ -10,36 +11,14 @@ namespace vaip { gsl::span tensor_proto_as_raw( const ONNX_NAMESPACE::TensorProto& tensor) { - auto data_type = tensor.data_type(); auto& mut_tensor = const_cast(tensor); - if (tensor.has_raw_data()) { - return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); - } else if (tensor.float_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::FLOAT) { - return gsl::span((char*)tensor.float_data().data(), tensor.float_data().size() * sizeof(float)); - } else if (tensor.int32_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT32) { - return gsl::span((char*)tensor.int32_data().data(), tensor.int32_data().size() * sizeof(int)); - // test case: graph_opt model #43 - } else if (tensor.int64_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT64) { - return gsl::span((char*)tensor.int64_data().data(), tensor.int64_data().size() * sizeof(int64_t)); - } else if (data_type == ONNX_NAMESPACE::TensorProto::INT8) { - auto size = tensor.int32_data_size(); - assert(size > 0); - mut_tensor.mutable_raw_data()->resize(sizeof(char) * size); - char* base = &(*mut_tensor.mutable_raw_data())[0]; - for (auto i = 0; i < size; ++i) { - auto value = (char)tensor.int32_data(i); - assert(value >= std::numeric_limits::min()); - assert(value <= std::numeric_limits::max()); - base[i] = value; - } - return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); - } else { - vai_assert(false, "not support data_type"); + if (!tensor.has_raw_data()) { + std::vector unpacked_tensor; + auto s = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor); + mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); + memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size()); } -#ifndef _WIN32 return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); -#endif - return gsl::span(); } size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {