2019-05-01 01:21:23 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
#include <stdio.h>
|
2019-07-04 08:08:14 +00:00
|
|
|
#include "serializing/mem_buffer.h"
|
|
|
|
|
#include "serializing/tensorprotoutils.h"
|
2019-05-01 01:21:23 +00:00
|
|
|
|
|
|
|
|
#include "onnx-ml.pb.h"
|
|
|
|
|
#include "predict.pb.h"
|
|
|
|
|
|
|
|
|
|
#include "converter.h"
|
|
|
|
|
#include "executor.h"
|
|
|
|
|
#include "util.h"
|
|
|
|
|
|
|
|
|
|
namespace onnxruntime {
|
|
|
|
|
namespace server {
|
|
|
|
|
|
|
|
|
|
namespace protobufutil = google::protobuf::util;
|
|
|
|
|
|
|
|
|
|
protobufutil::Status Executor::SetMLValue(const onnx::TensorProto& input_tensor,
|
2019-05-10 05:51:57 +00:00
|
|
|
MemBufferArray& buffers,
|
2019-09-12 15:19:29 +00:00
|
|
|
OrtMemoryInfo* cpu_memory_info,
|
2019-07-04 08:08:14 +00:00
|
|
|
/* out */ Ort::Value& ml_value) {
|
2019-05-01 01:21:23 +00:00
|
|
|
auto logger = env_->GetLogger(request_id_);
|
|
|
|
|
|
|
|
|
|
size_t cpu_tensor_length = 0;
|
2019-07-04 08:08:14 +00:00
|
|
|
try {
|
|
|
|
|
onnxruntime::server::GetSizeInBytesFromTensorProto<0>(input_tensor, &cpu_tensor_length);
|
|
|
|
|
} catch (const Ort::Exception& e) {
|
|
|
|
|
logger->error("GetSizeInBytesFromTensorProto() failed. Error Message: {}", e.what());
|
|
|
|
|
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
2019-05-10 05:51:57 +00:00
|
|
|
auto* buf = buffers.AllocNewBuffer(cpu_tensor_length);
|
2019-07-04 08:08:14 +00:00
|
|
|
try {
|
|
|
|
|
onnxruntime::server::TensorProtoToMLValue(input_tensor,
|
2019-09-12 15:19:29 +00:00
|
|
|
onnxruntime::server::MemBuffer(buf, cpu_tensor_length, *cpu_memory_info),
|
2019-07-04 08:08:14 +00:00
|
|
|
ml_value);
|
|
|
|
|
|
|
|
|
|
} catch (const Ort::Exception& e) {
|
|
|
|
|
logger->error("TensorProtoToMLValue() failed. Message: {}", e.what());
|
|
|
|
|
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return protobufutil::Status::OK;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
protobufutil::Status Executor::SetNameMLValueMap(std::vector<std::string>& input_names,
|
|
|
|
|
std::vector<Ort::Value>& input_values,
|
2019-05-10 05:51:57 +00:00
|
|
|
const onnxruntime::server::PredictRequest& request,
|
|
|
|
|
MemBufferArray& buffers) {
|
2019-05-01 01:21:23 +00:00
|
|
|
auto logger = env_->GetLogger(request_id_);
|
|
|
|
|
|
2019-09-12 15:19:29 +00:00
|
|
|
OrtMemoryInfo* memory_info = nullptr;
|
2019-10-23 21:15:53 +00:00
|
|
|
auto ort_status = Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info);
|
2019-05-10 05:51:57 +00:00
|
|
|
|
2019-09-12 15:19:29 +00:00
|
|
|
if (ort_status != nullptr || memory_info == nullptr) {
|
|
|
|
|
logger->error("OrtCreateCpuMemoryInfo failed");
|
|
|
|
|
return protobufutil::Status(protobufutil::error::Code::RESOURCE_EXHAUSTED, "OrtCreateCpuMemoryInfo() failed");
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
// Prepare the Value object
|
2019-05-01 01:21:23 +00:00
|
|
|
for (const auto& input : request.inputs()) {
|
|
|
|
|
using_raw_data_ = using_raw_data_ && input.second.has_raw_data();
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
Ort::Value ml_value{nullptr};
|
2019-09-12 15:19:29 +00:00
|
|
|
auto status = SetMLValue(input.second, buffers, memory_info, ml_value);
|
2019-05-01 01:21:23 +00:00
|
|
|
if (status != protobufutil::Status::OK) {
|
2019-10-23 21:15:53 +00:00
|
|
|
Ort::GetApi().ReleaseMemoryInfo(memory_info);
|
2019-07-04 08:08:14 +00:00
|
|
|
logger->error("SetMLValue() failed! Input name: {}", input.first);
|
2019-05-01 01:21:23 +00:00
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
input_names.push_back(input.first);
|
|
|
|
|
input_values.push_back(std::move(ml_value));
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
2019-10-23 21:15:53 +00:00
|
|
|
Ort::GetApi().ReleaseMemoryInfo(memory_info);
|
2019-05-01 01:21:23 +00:00
|
|
|
return protobufutil::Status::OK;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
std::vector<Ort::Value> Run(const Ort::Session& session, const Ort::RunOptions& options, const std::vector<std::string>& input_names, const std::vector<Ort::Value>& input_values, const std::vector<std::string>& output_names) {
|
|
|
|
|
size_t input_count = input_names.size();
|
|
|
|
|
size_t output_count = output_names.size();
|
|
|
|
|
|
|
|
|
|
std::vector<const char*> input_ptrs{};
|
|
|
|
|
input_ptrs.reserve(input_count);
|
|
|
|
|
for (const auto& input : input_names) {
|
|
|
|
|
input_ptrs.push_back(input.data());
|
|
|
|
|
}
|
|
|
|
|
std::vector<const char*> output_ptrs{};
|
|
|
|
|
output_ptrs.reserve(output_count);
|
|
|
|
|
for (const auto& output : output_names) {
|
|
|
|
|
output_ptrs.push_back(output.data());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return const_cast<Ort::Session&>(session).Run(options, input_ptrs.data(), const_cast<Ort::Value*>(input_values.data()), input_count, output_ptrs.data(), output_count);
|
|
|
|
|
}
|
|
|
|
|
|
2019-05-01 01:21:23 +00:00
|
|
|
protobufutil::Status Executor::Predict(const std::string& model_name,
|
|
|
|
|
const std::string& model_version,
|
2019-07-18 18:10:38 +00:00
|
|
|
const onnxruntime::server::PredictRequest& request,
|
2019-05-01 01:21:23 +00:00
|
|
|
/* out */ onnxruntime::server::PredictResponse& response) {
|
|
|
|
|
auto logger = env_->GetLogger(request_id_);
|
|
|
|
|
|
|
|
|
|
// Convert PredictRequest to NameMLValMap
|
2019-05-10 05:51:57 +00:00
|
|
|
MemBufferArray buffer_array;
|
2019-07-04 08:08:14 +00:00
|
|
|
std::vector<std::string> input_names;
|
|
|
|
|
std::vector<Ort::Value> input_values;
|
|
|
|
|
auto conversion_status = SetNameMLValueMap(input_names, input_values, request, buffer_array);
|
2019-05-01 01:21:23 +00:00
|
|
|
if (conversion_status != protobufutil::Status::OK) {
|
|
|
|
|
return conversion_status;
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
Ort::RunOptions run_options{};
|
2019-07-23 06:25:53 +00:00
|
|
|
run_options.SetRunLogVerbosityLevel(static_cast<int>(env_->GetLogSeverity()));
|
2019-07-04 08:08:14 +00:00
|
|
|
run_options.SetRunTag(request_id_.c_str());
|
|
|
|
|
|
|
|
|
|
// Prepare the output names
|
2019-05-01 01:21:23 +00:00
|
|
|
std::vector<std::string> output_names;
|
|
|
|
|
|
|
|
|
|
if (!request.output_filter().empty()) {
|
|
|
|
|
output_names.reserve(request.output_filter_size());
|
|
|
|
|
for (const auto& name : request.output_filter()) {
|
|
|
|
|
output_names.push_back(name);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
2019-09-19 16:04:12 +00:00
|
|
|
output_names = env_->GetModelOutputNames(model_name, model_version);
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
2019-07-04 08:08:14 +00:00
|
|
|
std::vector<Ort::Value> outputs;
|
|
|
|
|
try {
|
2019-09-19 16:04:12 +00:00
|
|
|
outputs = Run(env_->GetSession(model_name, model_version), run_options, input_names, input_values, output_names);
|
2019-07-04 08:08:14 +00:00
|
|
|
} catch (const Ort::Exception& e) {
|
|
|
|
|
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build the response
|
|
|
|
|
for (size_t i = 0, sz = outputs.size(); i < sz; ++i) {
|
|
|
|
|
onnx::TensorProto output_tensor{};
|
2019-07-04 08:08:14 +00:00
|
|
|
try {
|
|
|
|
|
MLValueToTensorProto(outputs[i], using_raw_data_, logger, output_tensor);
|
|
|
|
|
} catch (const Ort::Exception& e) {
|
|
|
|
|
logger = env_->GetLogger(request_id_);
|
|
|
|
|
logger->error("MLValueToTensorProto() failed. Output name: {}. Error Message: {}", output_names[i], e.what());
|
|
|
|
|
return GenerateProtobufStatus(e.GetOrtErrorCode(), e.what());
|
2019-05-01 01:21:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto insertion_result = response.mutable_outputs()->insert({output_names[i], output_tensor});
|
|
|
|
|
|
|
|
|
|
if (!insertion_result.second) {
|
2019-07-04 08:08:14 +00:00
|
|
|
logger->error("SetNameMLValueMap() failed. Output name: {}. Trying to overwrite existing output value", output_names[i]);
|
2019-05-01 01:21:23 +00:00
|
|
|
return protobufutil::Status(protobufutil::error::Code::INVALID_ARGUMENT, "SetNameMLValueMap() failed: Cannot have two outputs with the same name");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return protobufutil::Status::OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace server
|
2019-09-20 20:39:11 +00:00
|
|
|
} // namespace onnxruntime
|