diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 9b15ebb394..04e9655340 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1,4 +1,5 @@ #include "ort_test_session.h" +#include #include #include #include "core/session/onnxruntime_session_options_config_keys.h" @@ -386,7 +387,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (key == "runtime") { std::set supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"}; if (supported_runtime.find(value) == supported_runtime.end()) { - ORT_THROW(R"(Wrong configuration value for the key 'runtime'. + ORT_THROW(R"(Wrong configuration value for the key 'runtime'. select from 'CPU', 'GPU_FP32', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n)"); } } else if (key == "priority") { @@ -394,7 +395,7 @@ select from 'CPU', 'GPU_FP32', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n)" } else if (key == "buffer_type") { std::set supported_buffer_type = {"TF8", "TF16", "UINT8", "FLOAT", "ITENSOR"}; if (supported_buffer_type.find(value) == supported_buffer_type.end()) { - ORT_THROW(R"(Wrong configuration value for the key 'buffer_type'. + ORT_THROW(R"(Wrong configuration value for the key 'buffer_type'. select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } else { @@ -563,6 +564,47 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } +template +static void FillTensorDataTyped(Ort::Value& tensor, size_t count, T value = T{}) { + T* data = tensor.GetTensorMutableData(); + std::fill_n(data, count, value); +} + +static void InitializeTensorData(Ort::Value& tensor) { + const auto type_and_shape = tensor.GetTensorTypeAndShapeInfo(); + const auto count = type_and_shape.GetElementCount(); + const auto element_type = type_and_shape.GetElementType(); + +#define CASE_FOR_TYPE(T) \ + case Ort::TypeToTensorType::type: { \ + FillTensorDataTyped(tensor, count); \ + } break + + switch (element_type) { + CASE_FOR_TYPE(Ort::Float16_t); + CASE_FOR_TYPE(Ort::BFloat16_t); + CASE_FOR_TYPE(float); + CASE_FOR_TYPE(double); + CASE_FOR_TYPE(int8_t); + CASE_FOR_TYPE(int16_t); + CASE_FOR_TYPE(int32_t); + CASE_FOR_TYPE(int64_t); + CASE_FOR_TYPE(uint8_t); + CASE_FOR_TYPE(uint16_t); + CASE_FOR_TYPE(uint32_t); + CASE_FOR_TYPE(uint64_t); + CASE_FOR_TYPE(bool); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + // string tensors are already initialized to contain empty strings + // see onnxruntime::Tensor::Init() + break; + default: + ORT_THROW("Unsupported tensor data type: ", element_type); + } + +#undef CASE_FOR_TYPE +} + bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData() { // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { @@ -582,6 +624,7 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData() { auto allocator = static_cast(Ort::AllocatorWithDefaultOptions()); Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorData(input_tensor); PreLoadTestData(0, i, std::move(input_tensor)); } }