mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
vitis ai support generic data type (#16902)
### Description <!-- Describe your changes. --> Support more data types for vitis ai. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It is required because the models we are testing now have uint8 data type. To solve this once for all, we changed the code to support generic data type.
This commit is contained in:
parent
d399648869
commit
6361b22103
2 changed files with 8 additions and 27 deletions
|
|
@ -26,6 +26,8 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
|
|||
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT);
|
||||
} else if (data_type->s() == "int8") {
|
||||
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8);
|
||||
} else if (data_type->s() == "uint8") {
|
||||
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8);
|
||||
} else if (data_type->s() == "int32") {
|
||||
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32);
|
||||
} else if (data_type->s() == "int64") {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
#include "./tensor_proto.h"
|
||||
#include "./vai_assert.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
|
@ -10,36 +11,14 @@ namespace vaip {
|
|||
|
||||
gsl::span<const char> tensor_proto_as_raw(
|
||||
const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
auto data_type = tensor.data_type();
|
||||
auto& mut_tensor = const_cast<ONNX_NAMESPACE::TensorProto&>(tensor);
|
||||
if (tensor.has_raw_data()) {
|
||||
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
|
||||
} else if (tensor.float_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::FLOAT) {
|
||||
return gsl::span<const char>((char*)tensor.float_data().data(), tensor.float_data().size() * sizeof(float));
|
||||
} else if (tensor.int32_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT32) {
|
||||
return gsl::span<const char>((char*)tensor.int32_data().data(), tensor.int32_data().size() * sizeof(int));
|
||||
// test case: graph_opt model #43
|
||||
} else if (tensor.int64_data_size() > 0 && data_type == ONNX_NAMESPACE::TensorProto::INT64) {
|
||||
return gsl::span<const char>((char*)tensor.int64_data().data(), tensor.int64_data().size() * sizeof(int64_t));
|
||||
} else if (data_type == ONNX_NAMESPACE::TensorProto::INT8) {
|
||||
auto size = tensor.int32_data_size();
|
||||
assert(size > 0);
|
||||
mut_tensor.mutable_raw_data()->resize(sizeof(char) * size);
|
||||
char* base = &(*mut_tensor.mutable_raw_data())[0];
|
||||
for (auto i = 0; i < size; ++i) {
|
||||
auto value = (char)tensor.int32_data(i);
|
||||
assert(value >= std::numeric_limits<char>::min());
|
||||
assert(value <= std::numeric_limits<char>::max());
|
||||
base[i] = value;
|
||||
}
|
||||
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
|
||||
} else {
|
||||
vai_assert(false, "not support data_type");
|
||||
if (!tensor.has_raw_data()) {
|
||||
std::vector<uint8_t> unpacked_tensor;
|
||||
auto s = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
|
||||
mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size());
|
||||
memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size());
|
||||
}
|
||||
#ifndef _WIN32
|
||||
return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
|
||||
#endif
|
||||
return gsl::span<const char>();
|
||||
}
|
||||
|
||||
size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue