mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
### Description Added support for int64 and uint64 in Objective-C lib. ### Motivation and Context Int64 is rarely used, but we needed it. The Int64 inference worked after the change (tested).
133 lines
5 KiB
Text
133 lines
5 KiB
Text
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#import "ort_enums_internal.h"
|
|
|
|
#include <algorithm>
|
|
|
|
#import "src/cxx_api.h"
|
|
|
|
namespace {
|
|
|
|
struct LoggingLevelInfo {
|
|
ORTLoggingLevel logging_level;
|
|
OrtLoggingLevel capi_logging_level;
|
|
};
|
|
|
|
// supported ORT logging levels
|
|
// define the mapping from ORTLoggingLevel to C API OrtLoggingLevel here
|
|
constexpr LoggingLevelInfo kLoggingLevelInfos[]{
|
|
{ORTLoggingLevelVerbose, ORT_LOGGING_LEVEL_VERBOSE},
|
|
{ORTLoggingLevelInfo, ORT_LOGGING_LEVEL_INFO},
|
|
{ORTLoggingLevelWarning, ORT_LOGGING_LEVEL_WARNING},
|
|
{ORTLoggingLevelError, ORT_LOGGING_LEVEL_ERROR},
|
|
{ORTLoggingLevelFatal, ORT_LOGGING_LEVEL_FATAL},
|
|
};
|
|
|
|
struct ValueTypeInfo {
|
|
ORTValueType type;
|
|
ONNXType capi_type;
|
|
};
|
|
|
|
// supported ORT value types
|
|
// define the mapping from ORTValueType to C API ONNXType here
|
|
constexpr ValueTypeInfo kValueTypeInfos[]{
|
|
{ORTValueTypeUnknown, ONNX_TYPE_UNKNOWN},
|
|
{ORTValueTypeTensor, ONNX_TYPE_TENSOR},
|
|
};
|
|
|
|
struct TensorElementTypeInfo {
|
|
ORTTensorElementDataType type;
|
|
ONNXTensorElementDataType capi_type;
|
|
size_t element_size;
|
|
};
|
|
|
|
// supported ORT tensor element data types
|
|
// define the mapping from ORTTensorElementDataType to C API ONNXTensorElementDataType here
|
|
constexpr TensorElementTypeInfo kElementTypeInfos[]{
|
|
{ORTTensorElementDataTypeUndefined, ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, 0},
|
|
{ORTTensorElementDataTypeFloat, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)},
|
|
{ORTTensorElementDataTypeInt8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)},
|
|
{ORTTensorElementDataTypeUInt8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)},
|
|
{ORTTensorElementDataTypeInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)},
|
|
{ORTTensorElementDataTypeUInt32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)},
|
|
{ORTTensorElementDataTypeInt64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)},
|
|
{ORTTensorElementDataTypeUInt64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)},
|
|
};
|
|
|
|
struct GraphOptimizationLevelInfo {
|
|
ORTGraphOptimizationLevel opt_level;
|
|
GraphOptimizationLevel capi_opt_level;
|
|
};
|
|
|
|
// ORT graph optimization levels
|
|
// define the mapping from ORTGraphOptimizationLevel to C API GraphOptimizationLevel here
|
|
constexpr GraphOptimizationLevelInfo kGraphOptimizationLevelInfos[]{
|
|
{ORTGraphOptimizationLevelNone, ORT_DISABLE_ALL},
|
|
{ORTGraphOptimizationLevelBasic, ORT_ENABLE_BASIC},
|
|
{ORTGraphOptimizationLevelExtended, ORT_ENABLE_EXTENDED},
|
|
{ORTGraphOptimizationLevelAll, ORT_ENABLE_ALL},
|
|
};
|
|
|
|
template <typename Container, typename SelectFn, typename TransformFn>
|
|
auto SelectAndTransform(
|
|
const Container& container, SelectFn select_fn, TransformFn transform_fn,
|
|
const char* not_found_msg)
|
|
-> decltype(transform_fn(*std::begin(container))) {
|
|
const auto it = std::find_if(
|
|
std::begin(container), std::end(container), select_fn);
|
|
if (it == std::end(container)) {
|
|
ORT_CXX_API_THROW(not_found_msg, ORT_NOT_IMPLEMENTED);
|
|
}
|
|
return transform_fn(*it);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
OrtLoggingLevel PublicToCAPILoggingLevel(ORTLoggingLevel logging_level) {
|
|
return SelectAndTransform(
|
|
kLoggingLevelInfos,
|
|
[logging_level](const auto& logging_level_info) { return logging_level_info.logging_level == logging_level; },
|
|
[](const auto& logging_level_info) { return logging_level_info.capi_logging_level; },
|
|
"unsupported logging level");
|
|
}
|
|
|
|
ORTValueType CAPIToPublicValueType(ONNXType capi_type) {
|
|
return SelectAndTransform(
|
|
kValueTypeInfos,
|
|
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
|
[](const auto& type_info) { return type_info.type; },
|
|
"unsupported value type");
|
|
}
|
|
|
|
ONNXTensorElementDataType PublicToCAPITensorElementType(ORTTensorElementDataType type) {
|
|
return SelectAndTransform(
|
|
kElementTypeInfos,
|
|
[type](const auto& type_info) { return type_info.type == type; },
|
|
[](const auto& type_info) { return type_info.capi_type; },
|
|
"unsupported tensor element type");
|
|
}
|
|
|
|
ORTTensorElementDataType CAPIToPublicTensorElementType(ONNXTensorElementDataType capi_type) {
|
|
return SelectAndTransform(
|
|
kElementTypeInfos,
|
|
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
|
[](const auto& type_info) { return type_info.type; },
|
|
"unsupported tensor element type");
|
|
}
|
|
|
|
size_t SizeOfCAPITensorElementType(ONNXTensorElementDataType capi_type) {
|
|
return SelectAndTransform(
|
|
kElementTypeInfos,
|
|
[capi_type](const auto& type_info) { return type_info.capi_type == capi_type; },
|
|
[](const auto& type_info) { return type_info.element_size; },
|
|
"unsupported tensor element type");
|
|
}
|
|
|
|
GraphOptimizationLevel PublicToCAPIGraphOptimizationLevel(ORTGraphOptimizationLevel opt_level) {
|
|
return SelectAndTransform(
|
|
kGraphOptimizationLevelInfos,
|
|
[opt_level](const auto& opt_level_info) { return opt_level_info.opt_level == opt_level; },
|
|
[](const auto& opt_level_info) { return opt_level_info.capi_opt_level; },
|
|
"unsupported graph optimization level");
|
|
}
|