mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Handle exception thrown from all apis in WinMLAdapter (#2539)
This commit is contained in:
parent
2b8d6d3e31
commit
8fb7b88e0a
5 changed files with 105 additions and 48 deletions
|
|
@ -33,7 +33,7 @@ CpuOrtSessionBuilder::CpuOrtSessionBuilder() {
|
|||
|
||||
HRESULT
|
||||
CpuOrtSessionBuilder::CreateSessionOptions(
|
||||
OrtSessionOptions** options) {
|
||||
OrtSessionOptions** options) try {
|
||||
RETURN_HR_IF_NULL(E_POINTER, options);
|
||||
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
|
||||
|
|
@ -52,12 +52,13 @@ CpuOrtSessionBuilder::CreateSessionOptions(
|
|||
session_options.release();
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT
|
||||
CpuOrtSessionBuilder::CreateSession(
|
||||
OrtSessionOptions* options,
|
||||
winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) {
|
||||
onnxruntime::IExecutionProvider** pp_provider) try {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
||||
RETURN_HR_IF_NULL(E_POINTER, pp_provider);
|
||||
RETURN_HR_IF(E_POINTER, *pp_provider != nullptr);
|
||||
|
|
@ -84,14 +85,16 @@ CpuOrtSessionBuilder::CreateSession(
|
|||
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT
|
||||
CpuOrtSessionBuilder::Initialize(
|
||||
winmla::IInferenceSession* p_session,
|
||||
onnxruntime::IExecutionProvider* /*p_provider*/
|
||||
) {
|
||||
) try {
|
||||
ORT_THROW_IF_ERROR(p_session->get()->Initialize());
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
} // Windows::AI::MachineLearning::Adapter
|
||||
|
|
@ -42,7 +42,7 @@ DmlOrtSessionBuilder::DmlOrtSessionBuilder(
|
|||
|
||||
HRESULT
|
||||
DmlOrtSessionBuilder::CreateSessionOptions(
|
||||
OrtSessionOptions** options) {
|
||||
OrtSessionOptions** options) try {
|
||||
RETURN_HR_IF_NULL(E_POINTER, options);
|
||||
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
|
||||
|
|
@ -58,6 +58,7 @@ DmlOrtSessionBuilder::CreateSessionOptions(
|
|||
session_options.release();
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
static HRESULT
|
||||
RegisterCustomRegistry(
|
||||
|
|
@ -109,7 +110,7 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
|
|||
HRESULT DmlOrtSessionBuilder::CreateSession(
|
||||
OrtSessionOptions* options,
|
||||
winmla::IInferenceSession** p_session,
|
||||
onnxruntime::IExecutionProvider** pp_provider) {
|
||||
onnxruntime::IExecutionProvider** pp_provider) try {
|
||||
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
||||
RETURN_HR_IF_NULL(E_POINTER, pp_provider);
|
||||
RETURN_HR_IF(E_POINTER, *pp_provider != nullptr);
|
||||
|
|
@ -133,10 +134,11 @@ HRESULT DmlOrtSessionBuilder::CreateSession(
|
|||
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT DmlOrtSessionBuilder::Initialize(
|
||||
winmla::IInferenceSession* p_session,
|
||||
onnxruntime::IExecutionProvider* p_provider) {
|
||||
onnxruntime::IExecutionProvider* p_provider) try {
|
||||
RETURN_HR_IF_NULL(E_INVALIDARG, p_session);
|
||||
RETURN_HR_IF_NULL(E_INVALIDARG, p_provider);
|
||||
|
||||
|
|
@ -154,6 +156,7 @@ HRESULT DmlOrtSessionBuilder::Initialize(
|
|||
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
} // Windows::AI::MachineLearning::Adapter
|
||||
|
||||
|
|
|
|||
|
|
@ -63,11 +63,11 @@ class ModelProto : public Microsoft::WRL::RuntimeClass<
|
|||
ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) {
|
||||
}
|
||||
|
||||
onnx::ModelProto* STDMETHODCALLTYPE get() override {
|
||||
onnx::ModelProto* STDMETHODCALLTYPE get() noexcept override {
|
||||
return model_proto_.get();
|
||||
}
|
||||
|
||||
onnx::ModelProto* STDMETHODCALLTYPE detach() override {
|
||||
onnx::ModelProto* STDMETHODCALLTYPE detach() noexcept override {
|
||||
return model_proto_.release();
|
||||
}
|
||||
|
||||
|
|
@ -93,23 +93,28 @@ class ModelInfo : public Microsoft::WRL::RuntimeClass<
|
|||
Initialize(model_proto);
|
||||
}
|
||||
|
||||
const char* STDMETHODCALLTYPE author() override {
|
||||
const char* STDMETHODCALLTYPE author() noexcept override {
|
||||
return author_.c_str();
|
||||
}
|
||||
const char* STDMETHODCALLTYPE name() override {
|
||||
|
||||
const char* STDMETHODCALLTYPE name() noexcept override {
|
||||
return name_.c_str();
|
||||
}
|
||||
const char* STDMETHODCALLTYPE domain() override {
|
||||
|
||||
const char* STDMETHODCALLTYPE domain() noexcept override {
|
||||
return domain_.c_str();
|
||||
}
|
||||
const char* STDMETHODCALLTYPE description() override {
|
||||
|
||||
const char* STDMETHODCALLTYPE description() noexcept override {
|
||||
return description_.c_str();
|
||||
}
|
||||
int64_t STDMETHODCALLTYPE version() override {
|
||||
|
||||
int64_t STDMETHODCALLTYPE version() noexcept override {
|
||||
return version_;
|
||||
}
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetModelMetadata(
|
||||
ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING>** metadata) override {
|
||||
ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING>** metadata) override try {
|
||||
*metadata = nullptr;
|
||||
std::unordered_map<winrt::hstring, winrt::hstring> map_copy;
|
||||
for (auto& pair : model_metadata_) {
|
||||
|
|
@ -121,21 +126,25 @@ class ModelInfo : public Microsoft::WRL::RuntimeClass<
|
|||
std::move(map_copy));
|
||||
|
||||
winrt::copy_to_abi(out.GetView(), *(void**)metadata);
|
||||
return S_OK;
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetInputFeatures(
|
||||
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override{
|
||||
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override try {
|
||||
*features = nullptr;
|
||||
winrt::copy_to_abi(input_features_.GetView(), *(void**)features);
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetOutputFeatures(
|
||||
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override {
|
||||
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override try {
|
||||
*features = nullptr;
|
||||
winrt::copy_to_abi(output_features_.GetView(), *(void**)features);
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
static std::vector<const char*>
|
||||
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
|
||||
|
|
@ -262,7 +271,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
// factory methods for creating an ort model from a path
|
||||
HRESULT STDMETHODCALLTYPE CreateModelProto(
|
||||
const char* path,
|
||||
IModelProto** model_proto) override {
|
||||
IModelProto** model_proto) override try {
|
||||
int file_descriptor;
|
||||
_set_errno(0); // clear errno
|
||||
_sopen_s(
|
||||
|
|
@ -297,11 +306,12 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
||||
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
// factory methods for creating an ort model from a stream
|
||||
HRESULT STDMETHODCALLTYPE CreateModelProto(
|
||||
ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference,
|
||||
IModelProto** model_proto) override {
|
||||
IModelProto** model_proto) override try {
|
||||
ZeroCopyInputStreamWrapper wrapper(stream_reference);
|
||||
|
||||
auto model_proto_inner = new onnx::ModelProto();
|
||||
|
|
@ -313,22 +323,26 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
||||
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
// factory methods for creating an ort model from a model_proto
|
||||
HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override {
|
||||
HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override try {
|
||||
auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get());
|
||||
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
||||
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override {
|
||||
HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override try {
|
||||
auto model_info_outer = wil::MakeOrThrow<ModelInfo>(model_proto->get());
|
||||
return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast<void**>(model_info));
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
void STDMETHODCALLTYPE EnableDebugOutput() override {
|
||||
void STDMETHODCALLTYPE EnableDebugOutput() override try {
|
||||
WinML::CWinMLLogSink::EnableDebugOutput();
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
static bool IsFeatureDescriptorFp16(
|
||||
winml::ILearningModelFeatureDescriptor descriptor) {
|
||||
|
|
@ -346,7 +360,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility(
|
||||
winml::LearningModel const& model,
|
||||
IModelProto* p_model_proto,
|
||||
bool is_float16_supported) override {
|
||||
bool is_float16_supported) override try {
|
||||
if (!is_float16_supported) {
|
||||
auto& graph = p_model_proto->get()->graph();
|
||||
|
||||
|
|
@ -405,8 +419,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
}
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override {
|
||||
ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override try {
|
||||
#ifdef USE_DML
|
||||
auto d3dResource =
|
||||
Dml::GetD3D12ResourceFromAllocation(
|
||||
|
|
@ -416,6 +431,8 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
#else
|
||||
return nullptr;
|
||||
#endif USE_DML
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static onnxruntime::MLDataType GetType(winml::TensorKind kind) {
|
||||
|
|
@ -432,7 +449,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder(
|
||||
ID3D12Device* device,
|
||||
ID3D12CommandQueue* queue,
|
||||
IOrtSessionBuilder** session_builder) override {
|
||||
IOrtSessionBuilder** session_builder) override try {
|
||||
if (device == nullptr) {
|
||||
auto builder = wil::MakeOrThrow<CpuOrtSessionBuilder>();
|
||||
return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast<void**>(session_builder));
|
||||
|
|
@ -446,8 +463,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
return E_NOTIMPL;
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override {
|
||||
HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try {
|
||||
*key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
auto type = ort_value->Type();
|
||||
if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapStringToString>()) {
|
||||
|
|
@ -477,8 +495,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
}
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override {
|
||||
HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try {
|
||||
*key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
auto type = ort_value->Type();
|
||||
if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::VectorMapStringToFloat>()) {
|
||||
|
|
@ -490,8 +509,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
}
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override {
|
||||
HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override try {
|
||||
#ifdef USE_DML
|
||||
auto impl = wil::MakeOrThrow<AbiCustomRegistryImpl>();
|
||||
*registry = impl.Detach();
|
||||
|
|
@ -500,8 +520,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
return E_NOTIMPL;
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative* operator_provider_native, IMLOperatorRegistry** registry) override {
|
||||
HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative* operator_provider_native, IMLOperatorRegistry** registry) override try {
|
||||
#ifdef USE_DML
|
||||
// Retrieve the "operator abi" registry.
|
||||
winrt::com_ptr<IMLOperatorRegistry> operator_registry;
|
||||
|
|
@ -512,25 +533,29 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
return E_NOTIMPL;
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override {
|
||||
void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override try {
|
||||
#ifdef USE_DML
|
||||
return Dml::CreateGPUAllocationFromD3DResource(pResource);
|
||||
#else
|
||||
return nullptr;
|
||||
#endif USE_DML
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override {
|
||||
void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override try {
|
||||
#ifdef USE_DML
|
||||
Dml::FreeGPUAllocation(ptr);
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CopyTensor(
|
||||
onnxruntime::IExecutionProvider* provider,
|
||||
OrtValue* src,
|
||||
OrtValue* dst) override {
|
||||
OrtValue* dst) override try {
|
||||
#ifdef USE_DML
|
||||
ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, *(src->GetMutable<onnxruntime::Tensor>()), *(dst->GetMutable<onnxruntime::Tensor>())));
|
||||
return S_OK;
|
||||
|
|
@ -538,13 +563,14 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
return E_NOTIMPL;
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
|
||||
// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being
|
||||
// deferred until first evaluation. It also prevents a situation where inference functions in externally
|
||||
// registered schema are reachable only after upstream schema have been revised in a later OS release,
|
||||
// which would be a compatibility risk.
|
||||
HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override {
|
||||
HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override try {
|
||||
#ifdef USE_DML
|
||||
static std::once_flag schema_override_once_flag;
|
||||
std::call_once(schema_override_once_flag, []() {
|
||||
|
|
@ -555,10 +581,11 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
return S_OK; // needs to return S_OK otherwise everything breaks because this gets called from the learningmodel constructor
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(
|
||||
onnxruntime::IExecutionProvider* provider,
|
||||
OrtMemoryInfo** memory_info) override {
|
||||
OrtMemoryInfo** memory_info) override try {
|
||||
auto allocator = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault);
|
||||
|
||||
const auto& info = allocator->Info();
|
||||
|
|
@ -568,8 +595,9 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
}
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue* ort_value, OrtMemoryInfo** memory_info) override {
|
||||
HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue* ort_value, OrtMemoryInfo** memory_info) override try {
|
||||
const auto& tensor = ort_value->Get<onnxruntime::Tensor>();
|
||||
auto info = tensor.Location();
|
||||
*memory_info = new OrtMemoryInfo(info.name, info.type, info.device, info.id, info.mem_type);
|
||||
|
|
@ -578,6 +606,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
}
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
struct AllocatorWrapper : public OrtAllocator {
|
||||
public:
|
||||
|
|
@ -604,7 +633,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
|
||||
HRESULT STDMETHODCALLTYPE GetProviderAllocator(
|
||||
onnxruntime::IExecutionProvider* provider,
|
||||
OrtAllocator** allocator) override {
|
||||
OrtAllocator** allocator) override try {
|
||||
auto allocator_ptr = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault);
|
||||
*allocator = new AllocatorWrapper(allocator_ptr);
|
||||
if (*allocator == nullptr) {
|
||||
|
|
@ -613,13 +642,15 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|||
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
}; // namespace Windows::AI::MachineLearning::Adapter
|
||||
|
||||
extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) {
|
||||
extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) try {
|
||||
// make an adapter instance
|
||||
Microsoft::WRL::ComPtr<WinMLAdapter> adapterptr = wil::MakeOrThrow<WinMLAdapter>();
|
||||
return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast<void**>(adapter));
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
// InferenceSession
|
||||
// ================
|
||||
|
|
@ -627,25 +658,29 @@ extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter)
|
|||
InferenceSession::InferenceSession(onnxruntime::InferenceSession* session) : session_(session) {
|
||||
}
|
||||
|
||||
void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers() {
|
||||
void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers() try {
|
||||
#ifdef USE_DML
|
||||
// Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT
|
||||
GraphTransformerHelpers::RegisterGraphTransformers(session_.get());
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() {
|
||||
HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() try {
|
||||
this->session_->StartProfiling(PheonixSingleton<WinML::LotusEnvironment>()->GetDefaultLogger());
|
||||
return S_OK;
|
||||
}
|
||||
HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() {
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() try {
|
||||
this->session_->EndProfiling();
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE
|
||||
InferenceSession::LoadModel(
|
||||
IModelProto* model_proto) {
|
||||
IModelProto* model_proto) try {
|
||||
auto session_protected_load_accessor =
|
||||
static_cast<InferenceSessionProtectedLoadAccessor*>(session_.get());
|
||||
// session's like to have their very own copy of the model_proto, use detach()
|
||||
|
|
@ -653,10 +688,11 @@ InferenceSession::LoadModel(
|
|||
ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr)));
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
HRESULT STDMETHODCALLTYPE
|
||||
InferenceSession::RegisterCustomRegistry(
|
||||
IMLOperatorRegistry* registry) {
|
||||
IMLOperatorRegistry* registry) try {
|
||||
RETURN_HR_IF(S_OK, registry == nullptr);
|
||||
|
||||
#ifdef USE_DML
|
||||
|
|
@ -670,29 +706,33 @@ InferenceSession::RegisterCustomRegistry(
|
|||
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) {
|
||||
void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) try {
|
||||
#ifdef USE_DML
|
||||
Dml::FlushContext(dml_provider);
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) {
|
||||
void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) try {
|
||||
#ifdef USE_DML
|
||||
Dml::TrimUploadHeap(dml_provider);
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) {
|
||||
void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) try {
|
||||
#ifdef USE_DML
|
||||
Dml::ReleaseCompletedReferences(dml_provider);
|
||||
#endif USE_DML
|
||||
}
|
||||
WINML_CATCH_ALL_DONOTHING
|
||||
|
||||
HRESULT STDMETHODCALLTYPE InferenceSession::CopyOneInputAcrossDevices(
|
||||
const char* input_name,
|
||||
const OrtValue* orig_mlvalue,
|
||||
OrtValue** new_mlvalue) {
|
||||
OrtValue** new_mlvalue) try {
|
||||
auto session_protected_load_accessor =
|
||||
static_cast<InferenceSessionProtectedLoadAccessor*>(session_.get());
|
||||
const onnxruntime::SessionState& sessionState = session_protected_load_accessor->GetSessionState();
|
||||
|
|
@ -701,4 +741,6 @@ HRESULT STDMETHODCALLTYPE InferenceSession::CopyOneInputAcrossDevices(
|
|||
*new_mlvalue = temp_mlvalue.release();
|
||||
return S_OK;
|
||||
}
|
||||
WINML_CATCH_ALL_COM
|
||||
|
||||
} // namespace Windows::AI::MachineLearning::Adapter
|
||||
|
|
@ -119,12 +119,16 @@ public:
|
|||
|
||||
InferenceSession(onnxruntime::InferenceSession * session);
|
||||
|
||||
onnxruntime::InferenceSession* STDMETHODCALLTYPE get() override { return session_.get(); }
|
||||
HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) override {
|
||||
onnxruntime::InferenceSession* STDMETHODCALLTYPE get() noexcept override {
|
||||
return session_.get();
|
||||
}
|
||||
|
||||
HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) noexcept override {
|
||||
// (OrtSession *) are really (InferenceSession *) as well
|
||||
*out = reinterpret_cast<OrtSession*>(session_.get());
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
void STDMETHODCALLTYPE RegisterGraphTransformers() override;
|
||||
HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry* registry) override;
|
||||
HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) override;
|
||||
|
|
|
|||
|
|
@ -109,3 +109,8 @@ inline __declspec(noinline) winrt::hresult_error _to_hresult() noexcept {
|
|||
catch (...) { \
|
||||
return _to_hresult().to_abi(); \
|
||||
}
|
||||
|
||||
#define WINML_CATCH_ALL_DONOTHING \
|
||||
catch (...) { \
|
||||
return; \
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue