mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
* Dynamically link onnxruntime.dll * fixes * add preceeding backslash to onnxruntime.dll for inbox builds * remove /d * loadlibrary -> loadlibraryex * use loadlibrary system32 option Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
1355 lines
No EOL
56 KiB
C++
1355 lines
No EOL
56 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
#include "pch.h"
|
|
|
|
#include "OnnxruntimeEngine.h"
|
|
|
|
#include "PheonixSingleton.h"
|
|
#include "OnnxruntimeEnvironment.h"
|
|
#include "OnnxruntimeEngineBuilder.h"
|
|
#include "OnnxruntimeModel.h"
|
|
#include "OnnxruntimeSessionBuilder.h"
|
|
#include "OnnxruntimeErrors.h"
|
|
|
|
using namespace _winml;
|
|
|
|
static ONNXTensorElementDataType
|
|
ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) {
|
|
switch (kind) {
|
|
case winml::TensorKind::Boolean: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
|
}
|
|
case winml::TensorKind::String: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
}
|
|
case winml::TensorKind::Float16: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
|
}
|
|
case winml::TensorKind::Float: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
}
|
|
case winml::TensorKind::Double: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
|
}
|
|
case winml::TensorKind::Int8: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
|
}
|
|
case winml::TensorKind::Int16: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
|
}
|
|
case winml::TensorKind::Int32: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
|
}
|
|
case winml::TensorKind::Int64: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
}
|
|
case winml::TensorKind::UInt8: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
|
}
|
|
case winml::TensorKind::UInt16: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
|
}
|
|
case winml::TensorKind::UInt32: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
|
}
|
|
case winml::TensorKind::UInt64: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
|
}
|
|
case winml::TensorKind::Complex64: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
|
|
}
|
|
case winml::TensorKind::Complex128: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
|
|
}
|
|
default: {
|
|
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
}
|
|
}
|
|
}
|
|
|
|
OnnxruntimeValue::OnnxruntimeValue() : value_(nullptr, nullptr), allocator_(nullptr, nullptr) {}
|
|
|
|
OnnxruntimeValue::~OnnxruntimeValue() {
|
|
value_.reset(nullptr);
|
|
allocator_.reset(nullptr);
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::RuntimeClassInitialize(OnnxruntimeEngine* engine, UniqueOrtValue&& ort_value, UniqueOrtAllocator&& allocator) {
|
|
engine_ = engine;
|
|
value_ = std::move(ort_value);
|
|
allocator_ = std::move(allocator);
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsEmpty(bool* out) {
|
|
*out = UseOrtValue() == nullptr;
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsCpu(bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
|
|
|
|
OrtMemoryInfo* ort_memory_info;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info),
|
|
ort_api);
|
|
auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo);
|
|
|
|
const char* name;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name),
|
|
ort_api);
|
|
|
|
OrtMemType type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type),
|
|
ort_api);
|
|
|
|
*out = !strcmp(name, "Cpu") ||
|
|
type == OrtMemType::OrtMemTypeCPUOutput ||
|
|
type == OrtMemType::OrtMemTypeCPUInput;
|
|
return S_OK;
|
|
}
|
|
|
|
static int64_t ShapeSize(const int64_t* shape, size_t count) {
|
|
// for each dim
|
|
int64_t size = 1;
|
|
for (size_t i = 0; i < count; i++) {
|
|
// find out it's total size
|
|
size *= shape[i];
|
|
// make sure there are no invalid dimensions (-1 or any invalid shape)
|
|
THROW_HR_IF(E_INVALIDARG, shape[i] <= 0);
|
|
}
|
|
return size;
|
|
}
|
|
|
|
static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value,
|
|
OrtTensorTypeAndShapeInfo* type_and_shape_info) {
|
|
std::vector<std::string> out;
|
|
|
|
size_t size;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info, &size),
|
|
ort_api);
|
|
|
|
std::vector<int64_t> shape(size);
|
|
|
|
if (size > 0) {
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info, &shape[0], size),
|
|
ort_api);
|
|
}
|
|
auto length = ShapeSize(shape.data(), shape.size());
|
|
|
|
// make a big buffer to hold all the string data
|
|
size_t buffer_length;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorDataLength(ort_value, &buffer_length),
|
|
ort_api);
|
|
|
|
std::vector<std::string_view> strings;
|
|
std::unique_ptr<uint8_t[]> buffer(new uint8_t[buffer_length]);
|
|
std::vector<size_t> offsets(static_cast<size_t>(length));
|
|
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorContent(ort_value, buffer.get(), buffer_length, offsets.data(), offsets.size()),
|
|
ort_api);
|
|
|
|
// now go build all the strings
|
|
for (auto i = 0; i < length; ++i) {
|
|
size_t str_len = 0;
|
|
// are we on the last one?
|
|
if (i == (length - 1)) {
|
|
str_len = buffer_length - offsets[i];
|
|
} else {
|
|
str_len = offsets[i + 1] - offsets[i];
|
|
}
|
|
strings.push_back(std::string_view(reinterpret_cast<const char*>(buffer.get() + offsets[i]), str_len));
|
|
}
|
|
|
|
return std::make_shared<std::pair<decltype(strings), decltype(buffer)>>(std::move(strings), std::move(buffer));
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
|
|
|
|
void* mutable_data = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data),
|
|
ort_api);
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider),
|
|
ort_api);
|
|
|
|
bool is_cpu = false;
|
|
if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) {
|
|
void* resource;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data,
|
|
reinterpret_cast<ID3D12Resource**>(&resource)),
|
|
ort_api);
|
|
out = _winml::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ });
|
|
} else {
|
|
int is_tensor;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->IsTensor(value_.get(), &is_tensor),
|
|
ort_api);
|
|
if (is_tensor == 0) {
|
|
out = _winml::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ });
|
|
return S_OK;
|
|
}
|
|
|
|
OrtTensorTypeAndShapeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info),
|
|
ort_api);
|
|
auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo);
|
|
|
|
ONNXTensorElementDataType data_type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type),
|
|
ort_api);
|
|
|
|
if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
|
auto strings = GetStrings(ort_api, value_.get(), info);
|
|
auto string_data = strings->first.data();
|
|
out = _winml::Resource(string_data, [capture_strings = strings](void*) { /*This deleter does nothing but capture the strings, which extends the lifetime of the returned strings.*/ });
|
|
} else {
|
|
out = _winml::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ });
|
|
}
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsTensor(bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
|
|
ONNXType type = ONNXType::ONNX_TYPE_UNKNOWN;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueType(value_.get(), &type),
|
|
ort_api);
|
|
*out = type == ONNXType::ONNX_TYPE_TENSOR;
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsOfTensorType(winml::TensorKind kind, bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
OrtTensorTypeAndShapeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info),
|
|
ort_api);
|
|
auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo);
|
|
|
|
ONNXTensorElementDataType data_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type),
|
|
ort_api);
|
|
|
|
*out = data_type == ONNXTensorElementDataTypeFromTensorKind(kind);
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::GetTensorShape(std::vector<int64_t>& shape_vector) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
OrtTensorTypeAndShapeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info),
|
|
ort_api);
|
|
auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo);
|
|
|
|
size_t size;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info.get(), &size),
|
|
ort_api);
|
|
|
|
std::vector<int64_t> shape(size);
|
|
if (size > 0) {
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info.get(), &shape[0], size),
|
|
ort_api);
|
|
}
|
|
|
|
shape_vector = std::move(shape);
|
|
return S_OK;
|
|
}
|
|
|
|
static bool EnsureMapTypeInfo(OnnxruntimeEngine* engine, OrtTypeInfo* type_info, winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
auto ort_api = engine->GetEngineFactory()->UseOrtApi();
|
|
|
|
const OrtMapTypeInfo* map_info;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CastTypeInfoToMapTypeInfo(type_info, &map_info),
|
|
ort_api);
|
|
|
|
if (map_info == nullptr) {
|
|
// It must be a seq<tensor<*>> type
|
|
return false;
|
|
}
|
|
|
|
ONNXTensorElementDataType map_key_type;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetMapKeyType(map_info, &map_key_type),
|
|
ort_api);
|
|
|
|
if (map_key_type == ONNXTensorElementDataTypeFromTensorKind(key_kind)) {
|
|
OrtTypeInfo* value_info;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetMapValueType(map_info, &value_info),
|
|
ort_api);
|
|
auto map_value_info = UniqueOrtTypeInfo(value_info, ort_api->ReleaseTypeInfo);
|
|
|
|
const OrtTensorTypeAndShapeInfo* value_tensor_info = nullptr;
|
|
THROW_IF_NOT_OK_MSG(ort_api->CastTypeInfoToTensorInfo(map_value_info.get(), &value_tensor_info),
|
|
ort_api);
|
|
|
|
if (value_tensor_info) {
|
|
ONNXTensorElementDataType map_value_tensor_type;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetTensorElementType(value_tensor_info, &map_value_tensor_type),
|
|
ort_api);
|
|
|
|
if (map_value_tensor_type == ONNXTensorElementDataTypeFromTensorKind(value_kind)) {
|
|
size_t num_dims;
|
|
THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(value_tensor_info, &num_dims),
|
|
ort_api);
|
|
|
|
return num_dims == 0;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsOfMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
|
|
OrtTypeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info),
|
|
ort_api);
|
|
auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo);
|
|
|
|
ONNXType type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type),
|
|
ort_api);
|
|
|
|
if (type == ONNXType::ONNX_TYPE_MAP) {
|
|
*out = EnsureMapTypeInfo(engine_.Get(), unique_type_info.get(), key_kind, value_kind);
|
|
}
|
|
|
|
*out = false;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsOfVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
|
|
OrtTypeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info),
|
|
ort_api);
|
|
auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo);
|
|
|
|
ONNXType type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type),
|
|
ort_api);
|
|
|
|
if (type == ONNXType::ONNX_TYPE_SEQUENCE) {
|
|
const OrtSequenceTypeInfo* sequence_info;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info),
|
|
ort_api);
|
|
|
|
OrtTypeInfo* element_info;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetSequenceElementType(sequence_info, &element_info),
|
|
ort_api);
|
|
auto unique_element_info = UniqueOrtTypeInfo(element_info, ort_api->ReleaseTypeInfo);
|
|
|
|
*out = EnsureMapTypeInfo(engine_.Get(), unique_element_info.get(), key_kind, value_kind);
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::IsOfVectorTensorType(winml::TensorKind kind, bool* out) {
|
|
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
|
|
|
*out = false;
|
|
|
|
OrtTypeInfo* info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info),
|
|
ort_api);
|
|
auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo);
|
|
|
|
ONNXType type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type),
|
|
ort_api);
|
|
|
|
if (type == ONNXType::ONNX_TYPE_SEQUENCE) {
|
|
const OrtSequenceTypeInfo* sequence_info;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info),
|
|
ort_api);
|
|
|
|
OrtTypeInfo* element_info;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetSequenceElementType(sequence_info, &element_info),
|
|
ort_api);
|
|
auto unique_element_info = UniqueOrtTypeInfo(element_info, ort_api->ReleaseTypeInfo);
|
|
|
|
ONNXType element_type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_element_info.get(), &element_type),
|
|
ort_api);
|
|
|
|
if (element_type == ONNXType::ONNX_TYPE_TENSOR) {
|
|
const OrtTensorTypeAndShapeInfo* element_tensor_info = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CastTypeInfoToTensorInfo(unique_element_info.get(), &element_tensor_info),
|
|
ort_api);
|
|
|
|
if (element_tensor_info) {
|
|
ONNXTensorElementDataType element_tensor_type;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(element_tensor_info, &element_tensor_type),
|
|
ort_api);
|
|
*out = element_tensor_type == ONNXTensorElementDataTypeFromTensorKind(kind);
|
|
}
|
|
}
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::SetParameter(IUnknown* param) {
|
|
param_ = param;
|
|
return S_OK;
|
|
}
|
|
|
|
OrtValue* OnnxruntimeValue::UseOrtValue() {
|
|
return value_.get();
|
|
}
|
|
|
|
HRESULT OnnxruntimeValue::AssignOrtValue(OrtValue* in) {
|
|
value_.reset(in);
|
|
return S_OK;
|
|
}
|
|
|
|
OnnxruntimeEngine::OnnxruntimeEngine() : session_(nullptr, nullptr) {
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory,
|
|
UniqueOrtSession&& session,
|
|
IOrtSessionBuilder* session_builder) {
|
|
engine_factory_ = engine_factory;
|
|
session_ = std::move(session);
|
|
session_builder_ = session_builder;
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::LoadModel(_In_ IModel* model) {
|
|
Microsoft::WRL::ComPtr<IOnnxruntimeModel> onnxruntime_model;
|
|
RETURN_IF_FAILED(model->QueryInterface(IID_PPV_ARGS(&onnxruntime_model)));
|
|
|
|
OrtModel* ort_model;
|
|
RETURN_IF_FAILED(onnxruntime_model->DetachOrtModel(&ort_model));
|
|
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session_.get(), ort_model),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::Initialize() {
|
|
RETURN_IF_FAILED(session_builder_->Initialize(session_.get()));
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::RegisterGraphTransformers() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterGraphTransformers(session_.get()),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::RegisterCustomRegistry(IMLOperatorRegistry* registry) {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session_.get(), registry),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::EndProfiling() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionEndProfiling(session_.get()),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::StartProfiling() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtEnv* ort_env;
|
|
engine_factory_->GetOrtEnvironment(&ort_env);
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionStartProfiling(ort_env, session_.get()),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::FlushContext() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
auto src_value = static_cast<OnnxruntimeValue*>(src);
|
|
auto dest_value = static_cast<OnnxruntimeValue*>(dest);
|
|
|
|
bool is_empty;
|
|
auto has_null_source = (SUCCEEDED(src_value->IsEmpty(&is_empty)) && is_empty);
|
|
RETURN_HR_IF(E_FAIL, has_null_source);
|
|
|
|
auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty);
|
|
RETURN_HR_IF(E_FAIL, has_null_dest);
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::Sync() {
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
OrtSession* OnnxruntimeEngine::UseOrtSession() {
|
|
return session_.get();
|
|
}
|
|
|
|
const OrtApi* OnnxruntimeEngine::UseOrtApi() {
|
|
return engine_factory_->UseOrtApi();
|
|
}
|
|
|
|
OnnxruntimeEngineFactory* OnnxruntimeEngine::GetEngineFactory() {
|
|
return engine_factory_.Get();
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CreateTensorValueFromDefaultAllocator(const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
OrtAllocator* ort_allocator;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), ort_api); // This should not be freed as this owned by ort
|
|
|
|
OrtValue* ort_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(ort_allocator, shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value),
|
|
ort_api);
|
|
auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
/*
|
|
* OnnxruntimeEngine::CreateTensorValue
|
|
*
|
|
* Used by callers like ImageFeatureValue to allocate a cpu or gpu OrtValue with ORT owned memory.
|
|
* In the image feature value case, tensorization creates temporary buffers, and will need to copy the value from
|
|
* its source location to the ort value. Since a copy is required, there is need to preserve the caller's memory locations.
|
|
* We simply allocate memory with ORT and copy the tensorized values into it.
|
|
*/
|
|
HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
OrtAllocator* ort_allocator;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything
|
|
|
|
OrtValue* ort_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value),
|
|
ort_api);
|
|
auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_value), std::move(unique_allocator)));
|
|
return S_OK;
|
|
}
|
|
|
|
using DmlAllocatorResource = std::unique_ptr<void, void (*)(void*)>;
|
|
class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass<
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
IUnknown> {
|
|
public:
|
|
DmlAllocatorWrapper() : dml_resource_(nullptr, nullptr) {}
|
|
|
|
HRESULT RuntimeClassInitialize(DmlAllocatorResource&& dml_resource) {
|
|
dml_resource_ = std::move(dml_resource);
|
|
return S_OK;
|
|
}
|
|
|
|
private:
|
|
DmlAllocatorResource dml_resource_;
|
|
};
|
|
|
|
/*
|
|
* OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource
|
|
*
|
|
* Used by callers like TensorBase to allocate a gpu OrtValue based on a called owned ID3D12Resource.
|
|
* WinML cannot use ORT allocators here since they will allocate the ID3D12Resource and force a copy from the user provided value.
|
|
*/
|
|
HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
OrtExecutionProvider* ort_provider;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
OrtMemoryInfo* dml_memory = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
void* dml_allocator_resource;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource),
|
|
engine_factory_->UseOrtApi());
|
|
|
|
auto unique_dml_allocator_resource =
|
|
DmlAllocatorResource(dml_allocator_resource,
|
|
[](void* ptr) {
|
|
GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr);
|
|
});
|
|
|
|
// create the OrtValue as a tensor letting ort know that we own the data buffer
|
|
OrtValue* ort_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
|
|
dml_memory,
|
|
unique_dml_allocator_resource.get(),
|
|
static_cast<size_t>(d3d_resource->GetDesc().Width),
|
|
shape,
|
|
count,
|
|
ONNXTensorElementDataTypeFromTensorKind(kind),
|
|
&ort_value),
|
|
ort_api);
|
|
auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue);
|
|
|
|
Microsoft::WRL::ComPtr<OnnxruntimeValue> out_value;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(&out_value, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
|
|
// Cache the allocator on the value so it destructs appropriately when the value is dropped
|
|
Microsoft::WRL::ComPtr<DmlAllocatorWrapper> dml_allocator_resource_wrapper;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<DmlAllocatorWrapper>(&dml_allocator_resource_wrapper, std::move(unique_dml_allocator_resource)));
|
|
|
|
RETURN_IF_FAILED(out_value->SetParameter(dml_allocator_resource_wrapper.Get()));
|
|
|
|
*out = out_value.Detach();
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
/*
|
|
* OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy
|
|
*
|
|
* Used by callers like TensorString to allocate a cpu OrtValue and populate the contents with use specified data.
|
|
* WinML cannot use CreateTensorWithDataAsOrtValue since externally allocated strings are not supported on the c-abi.
|
|
* The c-abi string implementation requires a copy the external buffer into its own internal std::string copy.
|
|
* In addition, strings have different APIs on the c-abi like FillStringTensor to populate the buffer, and so strings
|
|
* have a different calling pattern than other Tensor<T> types of simple data types.
|
|
*/
|
|
HRESULT OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy(const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
RETURN_IF_FAILED(CreateTensorValueFromDefaultAllocator(shape, count, winml::TensorKind::String, out));
|
|
|
|
auto ort_value = reinterpret_cast<_winml::OnnxruntimeValue*>(*out)->UseOrtValue();
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(ort_value, reinterpret_cast<const char* const*>(data), num_elements),
|
|
ort_api);
|
|
return S_OK;
|
|
}
|
|
|
|
/*
|
|
* OnnxruntimeEngine::CreateTensorValueFromExternalBuffer
|
|
*
|
|
* Used by callers like TensorBase<T> to allocate a cpu OrtValue that is backed by caller owned memory.
|
|
*/
|
|
HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalBuffer(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
if (kind == winml::TensorKind::String) {
|
|
// String buffers cannot be passed into the ort api directly because ort c-api tensor strings cannot be backed by external memory
|
|
return E_NOTIMPL;
|
|
}
|
|
|
|
// TODO: what is the difference between the device allocator and the arena allocator?
|
|
OrtMemoryInfo* cpu_memory;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory),
|
|
ort_api);
|
|
|
|
OrtValue* ort_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
|
|
cpu_memory,
|
|
data,
|
|
size_in_bytes,
|
|
shape,
|
|
count,
|
|
ONNXTensorElementDataTypeFromTensorKind(kind),
|
|
&ort_value),
|
|
ort_api);
|
|
auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue);
|
|
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CreateSequenceOfValuesValue(IValue** values, size_t size, IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
std::vector<OrtValue*> sequence(size);
|
|
std::transform(
|
|
values,
|
|
values + size,
|
|
std::begin(sequence),
|
|
[](auto value) {
|
|
return static_cast<OnnxruntimeValue*>(value)->UseOrtValue();
|
|
});
|
|
|
|
OrtValue* ort_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateValue(sequence.data(), size, ONNXType::ONNX_TYPE_SEQUENCE, &ort_value),
|
|
ort_api);
|
|
|
|
UniqueOrtValue unique_value(ort_value, ort_api->ReleaseValue);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
/*
|
|
* OnnxruntimeEngine::CreateNullValue
|
|
*
|
|
* Used by callers like TensorBase<T> and the binding object to allocate a cpu OrtValue that is empty.
|
|
* This is used for WinML unbound outputs.
|
|
*/
|
|
HRESULT OnnxruntimeEngine::CreateNullValue(_Out_ IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto unique_value = UniqueOrtValue(nullptr, ort_api->ReleaseValue);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
template <typename TAbiType>
|
|
struct AbiTypeInfo {
|
|
using CppWinRTType = TAbiType;
|
|
using OrtType = TAbiType;
|
|
using ResourceType = TAbiType;
|
|
};
|
|
|
|
template <>
|
|
struct AbiTypeInfo<HSTRING> {
|
|
using CppWinRTType = winrt::hstring;
|
|
using OrtType = const char*;
|
|
using ResourceType = std::string_view;
|
|
};
|
|
|
|
template <typename TCppwinrtType>
|
|
typename auto CppwinrtTypeToOrtType(TCppwinrtType raw) {
|
|
return raw;
|
|
}
|
|
|
|
template <>
|
|
typename auto CppwinrtTypeToOrtType<winrt::hstring>(winrt::hstring raw) {
|
|
return _winml::Strings::UTF8FromHString(raw);
|
|
}
|
|
|
|
template <typename TAbiType>
|
|
typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo<TAbiType>::ResourceType value) {
|
|
return value;
|
|
}
|
|
|
|
template <>
|
|
typename auto ResourceTypeToCppwinrtType<HSTRING>(typename AbiTypeInfo<HSTRING>::ResourceType value) {
|
|
return _winml::Strings::HStringFromUTF8(value.data(), value.size());
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
auto CastToWinrtMap(IInspectable* map_insp) {
|
|
using cppwinrt_key_type = typename AbiTypeInfo<TAbiKey>::CppWinRTType;
|
|
using cppwinrt_value_type = typename AbiTypeInfo<TAbiValue>::CppWinRTType;
|
|
|
|
wf::IInspectable map_inspectable;
|
|
wfc::IMap<cppwinrt_key_type, cppwinrt_value_type> map;
|
|
winrt::copy_from_abi(map_inspectable, map_insp);
|
|
map_inspectable.as(map);
|
|
return map;
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
auto CastToWinrtSequenceOfMaps(IInspectable* sequence_insp) {
|
|
using cppwinrt_key_type = typename AbiTypeInfo<TAbiKey>::CppWinRTType;
|
|
using cppwinrt_value_type = typename AbiTypeInfo<TAbiValue>::CppWinRTType;
|
|
|
|
using cppwinrt_element_map_type = wfc::IMap<cppwinrt_key_type, cppwinrt_value_type>;
|
|
using cppwinrt_sequence_type = wfc::IVector<cppwinrt_element_map_type>;
|
|
cppwinrt_sequence_type sequence;
|
|
wf::IInspectable sequence_inspectable;
|
|
winrt::copy_from_abi(sequence_inspectable, sequence_insp);
|
|
sequence_inspectable.as(sequence);
|
|
return sequence;
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
struct FillMapTensors {
|
|
static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) {
|
|
AbiTypeInfo<TAbiKey>::OrtType* keys_mutable_data;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast<void**>(&keys_mutable_data)),
|
|
ort_api);
|
|
|
|
AbiTypeInfo<TAbiValue>::OrtType* values_mutable_data;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast<void**>(&values_mutable_data)),
|
|
ort_api);
|
|
|
|
auto map = CastToWinrtMap<TAbiKey, TAbiValue>(map_insp);
|
|
size_t index = 0;
|
|
for (const auto& pair : map) {
|
|
keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key());
|
|
values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value());
|
|
index++;
|
|
}
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
template <typename TAbiValue>
|
|
struct FillMapTensors<HSTRING, TAbiValue> {
|
|
static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) {
|
|
AbiTypeInfo<TAbiValue>::OrtType* values_mutable_data;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast<void**>(&values_mutable_data)),
|
|
ort_api);
|
|
|
|
auto map = CastToWinrtMap<HSTRING, TAbiValue>(map_insp);
|
|
size_t index = 0;
|
|
std::vector<std::string> keys;
|
|
for (const auto& pair : map) {
|
|
keys.push_back(CppwinrtTypeToOrtType(pair.Key()));
|
|
values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value());
|
|
index++;
|
|
}
|
|
|
|
std::vector<const char*> raw_values;
|
|
std::transform(
|
|
keys.begin(),
|
|
keys.end(),
|
|
std::back_inserter(raw_values),
|
|
[&](auto& str) { return str.c_str(); });
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()),
|
|
ort_api);
|
|
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
template <typename TAbiKey>
|
|
struct FillMapTensors<TAbiKey, HSTRING> {
|
|
static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) {
|
|
AbiTypeInfo<TAbiKey>::OrtType* keys_mutable_data;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast<void**>(&keys_mutable_data)),
|
|
ort_api);
|
|
|
|
auto map = CastToWinrtMap<TAbiKey, HSTRING>(map_insp);
|
|
size_t index = 0;
|
|
std::vector<std::string> values;
|
|
for (const auto& pair : map) {
|
|
keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key());
|
|
values.push_back(CppwinrtTypeToOrtType(pair.Value()));
|
|
index++;
|
|
}
|
|
|
|
std::vector<const char*> raw_values;
|
|
std::transform(
|
|
values.begin(),
|
|
values.end(),
|
|
std::back_inserter(raw_values),
|
|
[&](auto& str) { return str.c_str(); });
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()),
|
|
ort_api);
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct FillMapTensors<HSTRING, HSTRING> {
|
|
static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) {
|
|
auto map = CastToWinrtMap<HSTRING, HSTRING>(map_insp);
|
|
std::vector<std::string> keys;
|
|
std::vector<std::string> values;
|
|
for (const auto& pair : map) {
|
|
keys.push_back(CppwinrtTypeToOrtType(pair.Key()));
|
|
values.push_back(CppwinrtTypeToOrtType(pair.Value()));
|
|
}
|
|
|
|
std::vector<const char*> raw_keys;
|
|
std::transform(
|
|
keys.begin(),
|
|
keys.end(),
|
|
std::back_inserter(raw_keys),
|
|
[&](auto& str) { return str.c_str(); });
|
|
|
|
std::vector<const char*> raw_values;
|
|
std::transform(
|
|
values.begin(),
|
|
values.end(),
|
|
std::back_inserter(raw_values),
|
|
[&](auto& str) { return str.c_str(); });
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_keys.data(), raw_keys.size()),
|
|
ort_api);
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(values_ort_value, raw_values.data(), raw_values.size()),
|
|
ort_api);
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
HRESULT CreateMapValue(OnnxruntimeEngine* engine, IInspectable* map_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) {
|
|
auto ort_api = engine->UseOrtApi();
|
|
auto map = CastToWinrtMap<TAbiKey, TAbiValue>(map_insp);
|
|
std::vector<int64_t> shape = {static_cast<int64_t>(map.Size())};
|
|
|
|
winrt::com_ptr<_winml::IValue> key_value;
|
|
RETURN_IF_FAILED(engine->CreateTensorValueFromDefaultAllocator(shape.data(), shape.size(), key_kind, key_value.put()));
|
|
auto keys_ort_value = static_cast<OnnxruntimeValue*>(key_value.get())->UseOrtValue();
|
|
|
|
winrt::com_ptr<_winml::IValue> value_value;
|
|
RETURN_IF_FAILED(engine->CreateTensorValueFromDefaultAllocator(shape.data(), shape.size(), value_kind, value_value.put()));
|
|
auto values_ort_value = static_cast<OnnxruntimeValue*>(value_value.get())->UseOrtValue();
|
|
|
|
auto hr = FillMapTensors<TAbiKey, TAbiValue>::Run(ort_api, map_insp, keys_ort_value, values_ort_value);
|
|
RETURN_IF_FAILED(hr);
|
|
|
|
OrtValue* inputs[2] = {keys_ort_value, values_ort_value};
|
|
|
|
OrtValue* map_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateValue(inputs, 2, ONNXType::ONNX_TYPE_MAP, &map_value),
|
|
ort_api);
|
|
auto unique_map_ort_value = UniqueOrtValue(map_value, ort_api->ReleaseValue);
|
|
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, engine, std::move(unique_map_ort_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
static auto GetMapValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
using namespace std::placeholders;
|
|
if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) {
|
|
return std::bind(&CreateMapValue<int64_t, int64_t>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&CreateMapValue<int64_t, float>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) {
|
|
return std::bind(&CreateMapValue<int64_t, double>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Double, _2);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) {
|
|
return std::bind(&CreateMapValue<int64_t, HSTRING>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::String, _2);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) {
|
|
return std::bind(&CreateMapValue<HSTRING, int64_t>, engine, _1, winml::TensorKind::String, winml::TensorKind::Int64, _2);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&CreateMapValue<HSTRING, float>, engine, _1, winml::TensorKind::String, winml::TensorKind::Float, _2);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) {
|
|
return std::bind(&CreateMapValue<HSTRING, double>, engine, _1, winml::TensorKind::String, winml::TensorKind::Double, _2);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) {
|
|
return std::bind(&CreateMapValue<HSTRING, HSTRING>, engine, _1, winml::TensorKind::String, winml::TensorKind::String, _2);
|
|
}
|
|
|
|
THROW_HR(E_NOTIMPL);
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CreateMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) {
|
|
return GetMapValueCreator(this, key_kind, value_kind)(map, out);
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
HRESULT CreateSequenceOfMapsValue(OnnxruntimeEngine* engine, IInspectable* sequence_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) {
|
|
auto ort_api = engine->UseOrtApi();
|
|
auto sequence = CastToWinrtSequenceOfMaps<TAbiKey, TAbiValue>(sequence_insp);
|
|
|
|
std::vector<winrt::com_ptr<_winml::IValue>> element_values;
|
|
for (auto element : sequence) {
|
|
winrt::com_ptr<_winml::IValue> element_value;
|
|
engine->CreateMapValue(reinterpret_cast<IInspectable*>(winrt::get_abi(element)), key_kind, value_kind, element_value.put());
|
|
element_values.push_back(element_value);
|
|
}
|
|
|
|
std::vector<OrtValue*> element_ort_values;
|
|
std::transform(element_values.begin(),
|
|
element_values.end(),
|
|
std::back_inserter(element_ort_values),
|
|
[](auto value) { return static_cast<OnnxruntimeValue*>(value.get())->UseOrtValue(); });
|
|
|
|
OrtValue* sequence_value;
|
|
RETURN_HR_IF_NOT_OK_MSG(
|
|
ort_api->CreateValue(element_ort_values.data(), element_ort_values.size(),
|
|
ONNXType::ONNX_TYPE_SEQUENCE, &sequence_value),
|
|
ort_api);
|
|
auto unique_sequence_ort_value = UniqueOrtValue(sequence_value, ort_api->ReleaseValue);
|
|
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, engine, std::move(unique_sequence_ort_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
|
|
static auto GetSequenceOfMapsValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
using namespace std::placeholders;
|
|
if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&CreateSequenceOfMapsValue<HSTRING, float>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&CreateSequenceOfMapsValue<int64_t, float>, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2);
|
|
}
|
|
|
|
THROW_HR(E_NOTIMPL);
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CreateSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) {
|
|
RETURN_IF_FAILED(GetSequenceOfMapsValueCreator(this, key_kind, value_kind)(sequence, out));
|
|
return S_OK;
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
static HRESULT FillAbiSequence(IInspectable* sequence_insp, std::vector<wf::IInspectable>& elements) {
|
|
using cppwinrt_key_type = typename AbiTypeInfo<TAbiKey>::CppWinRTType;
|
|
using cppwinrt_value_type = typename AbiTypeInfo<TAbiValue>::CppWinRTType;
|
|
auto sequence = CastToWinrtSequenceOfMaps<TAbiKey, TAbiValue>(sequence_insp);
|
|
for (auto element : elements) {
|
|
wfc::IMap<cppwinrt_key_type, cppwinrt_value_type> map_element;
|
|
element.as(map_element);
|
|
sequence.Append(map_element);
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
static auto GetAbiSequenceFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
using namespace std::placeholders;
|
|
if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) {
|
|
return &FillAbiSequence<winrt::hstring, float>;
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) {
|
|
return &FillAbiSequence<int64_t, float>;
|
|
}
|
|
THROW_HR(E_NOTIMPL);
|
|
}
|
|
|
|
static wf::IInspectable CreateMap(winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
wf::IInspectable map_insp;
|
|
if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) {
|
|
auto map = winrt::single_threaded_map<winrt::hstring, float>();
|
|
map.as(map_insp);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) {
|
|
auto map = winrt::single_threaded_map<int64_t, float>();
|
|
map.as(map_insp);
|
|
}
|
|
|
|
return map_insp;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* sequence_value) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto onnxruntime_squence_value = static_cast<OnnxruntimeValue*>(sequence_value);
|
|
auto ort_sequence_value = onnxruntime_squence_value->UseOrtValue();
|
|
|
|
OrtAllocator* ort_allocator;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), ort_api); // This should not be freed as this owned by ort
|
|
|
|
size_t num_elements;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueCount(ort_sequence_value, &num_elements), ort_api);
|
|
|
|
// get the elements
|
|
std::vector<wf::IInspectable> element_map_inspectables;
|
|
for (size_t index = 0; index < num_elements; index++) {
|
|
OrtValue* elements_ort_value = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_sequence_value, static_cast<int>(index), ort_allocator, &elements_ort_value), ort_api);
|
|
auto unique_element_value = UniqueOrtValue(elements_ort_value, ort_api->ReleaseValue);
|
|
|
|
winrt::com_ptr<IValue> element_value;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(element_value.put(), this, std::move(unique_element_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
|
|
wf::IInspectable map_inspectable = CreateMap(key_kind, value_kind);
|
|
RETURN_IF_FAILED(FillFromMapValue(reinterpret_cast<IInspectable*>(winrt::get_abi(map_inspectable)), key_kind, value_kind, element_value.get()));
|
|
element_map_inspectables.push_back(map_inspectable);
|
|
}
|
|
|
|
GetAbiSequenceFiller(key_kind, value_kind)(sequence, element_map_inspectables);
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::GetSequenceOfTensorValues(_In_ _winml::IValue* sequence_value, _Out_ std::vector<winrt::com_ptr<_winml::IValue>>& out_values) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto onnxruntime_squence_value = static_cast<OnnxruntimeValue*>(sequence_value);
|
|
auto ort_sequence_value = onnxruntime_squence_value->UseOrtValue();
|
|
|
|
OrtAllocator* ort_allocator;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), ort_api); // This should not be freed as this owned by ort
|
|
|
|
size_t num_elements;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueCount(ort_sequence_value, &num_elements), ort_api);
|
|
|
|
// get the elements
|
|
out_values.clear();
|
|
for (size_t index = 0; index < num_elements; index++) {
|
|
OrtValue* elements_ort_value = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_sequence_value, static_cast<int>(index), ort_allocator, &elements_ort_value), ort_api);
|
|
auto unique_element_value = UniqueOrtValue(elements_ort_value, ort_api->ReleaseValue);
|
|
|
|
winrt::com_ptr<IValue> element_value;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(element_value.put(), this, std::move(unique_element_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
out_values.push_back(element_value);
|
|
}
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::CreateOneInputAcrossDevices(const char* name, IValue* src, IValue** out) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
|
|
|
auto src_value = static_cast<OnnxruntimeValue*>(src);
|
|
|
|
bool is_set;
|
|
auto is_empty = SUCCEEDED(src_value->IsEmpty(&is_set)) && is_set;
|
|
auto is_tensor = SUCCEEDED(src_value->IsTensor(&is_set)) && is_set;
|
|
|
|
if (is_tensor && !is_empty) {
|
|
int16_t source_location;
|
|
int16_t input_required_location;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(src_value->UseOrtValue(), &source_location),
|
|
ort_api);
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetInputRequiredDeviceId(session_.get(), name, &input_required_location),
|
|
ort_api);
|
|
|
|
if (source_location != input_required_location) {
|
|
OrtValue* dest_ort_value = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionCopyOneInputAcrossDevices(session_.get(), name,
|
|
src_value->UseOrtValue(), &dest_ort_value),
|
|
ort_api);
|
|
auto unique_dest_ort_value = UniqueOrtValue(dest_ort_value, ort_api->ReleaseValue);
|
|
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(out, this, std::move(unique_dest_ort_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
return S_OK;
|
|
}
|
|
}
|
|
|
|
*out = src;
|
|
(*out)->AddRef();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::Run(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
|
|
OrtRunOptions* run_options;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateRunOptions(&run_options),
|
|
ort_api);
|
|
auto unique_run_options = UniqueOrtRunOptions(run_options, ort_api->ReleaseRunOptions);
|
|
|
|
std::vector<OrtValue*> input_ort_values;
|
|
std::transform(
|
|
inputs,
|
|
inputs + num_inputs,
|
|
std::back_inserter(input_ort_values),
|
|
[&](auto& input) {
|
|
auto input_value = static_cast<OnnxruntimeValue*>(input);
|
|
return input_value->UseOrtValue();
|
|
});
|
|
|
|
std::vector<OrtValue*> output_ort_values;
|
|
std::transform(
|
|
outputs,
|
|
outputs + num_outputs,
|
|
std::back_inserter(output_ort_values),
|
|
[&](auto& output) {
|
|
auto output_value = static_cast<OnnxruntimeValue*>(output);
|
|
return output_value->UseOrtValue();
|
|
});
|
|
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->Run(session_.get(),
|
|
unique_run_options.get(),
|
|
input_names,
|
|
input_ort_values.data(),
|
|
num_inputs,
|
|
output_names,
|
|
num_outputs,
|
|
output_ort_values.data()),
|
|
ort_api);
|
|
|
|
for (size_t index = 0; index < num_outputs; index++) {
|
|
auto output_value = static_cast<OnnxruntimeValue*>(outputs[index]);
|
|
if (output_value->UseOrtValue() != output_ort_values[index]) {
|
|
RETURN_IF_FAILED(output_value->AssignOrtValue(output_ort_values[index]));
|
|
}
|
|
}
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
template <typename TAbiKey, typename TAbiValue>
|
|
HRESULT FillAbiMap(IInspectable* map_insp, size_t num_elements, void* keys_data, void* values_data) {
|
|
auto map = CastToWinrtMap<TAbiKey, TAbiValue>(map_insp);
|
|
|
|
auto keys = reinterpret_cast<typename AbiTypeInfo<TAbiKey>::ResourceType*>(keys_data);
|
|
auto values = reinterpret_cast<typename AbiTypeInfo<TAbiValue>::ResourceType*>(values_data);
|
|
|
|
for (size_t i = 0; i < num_elements; ++i) {
|
|
map.Insert(
|
|
ResourceTypeToCppwinrtType<TAbiKey>(keys[i]),
|
|
ResourceTypeToCppwinrtType<TAbiValue>(values[i]));
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
static auto GetAbiMapFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) {
|
|
using namespace std::placeholders;
|
|
if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) {
|
|
return std::bind(&FillAbiMap<int64_t, int64_t>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&FillAbiMap<int64_t, float>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) {
|
|
return std::bind(&FillAbiMap<int64_t, double>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) {
|
|
return std::bind(&FillAbiMap<int64_t, HSTRING>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) {
|
|
return std::bind(&FillAbiMap<HSTRING, int64_t>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) {
|
|
return std::bind(&FillAbiMap<HSTRING, float>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) {
|
|
return std::bind(&FillAbiMap<HSTRING, double>, _1, _2, _3, _4);
|
|
} else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) {
|
|
return std::bind(&FillAbiMap<HSTRING, HSTRING>, _1, _2, _3, _4);
|
|
}
|
|
|
|
THROW_HR(E_NOTIMPL);
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngine::FillFromMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* map_value) {
|
|
auto ort_api = engine_factory_->UseOrtApi();
|
|
auto onnxruntime_map_value = static_cast<OnnxruntimeValue*>(map_value);
|
|
auto ort_map_value = onnxruntime_map_value->UseOrtValue();
|
|
|
|
OrtAllocator* ort_allocator;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator),
|
|
ort_api); // This should not be freed as this owned by ort
|
|
|
|
// get the keys
|
|
OrtValue* keys_ort_value = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 0, ort_allocator, &keys_ort_value),
|
|
ort_api);
|
|
auto unique_keys_value = UniqueOrtValue(keys_ort_value, ort_api->ReleaseValue);
|
|
winrt::com_ptr<IValue> keys_value;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(keys_value.put(), this, std::move(unique_keys_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
|
|
// get the keys
|
|
OrtValue* values_ort_value = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 1, ort_allocator, &values_ort_value),
|
|
ort_api);
|
|
auto unique_values_value = UniqueOrtValue(values_ort_value, ort_api->ReleaseValue);
|
|
winrt::com_ptr<IValue> values_value;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeValue>(values_value.put(), this, std::move(unique_values_value), UniqueOrtAllocator(nullptr, nullptr)));
|
|
|
|
std::vector<int64_t> keys_shape;
|
|
keys_value->GetTensorShape(keys_shape);
|
|
|
|
_winml::Resource keys_data;
|
|
RETURN_IF_FAILED(keys_value->GetResource(keys_data));
|
|
_winml::Resource values_data;
|
|
RETURN_IF_FAILED(values_value->GetResource(values_data));
|
|
|
|
auto num_elements = static_cast<size_t>(ShapeSize(keys_shape.data(), keys_shape.size()));
|
|
GetAbiMapFiller(key_kind, value_kind)(map, num_elements, keys_data.get(), values_data.get());
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngineFactory::RuntimeClassInitialize() {
|
|
ort_api_ = GetVersionedOrtApi();
|
|
winml_adapter_api_ = GetVersionedWinmlAdapterApi();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngineFactory::EnsureEnvironment() {
|
|
if (environment_ == nullptr) {
|
|
std::lock_guard lock(mutex_);
|
|
if (environment_ == nullptr) {
|
|
environment_ = PheonixSingleton<OnnxruntimeEnvironment>(ort_api_);
|
|
}
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) {
|
|
RETURN_IF_FAILED(EnsureEnvironment());
|
|
|
|
OrtModel* ort_model = nullptr;
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateModelFromPath(model_path, len, &ort_model),
|
|
ort_api_);
|
|
|
|
auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnruntimeModel>(out, this, std::move(model)));
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) {
|
|
RETURN_IF_FAILED(EnsureEnvironment());
|
|
OrtModel* ort_model = nullptr;
|
|
if (auto status = winml_adapter_api_->CreateModelFromData(data, size, &ort_model)) {
|
|
return E_INVALIDARG;
|
|
}
|
|
|
|
auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel);
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnruntimeModel>(out, this, std::move(model)));
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ _winml::IEngineBuilder** out) {
|
|
RETURN_IF_FAILED(EnsureEnvironment());
|
|
Microsoft::WRL::ComPtr<OnnxruntimeEngineBuilder> onnxruntime_engine_builder;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeEngineBuilder>(&onnxruntime_engine_builder, this));
|
|
RETURN_IF_FAILED(onnxruntime_engine_builder.CopyTo(out));
|
|
return S_OK;
|
|
}
|
|
|
|
const OrtApi* OnnxruntimeEngineFactory::UseOrtApi() {
|
|
return ort_api_;
|
|
}
|
|
|
|
const WinmlAdapterApi* OnnxruntimeEngineFactory::UseWinmlAdapterApi() {
|
|
return winml_adapter_api_;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngineFactory::GetOrtEnvironment(OrtEnv** ort_env) {
|
|
RETURN_IF_FAILED(EnsureEnvironment());
|
|
RETURN_IF_FAILED(environment_->GetOrtEnvironment(ort_env));
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngineFactory::EnableDebugOutput(bool is_enabled) {
|
|
RETURN_IF_FAILED(EnsureEnvironment());
|
|
RETURN_IF_FAILED(environment_->EnableDebugOutput(is_enabled));
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT OnnxruntimeEngineFactory::CreateCustomRegistry(IMLOperatorRegistry** registry) {
|
|
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateCustomRegistry(registry),
|
|
ort_api_);
|
|
return S_OK;
|
|
}
|
|
|
|
STDAPI CreateOnnxruntimeEngineFactory(_Out_ _winml::IEngineFactory** engine_factory) {
|
|
Microsoft::WRL::ComPtr<OnnxruntimeEngineFactory> onnxruntime_engine_factory;
|
|
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeEngineFactory>(&onnxruntime_engine_factory));
|
|
RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory));
|
|
return S_OK;
|
|
} |