vitis ai support generic data type (#16902)

### Description
<!-- Describe your changes. -->
Support more data types for vitis ai.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
It is required because the models we are testing now have uint8 data
type. To solve this once for all, we changed the code to support generic
data type.
This commit is contained in:
BoarQing 2023-08-03 06:56:39 +08:00 committed by GitHub
parent d399648869
commit 6361b22103
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 27 deletions

View file

@ -26,6 +26,8 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT);
} else if (data_type->s() == "int8") {
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8);
} else if (data_type->s() == "uint8") {
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8);
} else if (data_type->s() == "int32") {
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32);
} else if (data_type->s() == "int64") {

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "./tensor_proto.h"
#include "./vai_assert.h"
#include "core/framework/tensorprotoutils.h"
#include <cstdint>
#include <limits>
@ -10,36 +11,14 @@ namespace vaip {
gsl::span<const char> tensor_proto_as_raw(
const ONNX_NAMESPACE::TensorProto& tensor) {
auto data_type = tensor.data_type();
auto& mut_tensor = const_cast<ONNX_NAMESPACE::TensorProto&>(tensor);
if (tensor.has_raw_data()) {
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
} else if (tensor.float_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::FLOAT) {
return gsl::span<const char>((char*)tensor.float_data().data(), tensor.float_data().size() * sizeof(float));
} else if (tensor.int32_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT32) {
return gsl::span<const char>((char*)tensor.int32_data().data(), tensor.int32_data().size() * sizeof(int));
// test case: graph_opt model #43
} else if (tensor.int64_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT64) {
return gsl::span<const char>((char*)tensor.int64_data().data(), tensor.int64_data().size() * sizeof(int64_t));
} else if (data_type == ONNX_NAMESPACE::TensorProto::INT8) {
auto size = tensor.int32_data_size();
assert(size > 0);
mut_tensor.mutable_raw_data()->resize(sizeof(char) * size);
char* base = &(*mut_tensor.mutable_raw_data())[0];
for (auto i = 0; i < size; ++i) {
auto value = (char)tensor.int32_data(i);
assert(value >= std::numeric_limits<char>::min());
assert(value <= std::numeric_limits<char>::max());
base[i] = value;
}
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
} else {
vai_assert(false, "not support data_type");
if (!tensor.has_raw_data()) {
std::vector<uint8_t> unpacked_tensor;
auto s = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size());
memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size());
}
#ifndef _WIN32
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
#endif
return gsl::span<const char>();
}
size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {