mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
134 lines
4.9 KiB
C++
134 lines
4.9 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include <thread>
|
|
#include <fstream>
|
|
#include <unordered_map>
|
|
|
|
#include "boost/program_options.hpp"
|
|
#include "onnxruntime_cxx_api.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace server {
|
|
|
|
namespace po = boost::program_options;
|
|
|
|
// Enumerates the different type of results which can occur
|
|
// The three different types are:
|
|
// 0. ExitSuccess which is when the program should exit with EXIT_SUCCESS
|
|
// 1. ExitFailure when program should exit with EXIT_FAILURE
|
|
// 2. No need for exiting the program, continue
|
|
enum class Result {
|
|
ExitSuccess,
|
|
ExitFailure,
|
|
ContinueSuccess
|
|
};
|
|
|
|
static std::unordered_map<std::string, OrtLoggingLevel> supported_log_levels{
|
|
{"verbose", ORT_LOGGING_LEVEL_VERBOSE},
|
|
{"info", ORT_LOGGING_LEVEL_INFO},
|
|
{"warning", ORT_LOGGING_LEVEL_WARNING},
|
|
{"error", ORT_LOGGING_LEVEL_ERROR},
|
|
{"fatal", ORT_LOGGING_LEVEL_FATAL}};
|
|
|
|
// Wrapper around Boost program_options and should provide all the functionality for options parsing
|
|
// Provides sane default values
|
|
class ServerConfiguration {
|
|
public:
|
|
const std::string full_desc = "ONNX Server: host an ONNX model with ONNX Runtime";
|
|
std::string model_path;
|
|
std::string model_name = "default";
|
|
std::string model_version = "1";
|
|
std::string address = "0.0.0.0";
|
|
unsigned short http_port = 8001;
|
|
unsigned short grpc_port = 50051;
|
|
int num_http_threads = std::thread::hardware_concurrency();
|
|
OrtLoggingLevel logging_level{};
|
|
|
|
ServerConfiguration() {
|
|
desc.add_options()("help,h", "Shows a help message and exits");
|
|
desc.add_options()("log_level", po::value(&log_level_str)->default_value(log_level_str), "Logging level. Allowed options (case sensitive): verbose, info, warning, error, fatal");
|
|
desc.add_options()("model_path", po::value(&model_path)->required(), "Path to ONNX model");
|
|
desc.add_options()("model_name", po::value(&model_name)->default_value(model_name), "ONNX model name");
|
|
desc.add_options()("model_version", po::value(&model_version)->default_value(model_version), "ONNX model version");
|
|
desc.add_options()("address", po::value(&address)->default_value(address), "The base HTTP address");
|
|
desc.add_options()("http_port", po::value(&http_port)->default_value(http_port), "HTTP port to listen to requests");
|
|
desc.add_options()("num_http_threads", po::value(&num_http_threads)->default_value(num_http_threads), "Number of http threads");
|
|
desc.add_options()("grpc_port", po::value(&grpc_port)->default_value(grpc_port), "GRPC port to listen to requests");
|
|
}
|
|
|
|
// Parses argc and argv and sets the values for the class
|
|
// Returns an enum with three options: ExitSuccess, ExitFailure, ContinueSuccess
|
|
// ExitSuccess and ExitFailure means the program should exit but is left to the caller
|
|
Result ParseInput(int argc, char** argv) {
|
|
try {
|
|
po::store(po::command_line_parser(argc, argv).options(desc).run(), vm); // can throw
|
|
|
|
if (ContainsHelp()) {
|
|
PrintHelp(std::cout, full_desc);
|
|
return Result::ExitSuccess;
|
|
}
|
|
|
|
po::notify(vm); // throws on error, so do after help
|
|
} catch (const po::error& e) {
|
|
PrintHelp(std::cerr, e.what());
|
|
return Result::ExitFailure;
|
|
} catch (const std::exception& e) {
|
|
PrintHelp(std::cerr, e.what());
|
|
return Result::ExitFailure;
|
|
}
|
|
|
|
Result result = ValidateOptions();
|
|
|
|
if (result == Result::ContinueSuccess) {
|
|
logging_level = supported_log_levels[log_level_str];
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
po::options_description desc{"Allowed options"};
|
|
po::variables_map vm{};
|
|
std::string log_level_str = "info";
|
|
|
|
// Print help and return if there is a bad value
|
|
Result ValidateOptions() {
|
|
if (vm.count("log_level") &&
|
|
supported_log_levels.find(log_level_str) == supported_log_levels.end()) {
|
|
PrintHelp(std::cerr, "log_level must be one of verbose, info, warning, error, or fatal");
|
|
return Result::ExitFailure;
|
|
} else if (num_http_threads <= 0) {
|
|
PrintHelp(std::cerr, "num_http_threads must be greater than 0");
|
|
return Result::ExitFailure;
|
|
} else if (!file_exists(model_path)) {
|
|
PrintHelp(std::cerr, "model_path must be the location of a valid file");
|
|
return Result::ExitFailure;
|
|
} else {
|
|
return Result::ContinueSuccess;
|
|
}
|
|
}
|
|
|
|
// Checks if program options contains help
|
|
bool ContainsHelp() const {
|
|
return vm.count("help") || vm.count("h");
|
|
}
|
|
|
|
// Prints a helpful message (param: what) to the user and then the program options
|
|
// Example: config.PrintHelp(std::cout, "Non-negative values not allowed")
|
|
// Which will print that message and then all publicly available options
|
|
void PrintHelp(std::ostream& out, const std::string& what) const {
|
|
out << what << std::endl
|
|
<< desc << std::endl;
|
|
}
|
|
|
|
inline bool file_exists(const std::string& fileName) {
|
|
std::ifstream infile(fileName.c_str());
|
|
return infile.good();
|
|
}
|
|
};
|
|
|
|
} // namespace server
|
|
} // namespace onnxruntime
|