mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
148 lines
No EOL
4.4 KiB
C++
148 lines
No EOL
4.4 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "session.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace server {
|
|
|
|
namespace net = boost::asio; // from <boost/asio.hpp>
|
|
namespace beast = boost::beast; // from <boost/beast.hpp>
|
|
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
|
|
|
|
HttpSession::HttpSession(const Routes& routes, tcp::socket socket)
|
|
: routes_(routes), socket_(std::move(socket)), strand_(socket_.get_executor()) {
|
|
}
|
|
|
|
void HttpSession::DoRead() {
|
|
req_.emplace();
|
|
|
|
// TODO: make the max request size configable.
|
|
req_->body_limit(25 * 1024 * 1024); // Max request size: 25 MiB
|
|
|
|
http::async_read(socket_, buffer_, *req_,
|
|
net::bind_executor(
|
|
strand_,
|
|
std::bind(
|
|
&HttpSession::OnRead,
|
|
shared_from_this(),
|
|
std::placeholders::_1,
|
|
std::placeholders::_2)));
|
|
}
|
|
|
|
void HttpSession::OnRead(beast::error_code ec, std::size_t bytes_transferred) {
|
|
boost::ignore_unused(bytes_transferred);
|
|
|
|
// This means they closed the connection
|
|
if (ec == http::error::end_of_stream) {
|
|
return DoClose();
|
|
}
|
|
|
|
if (ec) {
|
|
ErrorHandling(ec, "read");
|
|
return;
|
|
}
|
|
|
|
// Send the response
|
|
HandleRequest(req_->release());
|
|
}
|
|
|
|
void HttpSession::OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close) {
|
|
boost::ignore_unused(bytes_transferred);
|
|
|
|
if (ec) {
|
|
ErrorHandling(ec, "write");
|
|
return;
|
|
}
|
|
|
|
if (close) {
|
|
// This means we should close the connection, usually because
|
|
// the response indicated the "Connection: close" semantic.
|
|
return DoClose();
|
|
}
|
|
|
|
// We're done with the response so delete it
|
|
res_ = nullptr;
|
|
|
|
// Read another request
|
|
DoRead();
|
|
}
|
|
|
|
void HttpSession::DoClose() {
|
|
// Send a TCP shutdown
|
|
beast::error_code ec;
|
|
socket_.shutdown(tcp::socket::shutdown_send, ec);
|
|
|
|
// At this point the connection is closed gracefully
|
|
}
|
|
|
|
template <class Msg>
|
|
void HttpSession::Send(Msg&& msg) {
|
|
using item_type = std::remove_reference_t<decltype(msg)>;
|
|
|
|
auto ptr = std::make_shared<item_type>(std::move(msg));
|
|
auto self_ = shared_from_this();
|
|
self_->res_ = ptr;
|
|
|
|
http::async_write(self_->socket_, *ptr,
|
|
net::bind_executor(strand_,
|
|
[self_, close = ptr->need_eof()](beast::error_code ec, std::size_t bytes) {
|
|
self_->OnWrite(ec, bytes, close);
|
|
}));
|
|
}
|
|
|
|
template <typename Body, typename Allocator>
|
|
void HttpSession::HandleRequest(http::request<Body, http::basic_fields<Allocator> >&& req) {
|
|
HttpContext context{};
|
|
context.request = std::move(req);
|
|
|
|
// Special handle the liveness probe endpoint for orchestration systems like Kubernetes.
|
|
if (context.request.method() == http::verb::get && context.request.target().to_string() == "/") {
|
|
context.response.body() = "Healthy";
|
|
} else {
|
|
auto status = ExecuteUserFunction(context);
|
|
|
|
if (status != http::status::ok) {
|
|
routes_.on_error(context);
|
|
}
|
|
}
|
|
|
|
context.response.keep_alive(context.request.keep_alive());
|
|
context.response.prepare_payload();
|
|
return Send(std::move(context.response));
|
|
}
|
|
|
|
http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
|
|
std::string path = context.request.target().to_string();
|
|
std::string model_name, model_version, action;
|
|
HandlerFn func;
|
|
|
|
if (context.request.find(util::MS_CLIENT_REQUEST_ID_HEADER) != context.request.end()) {
|
|
context.client_request_id = context.request[util::MS_CLIENT_REQUEST_ID_HEADER].to_string();
|
|
}
|
|
|
|
|
|
auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);
|
|
|
|
if (status != http::status::ok) {
|
|
context.error_code = status;
|
|
context.error_message = std::string(http::obsolete_reason(status)) +
|
|
". For HTTP method: " +
|
|
std::string(http::to_string(context.request.method())) +
|
|
" and request path: " +
|
|
context.request.target().to_string();
|
|
return status;
|
|
}
|
|
|
|
try {
|
|
func(model_name, model_version, action, context);
|
|
} catch (const std::exception& ex) {
|
|
context.error_message = std::string(ex.what());
|
|
return http::status::internal_server_error;
|
|
}
|
|
|
|
return http::status::ok;
|
|
}
|
|
|
|
} // namespace server
|
|
} // namespace onnxruntime
|