mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Split `IsTunbaleOpEnable` semantics into **enable tunable op for using** and **enable tunable op for tuning**. They remain disabled in general for safety purpose. But - if session is created with onnx model with tuning results embeded - the embedded tuning results is set to the EP without error `Status` then we automatically enable the using, tuning remains disabled. The planned options will be - `tunable_op_enable`: The top-level switch of `TunableOp`, indicate if we will run into `TunableOp` related logic. **NOTE:** most of our impls have a bottom impl that is acting as a fallback and is set as the default. In this case, we still call into the `TunableOp`, but no kernel selection, no kernel tuning and caching is involved. This reduced our maintainance burden of a duplicate code path. - `tunable_op_tuning_enable`: The secondary switch of `TunableOp`, indicate if we will run into the tuning related logic of `TunableOp` Then for the possible future options: - `tunable_op_tuning_max_iteration`: blahblah - `tunable_op_tuning_max_duration_ms`: blahblah - `tunable_op_flash_attention_enable`: blahblah, for example only, we will not have this. For developer oriented envvar, it is for developers' convenience to inspect the performance impact of tuning. So there is only `ORT_ROCM_TUNABLE_OP_ENABLE`, `ORT_ROCM_TUNABLE_OP_TUNING_ENABLE` to take the fine-grind control of combinations.
270 lines
10 KiB
C++
270 lines
10 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
|
|
#include "core/session/inference_session_utils.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
//---------------------
|
|
//--- local helpers ---
|
|
//---------------------
|
|
|
|
//--------------------------------------------
|
|
//--- session options related helpers ---
|
|
//--------------------------------------------
|
|
// Below are some helpers that will be used to set corresponding session option values
|
|
|
|
static Status SetIntraOpNumThreads(SessionOptions& session_options,
|
|
int value,
|
|
const logging::Logger& logger) {
|
|
if (value < 0) {
|
|
LOGS(logger, ERROR) << "Unsupported value for intra_op_num_threads: " << value;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported value for intra_op_num_threads: ", value);
|
|
}
|
|
|
|
LOGS(logger, INFO) << "Setting intra_op_num_threads to " << value;
|
|
session_options.intra_op_param.thread_pool_size = value;
|
|
return Status::OK();
|
|
}
|
|
|
|
static Status SetInterOpNumThreads(SessionOptions& session_options,
|
|
int value,
|
|
const logging::Logger& logger) {
|
|
if (value < 0) {
|
|
LOGS(logger, ERROR) << "Unsupported value for inter_op_num_threads: " << value;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported value for inter_op_num_threads: ", value);
|
|
}
|
|
|
|
LOGS(logger, INFO) << "Setting inter_op_num_threads to " << value;
|
|
session_options.inter_op_param.thread_pool_size = value;
|
|
return Status::OK();
|
|
}
|
|
|
|
static Status SetExecutionMode(SessionOptions& session_options,
|
|
int value,
|
|
const logging::Logger& logger) {
|
|
if (value != 0 && value != 1) {
|
|
LOGS(logger, ERROR) << "Unsupported execution_mode value in ORT config: " << value;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported execution_mode value in ORT config: ", value);
|
|
}
|
|
|
|
LOGS(logger, INFO) << "Setting execution_mode to " << (value == 0 ? "Sequential mode" : "Parallel mode");
|
|
session_options.execution_mode = (value == 0 ? ExecutionMode::ORT_SEQUENTIAL : ExecutionMode::ORT_PARALLEL);
|
|
return Status::OK();
|
|
}
|
|
|
|
static Status SetGraphOptimizationLevel(SessionOptions& session_options,
|
|
int value,
|
|
const logging::Logger& logger) {
|
|
switch (value) {
|
|
case ORT_DISABLE_ALL:
|
|
LOGS(logger, INFO) << "Setting graph_optimization_level to ORT_DISABLE_ALL";
|
|
session_options.graph_optimization_level = TransformerLevel::Default;
|
|
return Status::OK();
|
|
|
|
case ORT_ENABLE_BASIC:
|
|
LOGS(logger, INFO) << "Setting graph_optimization_level to ORT_ENABLE_BASIC";
|
|
session_options.graph_optimization_level = TransformerLevel::Level1;
|
|
return Status::OK();
|
|
|
|
case ORT_ENABLE_EXTENDED:
|
|
LOGS(logger, INFO) << "Setting graph_optimization_level to ORT_ENABLE_EXTENDED";
|
|
session_options.graph_optimization_level = TransformerLevel::Level2;
|
|
return Status::OK();
|
|
|
|
case ORT_ENABLE_ALL:
|
|
LOGS(logger, INFO) << "Setting graph_optimization_level to ORT_ENABLE_ALL";
|
|
session_options.graph_optimization_level = TransformerLevel::MaxLevel;
|
|
return Status::OK();
|
|
|
|
default:
|
|
std::ostringstream message_stream;
|
|
message_stream << "Unsupported graph_optimization_level value in ORT config: " << value;
|
|
|
|
std::string message = message_stream.str();
|
|
|
|
LOGS(logger, ERROR) << message;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, message);
|
|
}
|
|
}
|
|
|
|
static Status SetEnableProfiling(SessionOptions& session_options,
|
|
int value,
|
|
const logging::Logger& logger) {
|
|
if (value != 0 && value != 1) {
|
|
LOGS(logger, ERROR) << "Unsupported value for enable_profiling option: " << value;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported value for enable_profiling option: ", value);
|
|
}
|
|
|
|
LOGS(logger, INFO) << "Setting enable_profiling to " << (value == 0 ? "false" : "true");
|
|
session_options.enable_profiling = (value == 0 ? false : true);
|
|
return Status::OK();
|
|
}
|
|
|
|
// This function is called by nlohmann/json
|
|
void from_json(const json& j, TuningResults& trs) {
|
|
j.at("ep").get_to(trs.ep);
|
|
j.at("results").get_to(trs.results);
|
|
j.at("validators").get_to(trs.validators);
|
|
}
|
|
|
|
//---------------------------------------------------
|
|
//--- end of session options related helpers ---
|
|
//---------------------------------------------------
|
|
|
|
//---------------------
|
|
//--- end of local helpers ---
|
|
//---------------------
|
|
|
|
namespace inference_session_utils {
|
|
Status JsonConfigParser::ParseOrtConfigJsonInModelProto(const ONNX_NAMESPACE::ModelProto& model_proto) {
|
|
if (is_model_checked_for_ort_config_json_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The Model Proto has already been checked for the ORT config json.");
|
|
}
|
|
|
|
for (const auto& metadata_field : model_proto.metadata_props()) {
|
|
if (metadata_field.has_key() && metadata_field.key() == inference_session_utils::kOrtConfigKey) {
|
|
LOGS(logger_, INFO)
|
|
<< "Found session/run/environment configuration in the model file to be used while running the model";
|
|
|
|
auto status = Status::OK();
|
|
ORT_TRY {
|
|
const auto& val = metadata_field.value();
|
|
LOGS(logger_, INFO) << "ORT config json from the model: " << val;
|
|
|
|
parsed_json_ = json::parse(val);
|
|
// set the flag indicating that the model has the ORT config json.
|
|
is_ort_config_json_available_ = true;
|
|
}
|
|
ORT_CATCH(const std::exception& e) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
std::ostringstream message_stream;
|
|
message_stream << "Json stored in the `ort_config` key cannot be parsed. Error message: " << e.what();
|
|
|
|
std::string message = message_stream.str();
|
|
|
|
LOGS(logger_, ERROR) << message;
|
|
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, message);
|
|
});
|
|
}
|
|
ORT_RETURN_IF_ERROR(status);
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
// all steps complete, set the flag indicating that the model has been checked for the ORT config json.
|
|
is_model_checked_for_ort_config_json_ = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
Status JsonConfigParser::ParseSessionOptionsFromModelProto(SessionOptions& session_options) {
|
|
if (!is_model_checked_for_ort_config_json_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The Model Proto hasn't been checked for the ORT config json.");
|
|
}
|
|
|
|
if (!is_ort_config_json_available_ || !parsed_json_.contains(inference_session_utils::kSessionOptionsKey)) {
|
|
LOGS(logger_, INFO) << "Did not find session options in the model file to be used while running the model";
|
|
return Status::OK();
|
|
}
|
|
|
|
const auto& session_options_from_model =
|
|
parsed_json_.at(inference_session_utils::kSessionOptionsKey);
|
|
|
|
// TODO: Support all valid session options
|
|
// Only the following config options from the json will be supported in this version
|
|
// Any other option if part of the json (even if valid session option) will be ignored
|
|
|
|
for (const auto& it : session_options_from_model.items()) {
|
|
const auto& key = it.key();
|
|
const auto& value = it.value();
|
|
|
|
if (key == "intra_op_num_threads") {
|
|
if (!value.is_number_integer()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
|
"intra_op_num_threads option in the model file must be an integer");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(SetIntraOpNumThreads(session_options, it.value().get<int>(), logger_));
|
|
|
|
} else if (key == "inter_op_num_threads") {
|
|
if (!value.is_number_integer()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
|
"inter_op_num_threads option in the model file must be an integer");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(SetInterOpNumThreads(session_options, it.value().get<int>(), logger_));
|
|
|
|
} else if (key == "execution_mode") {
|
|
if (!value.is_number_integer()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
|
"execution_mode option in the model file must be an integer");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(SetExecutionMode(session_options, it.value().get<int>(), logger_));
|
|
|
|
} else if (key == "graph_optimization_level") {
|
|
if (!value.is_number_integer()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
|
"graph_optimization_level option in the model file must be an integer");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(SetGraphOptimizationLevel(session_options, it.value().get<int>(), logger_));
|
|
|
|
} else if (key == "enable_profiling") {
|
|
if (!value.is_number_integer()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
|
"enable_profiling option in the model file must be an integer");
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(SetEnableProfiling(session_options, it.value().get<int>(), logger_));
|
|
|
|
} else {
|
|
LOGS(logger_, INFO) << "Ignoring unsupported session option in ORT config: " << key;
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options*/) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
|
"Parsing RunOptions from ModelProto is not supported yet");
|
|
}
|
|
|
|
Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
|
|
std::vector<TuningResults>& results,
|
|
bool& key_found) {
|
|
results.clear();
|
|
key_found = false;
|
|
auto it = metadata.custom_metadata_map.find(kTuningResultsKeys);
|
|
if (it == metadata.custom_metadata_map.end()) {
|
|
return Status::OK();
|
|
}
|
|
|
|
key_found = true;
|
|
LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model";
|
|
|
|
Status status;
|
|
ORT_TRY {
|
|
auto parsed_tuning_results_json = json::parse(it->second);
|
|
results = parsed_tuning_results_json.get<std::vector<TuningResults>>();
|
|
}
|
|
ORT_CATCH(const std::exception& e) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = ORT_MAKE_STATUS(
|
|
ONNXRUNTIME, FAIL,
|
|
"Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring...");
|
|
});
|
|
ORT_RETURN_IF_ERROR(status);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace inference_session_utils
|
|
} // namespace onnxruntime
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|