mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
42 lines
No EOL
1.7 KiB
C++
42 lines
No EOL
1.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "prediction_service_impl.h"
|
|
#include "request_id.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace server {
|
|
namespace grpc {
|
|
|
|
PredictionServiceImpl::PredictionServiceImpl(const std::shared_ptr<onnxruntime::server::ServerEnvironment>& env) : environment_(env) {}
|
|
|
|
::grpc::Status PredictionServiceImpl::Predict(::grpc::ServerContext* context, const ::onnxruntime::server::PredictRequest* request, ::onnxruntime::server::PredictResponse* response) {
|
|
auto request_id = SetRequestContext(context);
|
|
onnxruntime::server::Executor executor(environment_.get(), request_id);
|
|
//TODO: (csteegz) Add modelspec for both paths.
|
|
auto status = executor.Predict("default", "1", *request, *response); // Currently only support one model so hard coded.
|
|
if (!status.ok()) {
|
|
return ::grpc::Status(::grpc::StatusCode(status.error_code()), status.error_message());
|
|
}
|
|
return ::grpc::Status::OK;
|
|
}
|
|
|
|
std::string PredictionServiceImpl::SetRequestContext(::grpc::ServerContext* context) {
|
|
auto metadata = context->client_metadata();
|
|
auto request_id = util::InternalRequestId();
|
|
context->AddInitialMetadata(util::MS_REQUEST_ID_HEADER, request_id);
|
|
auto logger = environment_->GetLogger(request_id);
|
|
auto search = metadata.find(util::MS_CLIENT_REQUEST_ID_HEADER);
|
|
if (search != metadata.end()) {
|
|
std::string id{search->second.data(), search->second.length()};
|
|
context->AddInitialMetadata(util::MS_CLIENT_REQUEST_ID_HEADER, id);
|
|
logger->info("{}: [{}]", util::MS_CLIENT_REQUEST_ID_HEADER, id);
|
|
}
|
|
|
|
return request_id;
|
|
}
|
|
|
|
} // namespace grpc
|
|
} // namespace server
|
|
|
|
} // namespace onnxruntime
|