Check some return status values that were ignored and add logging of any error messages in onnxruntime_perf_test. (#525)

This commit is contained in:
Scott McKay 2019-02-27 20:20:06 -08:00 committed by GitHub
parent 6c7099a18e
commit 2e6ec07d9a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -94,11 +94,20 @@ bool PerformanceRunner::Initialize() {
sf.enable_sequential_execution = performance_test_config_.run_config.enable_sequential_execution;
sf.session_thread_pool_size = 6;
sf.Create(session_object_, test_case->GetModelUrl(), test_case->GetTestCaseName());
auto status = sf.Create(session_object_, test_case->GetModelUrl(), test_case->GetTestCaseName());
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Failed to create InferenceSession."
<< " TestCaseName:" << test_case->GetTestCaseName()
<< ", Error:" << status.ErrorMessage();
return false;
}
// Initialize IO Binding
if (!session_object_->NewIOBinding(&io_binding_).IsOK()) {
LOGF_DEFAULT(ERROR, "Failed to init session and IO binding");
status = session_object_->NewIOBinding(&io_binding_);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Failed to init session and IO binding. "
<< " TestCaseName:" << test_case->GetTestCaseName()
<< ", Error:" << status.ErrorMessage();
return false;
}
@ -112,21 +121,27 @@ bool PerformanceRunner::Initialize() {
test_case->SetAllocator(cpu_allocator);
if (test_case->GetDataCount() <= 0) {
LOGF_DEFAULT(ERROR, "there is no test data for model %s", test_case->GetTestCaseName().c_str());
LOGS_DEFAULT(ERROR) << "there is no test data for model ", test_case->GetTestCaseName();
return false;
}
std::unordered_map<std::string, ::onnxruntime::MLValue> feeds;
test_case->LoadTestData(0 /* id */, feeds, true);
for (auto feed : feeds) {
io_binding_->BindInput(feed.first, feed.second);
status = io_binding_->BindInput(feed.first, feed.second);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "BindInput failed for " << feed.first
<< " TestCaseName:" << test_case->GetTestCaseName()
<< ", Error:" << status.ErrorMessage();
return false;
}
}
auto outputs = session_object_->GetModelOutputs();
auto status = outputs.first;
if (!outputs.first.IsOK()) {
LOGF_DEFAULT(ERROR, "GetOutputs failed, TestCaseName:%s, ErrorMessage:%s",
test_case->GetTestCaseName().c_str(),
status.ErrorMessage().c_str());
status = outputs.first;
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "GetOutputs failed. TestCaseName:" << test_case->GetTestCaseName()
<< ", Error:" << status.ErrorMessage();
return false;
}
@ -134,7 +149,13 @@ bool PerformanceRunner::Initialize() {
for (size_t i_output = 0; i_output < outputs.second->size(); ++i_output) {
auto output = outputs.second->at(i_output);
if (!output) continue;
io_binding_->BindOutput(output->Name(), output_mlvalues[i_output]);
status = io_binding_->BindOutput(output->Name(), output_mlvalues[i_output]);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "BindOutput failed for " << output->Name()
<< " TestCaseName:" << test_case->GetTestCaseName()
<< ", Error:" << status.ErrorMessage();
return false;
}
}
return true;