mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Fix sample tests (#1926)
This commit is contained in:
parent
09cdbe9d76
commit
7e22ed41b9
2 changed files with 39 additions and 36 deletions
|
|
@ -6,6 +6,8 @@
|
|||
#include <vector>
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
//*************************************************************************
|
||||
// initialize enviroment...one enviroment per process
|
||||
|
|
|
|||
|
|
@ -8,33 +8,34 @@
|
|||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
const OrtApi* g_ort = OrtGetApi(ORT_API_VERSION);
|
||||
|
||||
//*****************************************************************************
|
||||
// helper function to check for status
|
||||
#define CHECK_STATUS(expr) \
|
||||
{ \
|
||||
OrtStatus* onnx_status = (expr); \
|
||||
if (onnx_status != NULL) { \
|
||||
const char* msg = OrtGetErrorMessage(onnx_status); \
|
||||
fprintf(stderr, "%s\n", msg); \
|
||||
OrtReleaseStatus(onnx_status); \
|
||||
exit(1); \
|
||||
} \
|
||||
}
|
||||
void CheckStatus(OrtStatus* status)
|
||||
{
|
||||
if (status != NULL) {
|
||||
const char* msg = g_ort->GetErrorMessage(status);
|
||||
fprintf(stderr, "%s\n", msg);
|
||||
g_ort->ReleaseStatus(status);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
//*************************************************************************
|
||||
// initialize enviroment...one enviroment per process
|
||||
// enviroment maintains thread pools and other state info
|
||||
OrtEnv* env;
|
||||
CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));
|
||||
CheckStatus(g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));
|
||||
|
||||
// initialize session options if needed
|
||||
OrtSessionOptions* session_options;
|
||||
CHECK_STATUS(OrtCreateSessionOptions(&session_options));
|
||||
OrtSetIntraOpNumThreads(session_options, 1);
|
||||
CheckStatus(g_ort->CreateSessionOptions(&session_options));
|
||||
g_ort->SetIntraOpNumThreads(session_options, 1);
|
||||
|
||||
// Sets graph optimization level
|
||||
OrtSetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC);
|
||||
g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC);
|
||||
|
||||
// Optionally add more execution providers via session_options
|
||||
// E.g. for CUDA include cuda_provider_factory.h and uncomment the following line:
|
||||
|
|
@ -52,17 +53,17 @@ int main(int argc, char* argv[]) {
|
|||
#endif
|
||||
|
||||
printf("Using Onnxruntime C API\n");
|
||||
CHECK_STATUS(OrtCreateSession(env, model_path, session_options, &session));
|
||||
CheckStatus(g_ort->CreateSession(env, model_path, session_options, &session));
|
||||
|
||||
//*************************************************************************
|
||||
// print model input layer (node names, types, shape etc.)
|
||||
size_t num_input_nodes;
|
||||
OrtStatus* status;
|
||||
OrtAllocator* allocator;
|
||||
CHECK_STATUS(OrtGetAllocatorWithDefaultOptions(&allocator));
|
||||
CheckStatus(g_ort->GetAllocatorWithDefaultOptions(&allocator));
|
||||
|
||||
// print number of model input nodes
|
||||
status = OrtSessionGetInputCount(session, &num_input_nodes);
|
||||
status = g_ort->SessionGetInputCount(session, &num_input_nodes);
|
||||
std::vector<const char*> input_node_names(num_input_nodes);
|
||||
std::vector<int64_t> input_node_dims; // simplify... this model has only 1 input node {1, 3, 224, 224}.
|
||||
// Otherwise need vector<vector<>>
|
||||
|
|
@ -73,29 +74,29 @@ int main(int argc, char* argv[]) {
|
|||
for (size_t i = 0; i < num_input_nodes; i++) {
|
||||
// print input node names
|
||||
char* input_name;
|
||||
status = OrtSessionGetInputName(session, i, allocator, &input_name);
|
||||
status = g_ort->SessionGetInputName(session, i, allocator, &input_name);
|
||||
printf("Input %zu : name=%s\n", i, input_name);
|
||||
input_node_names[i] = input_name;
|
||||
|
||||
// print input node types
|
||||
OrtTypeInfo* typeinfo;
|
||||
status = OrtSessionGetInputTypeInfo(session, i, &typeinfo);
|
||||
status = g_ort->SessionGetInputTypeInfo(session, i, &typeinfo);
|
||||
const OrtTensorTypeAndShapeInfo* tensor_info;
|
||||
CHECK_STATUS(OrtCastTypeInfoToTensorInfo(typeinfo, &tensor_info));
|
||||
CheckStatus(g_ort->CastTypeInfoToTensorInfo(typeinfo, &tensor_info));
|
||||
ONNXTensorElementDataType type;
|
||||
CHECK_STATUS(OrtGetTensorElementType(tensor_info, &type));
|
||||
CheckStatus(g_ort->GetTensorElementType(tensor_info, &type));
|
||||
printf("Input %zu : type=%d\n", i, type);
|
||||
|
||||
// print input shapes/dims
|
||||
size_t num_dims;
|
||||
CHECK_STATUS(OrtGetDimensionsCount(tensor_info, &num_dims));
|
||||
CheckStatus(g_ort->GetDimensionsCount(tensor_info, &num_dims));
|
||||
printf("Input %zu : num_dims=%zu\n", i, num_dims);
|
||||
input_node_dims.resize(num_dims);
|
||||
OrtGetDimensions(tensor_info, (int64_t*)input_node_dims.data(), num_dims);
|
||||
g_ort->GetDimensions(tensor_info, (int64_t*)input_node_dims.data(), num_dims);
|
||||
for (size_t j = 0; j < num_dims; j++)
|
||||
printf("Input %zu : dim %zu=%jd\n", i, j, input_node_dims[j]);
|
||||
|
||||
OrtReleaseTypeInfo(typeinfo);
|
||||
g_ort->ReleaseTypeInfo(typeinfo);
|
||||
}
|
||||
|
||||
// Results should be...
|
||||
|
|
@ -128,23 +129,23 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
// create input tensor object from data values
|
||||
OrtMemoryInfo* memory_info;
|
||||
CHECK_STATUS(OrtCreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info));
|
||||
CheckStatus(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info));
|
||||
OrtValue* input_tensor = NULL;
|
||||
CHECK_STATUS(OrtCreateTensorWithDataAsOrtValue(memory_info, input_tensor_values.data(), input_tensor_size * sizeof(float), input_node_dims.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor));
|
||||
CheckStatus(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_tensor_values.data(), input_tensor_size * sizeof(float), input_node_dims.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor));
|
||||
int is_tensor;
|
||||
CHECK_STATUS(OrtIsTensor(input_tensor, &is_tensor));
|
||||
CheckStatus(g_ort->IsTensor(input_tensor, &is_tensor));
|
||||
assert(is_tensor);
|
||||
OrtReleaseMemoryInfo(memory_info);
|
||||
g_ort->ReleaseMemoryInfo(memory_info);
|
||||
|
||||
// score model & input tensor, get back output tensor
|
||||
OrtValue* output_tensor = NULL;
|
||||
CHECK_STATUS(OrtRun(session, NULL, input_node_names.data(), (const OrtValue* const*)&input_tensor, 1, output_node_names.data(), 1, &output_tensor));
|
||||
CHECK_STATUS(OrtIsTensor(output_tensor, &is_tensor));
|
||||
CheckStatus(g_ort->Run(session, NULL, input_node_names.data(), (const OrtValue* const*)&input_tensor, 1, output_node_names.data(), 1, &output_tensor));
|
||||
CheckStatus(g_ort->IsTensor(output_tensor, &is_tensor));
|
||||
assert(is_tensor);
|
||||
|
||||
// Get pointer to output tensor float values
|
||||
float* floatarr;
|
||||
CHECK_STATUS(OrtGetTensorMutableData(output_tensor, (void**)&floatarr));
|
||||
CheckStatus(g_ort->GetTensorMutableData(output_tensor, (void**)&floatarr));
|
||||
assert(abs(floatarr[0] - 0.000045) < 1e-6);
|
||||
|
||||
// score the model, and print scores for first 5 classes
|
||||
|
|
@ -158,11 +159,11 @@ int main(int argc, char* argv[]) {
|
|||
// Score for class[3] = 0.001180
|
||||
// Score for class[4] = 0.001317
|
||||
|
||||
OrtReleaseValue(output_tensor);
|
||||
OrtReleaseValue(input_tensor);
|
||||
OrtReleaseSession(session);
|
||||
OrtReleaseSessionOptions(session_options);
|
||||
OrtReleaseEnv(env);
|
||||
g_ort->ReleaseValue(output_tensor);
|
||||
g_ort->ReleaseValue(input_tensor);
|
||||
g_ort->ReleaseSession(session);
|
||||
g_ort->ReleaseSessionOptions(session_options);
|
||||
g_ort->ReleaseEnv(env);
|
||||
printf("Done!\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue