Initialize generated tensor data in onnxruntime_perf_test. (#12275)

Initialize generated tensor data in onnxruntime_perf_test to zeroes instead of leaving it uninitialized. String tensors were already being initialized.
This commit is contained in:
Edward Chen 2022-07-22 16:26:13 -07:00 committed by GitHub
parent 89ac61f4d4
commit 564dc32304
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,4 +1,5 @@
#include "ort_test_session.h"
#include <algorithm>
#include <set>
#include <core/session/onnxruntime_cxx_api.h>
#include "core/session/onnxruntime_session_options_config_keys.h"
@ -386,7 +387,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
if (key == "runtime") {
std::set<std::string> supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"};
if (supported_runtime.find(value) == supported_runtime.end()) {
ORT_THROW(R"(Wrong configuration value for the key 'runtime'.
ORT_THROW(R"(Wrong configuration value for the key 'runtime'.
select from 'CPU', 'GPU_FP32', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n)");
}
} else if (key == "priority") {
@ -394,7 +395,7 @@ select from 'CPU', 'GPU_FP32', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n)"
} else if (key == "buffer_type") {
std::set<std::string> supported_buffer_type = {"TF8", "TF16", "UINT8", "FLOAT", "ITENSOR"};
if (supported_buffer_type.find(value) == supported_buffer_type.end()) {
ORT_THROW(R"(Wrong configuration value for the key 'buffer_type'.
ORT_THROW(R"(Wrong configuration value for the key 'buffer_type'.
select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
} else {
@ -563,6 +564,47 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
}
template <typename T>
static void FillTensorDataTyped(Ort::Value& tensor, size_t count, T value = T{}) {
T* data = tensor.GetTensorMutableData<T>();
std::fill_n(data, count, value);
}
static void InitializeTensorData(Ort::Value& tensor) {
const auto type_and_shape = tensor.GetTensorTypeAndShapeInfo();
const auto count = type_and_shape.GetElementCount();
const auto element_type = type_and_shape.GetElementType();
#define CASE_FOR_TYPE(T) \
case Ort::TypeToTensorType<T>::type: { \
FillTensorDataTyped<T>(tensor, count); \
} break
switch (element_type) {
CASE_FOR_TYPE(Ort::Float16_t);
CASE_FOR_TYPE(Ort::BFloat16_t);
CASE_FOR_TYPE(float);
CASE_FOR_TYPE(double);
CASE_FOR_TYPE(int8_t);
CASE_FOR_TYPE(int16_t);
CASE_FOR_TYPE(int32_t);
CASE_FOR_TYPE(int64_t);
CASE_FOR_TYPE(uint8_t);
CASE_FOR_TYPE(uint16_t);
CASE_FOR_TYPE(uint32_t);
CASE_FOR_TYPE(uint64_t);
CASE_FOR_TYPE(bool);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
// string tensors are already initialized to contain empty strings
// see onnxruntime::Tensor::Init()
break;
default:
ORT_THROW("Unsupported tensor data type: ", element_type);
}
#undef CASE_FOR_TYPE
}
bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData() {
// iterate over all input nodes
for (size_t i = 0; i < static_cast<size_t>(input_length_); i++) {
@ -582,6 +624,7 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData() {
auto allocator = static_cast<OrtAllocator*>(Ort::AllocatorWithDefaultOptions());
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(),
input_node_dim.size(), tensor_info.GetElementType());
InitializeTensorData(input_tensor);
PreLoadTestData(0, i, std::move(input_tensor));
}
}