// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include "pch.h" #include "inc/WinMLAdapter.h" #include "inc/CustomRegistryHelper.h" #include "inc/LotusEnvironment.h" #include "inc/AbiCustomRegistryImpl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" #include "LearningModelDevice.h" #include "TensorFeatureDescriptor.h" #include "ImageFeatureDescriptor.h" #include "api.image/inc/D3DDeviceCache.h" #include "PheonixSingleton.h" #include "DmlOrtSessionBuilder.h" #include "CpuOrtSessionBuilder.h" #include #include #include "ZeroCopyInputStreamWrapper.h" #include "google/protobuf/io/zero_copy_stream_impl.h" using namespace winrt::Windows::AI::MachineLearning; namespace Windows::AI::MachineLearning::Adapter { // ORT intentionally requires callers derive from their session class to access // the protected Load method used below. class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession { public: onnxruntime::common::Status Load(std::unique_ptr p_model_proto) { return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); } }; // class AbiSafeTensor // class AbiSafeTensor : public Microsoft::WRL::RuntimeClass < Microsoft::WRL::RuntimeClassFlags, ITensor> { private: onnxruntime::Tensor& tensor_; // weak ref ComPtr value_; // strong ref public: AbiSafeTensor(onnxruntime::Tensor* tensor, IOrtValue * value_in) : tensor_(*tensor), value_(value_in) { } const onnxruntime::Tensor& STDMETHODCALLTYPE get() override { return tensor_; } onnxruntime::Tensor* STDMETHODCALLTYPE getMutable() override { return &tensor_; } onnxruntime::MLDataType STDMETHODCALLTYPE DataType() override { return tensor_.DataType(); } const void* STDMETHODCALLTYPE DataRaw() override { return tensor_.DataRaw(); } const std::vector& STDMETHODCALLTYPE ShapeGetDims() override { return tensor_.Shape().GetDims(); } int64_t STDMETHODCALLTYPE ShapeSize() override { return tensor_.Shape().Size(); } const char * STDMETHODCALLTYPE LocationName() override { return tensor_.Location().name; } OrtMemType STDMETHODCALLTYPE LocationMemType() override { return tensor_.Location().mem_type; } }; // class OrtValue // class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < Microsoft::WRL::RuntimeClassFlags, IOrtValue> { private: OrtValue ort_value_; public: AbiSafeOrtValue() {} AbiSafeOrtValue(OrtValue value_in) : ort_value_(value_in) {} OrtValue& STDMETHODCALLTYPE get() override { return ort_value_; } onnxruntime::MLDataType STDMETHODCALLTYPE Type() override { return ort_value_.Type(); } bool STDMETHODCALLTYPE IsTensor() override { return ort_value_.IsTensor(); } // end HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) override { auto tensor_inner = ort_value_.GetMutable(); auto tensor_outer = wil::MakeOrThrow(tensor_inner, this); *tensor = tensor_outer.Detach(); return S_OK; } }; class ModelProto : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, IModelProto> { public: ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) { } onnx::ModelProto* STDMETHODCALLTYPE get() override { return model_proto_.get(); } private: std::shared_ptr model_proto_; }; class WinMLAdapter : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, IWinMLAdapter> { public: // factory methods for creating an ort model from a path HRESULT STDMETHODCALLTYPE CreateModelProto( const char* path, IModelProto** model_proto) override { int file_descriptor; _sopen_s( &file_descriptor, path, O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); THROW_HR_IF_MSG( E_FAIL, 0 > file_descriptor, "Failed"); //errno auto stream = google::protobuf::io::FileInputStream(file_descriptor); stream.SetCloseOnDelete(true); auto model_proto_inner = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, !model_proto_inner->ParseFromZeroCopyStream(&stream) == false, "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); } // factory methods for creating an ort model from a stream HRESULT STDMETHODCALLTYPE CreateModelProto( ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) override { ZeroCopyInputStreamWrapper wrapper(stream_reference); auto model_proto_inner = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false, "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); } // factory methods for creating an ort model from a model_proto HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) override { auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); } void STDMETHODCALLTYPE EnableDebugOutput() override { WinML::CWinMLLogSink::EnableDebugOutput(); } static bool IsFeatureDescriptorFp16( winml::ILearningModelFeatureDescriptor descriptor) { if (auto imageFeatureDescriptor = descriptor.try_as()) { return TensorKind::Float16 == imageFeatureDescriptor.TensorKind(); } if (auto tensorFeatureDescriptor = descriptor.try_as()) { return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind(); } return false; } HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( winml::LearningModel const& model, IModelProto* p_model_proto, bool is_float16_supported) override { if (!is_float16_supported) { auto& graph = p_model_proto->get()->graph(); // The model will not contain fp16 operations if: // 1. The model has no fp16 inputs // 2. The model has no fp16 initializers // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator // 4. The model does not have any fp16 outputs // 1. Ensure that The model has no fp16 inputs for (auto descriptor : model.InputFeatures()) { THROW_HR_IF_MSG( DXGI_ERROR_UNSUPPORTED, IsFeatureDescriptorFp16(descriptor), "The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.", descriptor.Name().c_str()); } // 2. Ensure that the model has no fp16 initializers for (int i = 0; i < graph.node_size(); i++) { auto node = graph.node(i); if (node.op_type() == "Cast" && node.domain().empty()) { for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { auto attribute = node.attribute(attribIndex); if (attribute.name() == "to") { THROW_HR_IF_MSG( DXGI_ERROR_UNSUPPORTED, attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, "The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.", node.name().c_str()); } } } } // 3. Ensure that the model does not create any fp16 intermediary // tensors via the Cast (to float16) operator for (int i = 0; i < graph.initializer_size(); i++) { auto initializer = graph.initializer(i); THROW_HR_IF_MSG( DXGI_ERROR_UNSUPPORTED, initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, "The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.", initializer.name().c_str()); } // 4. Ensure that the model does not have any fp16 outputs for (auto descriptor : model.OutputFeatures()) { THROW_HR_IF_MSG( DXGI_ERROR_UNSUPPORTED, IsFeatureDescriptorFp16(descriptor), "The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.", descriptor.Name().c_str()); } } return S_OK; } ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override { auto d3dResource = Dml::GetD3D12ResourceFromAllocation( provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), allocation); return d3dResource; } static onnxruntime::MLDataType GetType(winml::TensorKind kind) { switch (kind) { case winml::TensorKind::Float: return onnxruntime::DataTypeImpl::GetType(); case winml::TensorKind::Float16: return onnxruntime::DataTypeImpl::GetType(); }; return nullptr; } // factory method for creating an ortsessionbuilder from a device HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( ID3D12Device* device, ID3D12CommandQueue* queue, IOrtSessionBuilder** session_builder) override { if (device == nullptr) { auto builder = wil::MakeOrThrow(); return builder.CopyTo(__uuidof(IOrtSessionBuilder), (void**)session_builder); } else { auto builder = wil::MakeOrThrow(device, queue); return builder.CopyTo(__uuidof(IOrtSessionBuilder), (void**)session_builder); } } onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() override { return onnxruntime::DataTypeImpl::GetType(); } onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType(winml::TensorKind kind) override { if (kind == TensorKind::Float) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Double) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::String) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::UInt8) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Int8) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::UInt16) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Int16) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::UInt32) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Int32) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::UInt64) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Int64) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Boolean) { return onnxruntime::DataTypeImpl::GetType(); } else if (kind == TensorKind::Float16) { return onnxruntime::DataTypeImpl::GetType(); } return nullptr; } onnxruntime::MLDataType STDMETHODCALLTYPE GetMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { if (key_kind == TensorKind::String) { if (value_kind == TensorKind::String) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Int64) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Float) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Double) { return onnxruntime::DataTypeImpl::GetType(); } } else if (key_kind == TensorKind::Int64) { if (value_kind == TensorKind::String) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Int64) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Float) { return onnxruntime::DataTypeImpl::GetType(); } else if (value_kind == TensorKind::Double) { return onnxruntime::DataTypeImpl::GetType(); } } return nullptr; } onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) override { if (key_kind == TensorKind::String) { if (value_kind == TensorKind::Float) { return onnxruntime::DataTypeImpl::GetType(); } } else if (key_kind == TensorKind::Int64) { if (value_kind == TensorKind::Float) { return onnxruntime::DataTypeImpl::GetType(); } } return nullptr; } // returns the raw mutable data. void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_Value) override { return nullptr; } void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { return nullptr; } void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) override { return nullptr; } HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override { auto impl = wil::MakeOrThrow(); *registry = impl.Detach(); return S_OK; } void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override { return Dml::CreateGPUAllocationFromD3DResource(pResource); } void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override { Dml::FreeGPUAllocation(ptr); } HRESULT STDMETHODCALLTYPE CopyTensor( onnxruntime::IExecutionProvider* provider, ITensor* src, ITensor* dst) override { ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, src->get(), *(dst->getMutable()))); return S_OK; } HRESULT STDMETHODCALLTYPE CreateGPUMLValue( void * execution_provider_allocated_resource, onnxruntime::IExecutionProvider* provider, std::vector* shape, onnxruntime::MLDataType data_type, IOrtValue ** gpu_value) override { THROW_HR_IF_MSG(WINML_ERR_INVALID_BINDING, "DmlExecutionProvider" != provider->Type(), "Cannot creat GPU tensor on CPU device"); onnxruntime::TensorShape tensor_shape(*shape); auto tensor = new onnxruntime::Tensor( data_type, tensor_shape, execution_provider_allocated_resource, provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)->Info()); auto ort_value = wil::MakeOrThrow(); ort_value->get().Init(tensor, onnxruntime::DataTypeImpl::GetType(), onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); *gpu_value = ort_value.Detach(); return S_OK; } HRESULT STDMETHODCALLTYPE CreateCPUMLValue( std::vector* shape, onnxruntime::MLDataType data_type, onnxruntime::BufferNakedPtr buffer, IOrtValue ** cpu_value) override { auto registrations = onnxruntime::DeviceAllocatorRegistry::Instance().AllRegistrations(); auto alloc = registrations[onnxruntime::CPU].factory(0); onnxruntime::TensorShape tensor_shape(*shape); // Unowned raw tensor pointer passed to engine auto tensor = new onnxruntime::Tensor( data_type, tensor_shape, buffer, alloc->Info()); auto ort_value = wil::MakeOrThrow(); ort_value->get().Init(tensor, onnxruntime::DataTypeImpl::GetType(), onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); *cpu_value = ort_value.Detach(); return S_OK; } HRESULT STDMETHODCALLTYPE CreateMLValue( winml::TensorKind kind, onnxruntime::MLDataType data_type, const int64_t * shape, uint32_t shape_count, onnxruntime::IExecutionProvider* provider, IOrtValue ** ort_value) override { onnxruntime::TensorShape tensor_shape(shape, shape_count); auto tensor = new onnxruntime::Tensor( GetType(kind), tensor_shape, provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault)); auto ort_value_out = wil::MakeOrThrow(); ort_value_out->get().Init(tensor, data_type, data_type->GetDeleteFunc()); *ort_value = ort_value_out.Detach();; return S_OK; } HRESULT STDMETHODCALLTYPE CreateOrtValue( void * data, onnxruntime::MLDataType data_type, IOrtValue ** ort_value) override { auto ort_value_out = wil::MakeOrThrow(); ort_value_out->get().Init( data, data_type, data_type->GetDeleteFunc()); *ort_value = ort_value_out.Detach(); return S_OK; } }; extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) { // make an adapter instance Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); return adapterptr.CopyTo(__uuidof(IWinMLAdapter), (void **)adapter); } // class IOBinding // =============== class IOBinding : public Microsoft::WRL::RuntimeClass < Microsoft::WRL::RuntimeClassFlags, IIOBinding> { private: std::shared_ptr binding_; std::vector outputs_weak; std::vector> outputs_; public: IOBinding(onnxruntime::IOBinding * binding) : binding_(binding) { } onnxruntime::IOBinding* STDMETHODCALLTYPE get() override { return binding_.get(); } HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, IOrtValue* ml_value) override { ORT_THROW_IF_ERROR(binding_->BindInput(name, ml_value->get())); return S_OK; } HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, IOrtValue* ml_value) override { // this can be null for unbound outputs if (ml_value == nullptr) { OrtValue empty_value = {}; ORT_THROW_IF_ERROR(binding_->BindOutput(name, empty_value)); } else { ORT_THROW_IF_ERROR(binding_->BindOutput(name, ml_value->get())); } return S_OK; } const std::vector& STDMETHODCALLTYPE GetOutputNames() override { return binding_->GetOutputNames(); } std::vector& STDMETHODCALLTYPE GetOutputs() override { auto output_inner = binding_->GetOutputs(); outputs_.clear(); for (unsigned i = 0; i < output_inner.size(); i++) { auto ort_value = wil::MakeOrThrow(output_inner[i]); outputs_.push_back(ort_value); outputs_weak.push_back(ort_value.Get()); } return outputs_weak; } }; // InferenceSession // ================ InferenceSession::InferenceSession(onnxruntime::InferenceSession * session) : session_(session) { } void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers(bool registerLotusTransforms) { GraphTransformerHelpers::RegisterGraphTransformers(session_.get(), registerLotusTransforms); } HRESULT STDMETHODCALLTYPE InferenceSession::NewIOBinding(IIOBinding** io_binding) { std::unique_ptr binding; ORT_THROW_IF_ERROR(this->session_->NewIOBinding(&binding)); auto io_binding_outer = wil::MakeOrThrow(binding.release()); return io_binding_outer.CopyTo(__uuidof(IIOBinding), (void**)io_binding); } HRESULT STDMETHODCALLTYPE InferenceSession::Run(const onnxruntime::RunOptions* run_options, IIOBinding* io_binding) { ORT_THROW_IF_ERROR(this->session_->Run(*run_options, *(io_binding->get()))); return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() { this->session_->StartProfiling(PheonixSingleton()->GetDefaultLogger()); return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() { this->session_->EndProfiling(); return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::LoadModel( IModelProto* model_proto) { auto session_protected_load_accessor = static_cast(session_.get()); std::unique_ptr model_proto_ptr(model_proto->get()); ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); return S_OK; } HRESULT STDMETHODCALLTYPE InferenceSession::RegisterCustomRegistry( IMLOperatorRegistry* registry) { RETURN_HR_IF(S_OK, registry == nullptr); auto custom_registries = GetLotusCustomRegistries(registry); // Register for (auto& custom_registry : custom_registries) { ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry)); } return S_OK; } void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) { Dml::FlushContext(dml_provider); } void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) { Dml::TrimUploadHeap(dml_provider); } void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) { Dml::ReleaseCompletedReferences(dml_provider); } } // namespace Windows::AI::MachineLearning::Adapter