#pragma once #ifndef WINML_H_ #define WINML_H_ #include "weak_buffer.h" #include "buffer_backed_random_access_stream_reference.h" #include "weak_single_threaded_iterable.h" #define RETURN_HR_IF_FAILED(expression) \ do { \ auto _hr = expression; \ if (FAILED(_hr)) \ { \ return static_cast(_hr); \ } \ } while (0) #define FAIL_FAST_IF_HR_FAILED(expression) \ do { \ auto _hr = expression; \ if (FAILED(_hr)) \ { \ __fastfail(static_cast(_hr)); \ } \ } while (0) struct float16 { uint16_t value; }; namespace Microsoft { namespace AI { namespace MachineLearning { namespace Details { class WinMLLearningModel; class WinMLLearningModelBinding; class WinMLLearningModelSession; class WinMLLearningModelResults; extern const __declspec(selectany) _Null_terminated_ wchar_t MachineLearningDll[] = L"microsoft.ai.machinelearning.dll"; template struct Tensor { }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorFloat; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorFloat16Bit;}; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt8Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt8Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt16Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt16Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt32Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt32Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt64Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt64Bit; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorBoolean; }; template <> struct Tensor { using Type = ABI::Microsoft::AI::MachineLearning::ITensorDouble; }; template struct TensorRuntimeClassID { }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; template <> struct TensorRuntimeClassID { static const wchar_t* RuntimeClass_ID; }; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorFloat; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorFloat16Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorInt8Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt8Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt16Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorInt16Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt32Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorInt32Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt64Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorInt64Bit; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorBoolean; __declspec(selectany) const wchar_t* TensorRuntimeClassID::RuntimeClass_ID = RuntimeClass_Microsoft_AI_MachineLearning_TensorDouble; template struct TensorFactory { }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloatStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloat16BitStatics;}; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt8BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt8BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt16BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt16BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt32BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt32BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt64BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt64BitStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorBooleanStatics; }; template <> struct TensorFactory { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorDoubleStatics; }; template struct TensorFactory2 { }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloatStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloat16BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt8BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt8BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt16BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt16BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt32BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt32BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt64BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt64BitStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorBooleanStatics2; }; template <> struct TensorFactory2 { using Factory = ABI::Microsoft::AI::MachineLearning::ITensorDoubleStatics2; }; template struct TensorFactoryIID { }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; template <> struct TensorFactoryIID { static const GUID IID; }; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorFloatStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorFloat16BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt8BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt8BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt16BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt16BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt32BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt32BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt64BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt64BitStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorBooleanStatics; __declspec(selectany) const GUID TensorFactoryIID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorDoubleStatics; template struct TensorFactory2IID { }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; template <> struct TensorFactory2IID { static const GUID IID; }; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorFloatStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorFloat16BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt8BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt8BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt16BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt16BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt32BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt32BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt64BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorInt64BitStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorBooleanStatics2; __declspec(selectany) const GUID TensorFactory2IID::IID = ABI::Microsoft::AI::MachineLearning::IID_ITensorDoubleStatics2; inline HRESULT GetActivationFactory( const wchar_t* p_class_id, const IID& iid, void** factory) noexcept { // Fallback to OS binary if the redistributable is not present! auto library = LoadLibraryExW(MachineLearningDll, nullptr, 0); if (library == nullptr) { return HRESULT_FROM_WIN32(GetLastError()); } using DllGetActivationFactory = HRESULT __stdcall(HSTRING, void** factory); auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); if (!call) { auto hr = HRESULT_FROM_WIN32(GetLastError()); FreeLibrary(library); return hr; } Microsoft::WRL::ComPtr activation_factory; auto hr = call( Microsoft::WRL::Wrappers::HStringReference(p_class_id, static_cast(wcslen(p_class_id))).Get(), reinterpret_cast(activation_factory.GetAddressOf())); if (FAILED(hr)) { FreeLibrary(library); return hr; } return activation_factory->QueryInterface(iid, factory); } class WinMLLearningModel { friend class WinMLLearningModelSession; public: WinMLLearningModel(const wchar_t* model_path, size_t size) { ML_FAIL_FAST_IF(0 != Initialize(model_path, size)); } WinMLLearningModel(const char* bytes, size_t size) { ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*dont copy*/)); } private: int32_t Initialize(const wchar_t* model_path, size_t size) { Microsoft::WRL::ComPtr learningModel; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModel, ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics, &learningModel)); RETURN_HR_IF_FAILED( learningModel->LoadFromFilePath( Microsoft::WRL::Wrappers::HStringReference(model_path, static_cast(size)).Get(), m_learning_model.GetAddressOf())); return 0; } struct StoreCompleted : Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, ABI::Windows::Foundation::IAsyncOperationCompletedHandler> { HANDLE completed_event_; StoreCompleted() : completed_event_(CreateEvent(nullptr, true, false, nullptr)) {} ~StoreCompleted() { CloseHandle(completed_event_); } HRESULT STDMETHODCALLTYPE Invoke( ABI::Windows::Foundation::IAsyncOperation * /*asyncInfo*/, ABI::Windows::Foundation::AsyncStatus /*status*/) { SetEvent(completed_event_); return S_OK; } HRESULT Wait() { WaitForSingleObject(completed_event_, INFINITE); return S_OK; } }; int32_t Initialize(const char* bytes, size_t size, bool with_copy = false) { auto hr = RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED); // https://docs.microsoft.com/en-us/windows/win32/api/roapi/nf-roapi-roinitialize#return-value // RPC_E_CHANGED_MODE indicates already initialized as multithreaded if (hr < 0 && hr != RPC_E_CHANGED_MODE) { return static_cast(hr); } Microsoft::WRL::ComPtr random_access_stream_ref; if (with_copy) { // Create in memory stream Microsoft::WRL::ComPtr in_memory_random_access_stream_insp; RETURN_HR_IF_FAILED(RoActivateInstance( Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), in_memory_random_access_stream_insp.GetAddressOf())); // QI memory stream to output stream Microsoft::WRL::ComPtr output_stream; RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream)); // Create data writer factory Microsoft::WRL::ComPtr activation_factory; RETURN_HR_IF_FAILED(RoGetActivationFactory( Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(), IID_PPV_ARGS(activation_factory.GetAddressOf()))); // Create data writer object based on the in memory stream Microsoft::WRL::ComPtr data_writer; RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter( output_stream.Get(), data_writer.GetAddressOf())); // Write the model to the data writer and thus to the stream RETURN_HR_IF_FAILED( data_writer->WriteBytes(static_cast(size), reinterpret_cast(const_cast(bytes)))); // QI the in memory stream to a random access stream Microsoft::WRL::ComPtr random_access_stream; RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream)); // Create a random access stream reference factory Microsoft::WRL::ComPtr random_access_stream_ref_statics; RETURN_HR_IF_FAILED(RoGetActivationFactory( Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(), IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf()))); // Create a random access stream reference from the random access stream view on top of // the in memory stream RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream( random_access_stream.Get(), random_access_stream_ref.GetAddressOf())); Microsoft::WRL::ComPtr> async_operation; RETURN_HR_IF_FAILED(data_writer->StoreAsync(&async_operation)); auto store_completed_handler = Microsoft::WRL::Make(); RETURN_HR_IF_FAILED(async_operation->put_Completed(store_completed_handler.Get())); RETURN_HR_IF_FAILED(store_completed_handler->Wait()); } else { Microsoft::WRL::ComPtr> buffer; RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize>( &buffer, bytes, bytes + size)); RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize( &random_access_stream_ref, buffer.Get())); } // Create a learning model factory Microsoft::WRL::ComPtr learning_model; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModel, ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics, &learning_model)); // Create a learning model from the factory with the random access stream reference that points // to the random access stream view on top of the in memory stream copy of the model RETURN_HR_IF_FAILED( learning_model->LoadFromStream( random_access_stream_ref.Get(), m_learning_model.GetAddressOf())); return 0; } private: Microsoft::WRL::ComPtr m_learning_model; }; class WinMLLearningModelResults { friend class WinMLLearningModelSession; public: int32_t get_output( const wchar_t* feature_name, size_t feature_name_size, void** pp_buffer, size_t* p_capacity) { Microsoft::WRL::ComPtr> output_map; RETURN_HR_IF_FAILED(m_result->get_Outputs(&output_map)); Microsoft::WRL::ComPtr inspectable; RETURN_HR_IF_FAILED(output_map->Lookup( Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast(feature_name_size)).Get(), inspectable.GetAddressOf())); Microsoft::WRL::ComPtr output_feature_value; RETURN_HR_IF_FAILED(inspectable.As(&output_feature_value)); Microsoft::WRL::ComPtr native_tensor_float_feature_value; RETURN_HR_IF_FAILED(output_feature_value.As(&native_tensor_float_feature_value)); uint32_t size; RETURN_HR_IF_FAILED(native_tensor_float_feature_value->GetBuffer(reinterpret_cast(pp_buffer), &size)); *p_capacity = size; return 0; } private: WinMLLearningModelResults(ABI::Microsoft::AI::MachineLearning::ILearningModelEvaluationResult* p_result) { m_result = p_result; } private: Microsoft::WRL::ComPtr< ABI::Microsoft::AI::MachineLearning::ILearningModelEvaluationResult> m_result; }; class WinMLLearningModelBinding { friend class WinMLLearningModelSession; public: WinMLLearningModelBinding(const WinMLLearningModelSession& session) { ML_FAIL_FAST_IF(0 != Initialize(session)); } template int32_t bind( const wchar_t* feature_name, size_t feature_name_size, tensor_shape_type* p_shape, size_t shape_size, T* p_data, size_t data_size) { using ITensor = typename Tensor::Type; using ITensorFactory = typename TensorFactory::Factory; Microsoft::WRL::ComPtr tensor_factory; RETURN_HR_IF_FAILED( GetActivationFactory( TensorRuntimeClassID::RuntimeClass_ID, TensorFactoryIID::IID, &tensor_factory)); Microsoft::WRL::ComPtr> input_shape_iterable; RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize>( &input_shape_iterable, p_shape, p_shape + shape_size)); Microsoft::WRL::ComPtr tensor; RETURN_HR_IF_FAILED( tensor_factory->CreateFromArray( input_shape_iterable.Get(), static_cast(data_size), p_data, tensor.GetAddressOf())); Microsoft::WRL::ComPtr inspectable_tensor; RETURN_HR_IF_FAILED(tensor.As(&inspectable_tensor)); RETURN_HR_IF_FAILED( m_learning_model_binding->Bind( Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast(feature_name_size)).Get(), inspectable_tensor.Get())); return 0; } template int32_t bind( const wchar_t* /*feature_name*/, size_t /*feature_name_size*/, tensor_shape_type* /*p_shape*/, size_t /*shape_size*/) { return 0; } template int32_t bind_as_reference( const wchar_t* feature_name, size_t feature_name_size, tensor_shape_type* p_shape, size_t shape_size, T* p_data, size_t data_size) { using ITensor = typename Tensor::Type; using ITensorFactory = typename TensorFactory2::Factory; Microsoft::WRL::ComPtr tensor_factory; RETURN_HR_IF_FAILED( GetActivationFactory( TensorRuntimeClassID::RuntimeClass_ID, TensorFactory2IID::IID, &tensor_factory)); Microsoft::WRL::ComPtr> buffer; RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize>( &buffer, p_data, p_data + data_size)); Microsoft::WRL::ComPtr tensor; RETURN_HR_IF_FAILED( tensor_factory->CreateFromBuffer( static_cast(shape_size), p_shape, buffer.Get(), tensor.GetAddressOf())); Microsoft::WRL::ComPtr inspectable_tensor; RETURN_HR_IF_FAILED(tensor.As(&inspectable_tensor)); RETURN_HR_IF_FAILED( m_learning_model_binding->Bind( Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast(feature_name_size)).Get(), inspectable_tensor.Get())); return 0; } template int32_t bind_as_references( const wchar_t* feature_name, size_t feature_name_size, T** p_data, size_t* data_sizes, size_t num_buffers) { using ITensor = typename Tensor::Type; using ITensorFactory = typename TensorFactory2::Factory; std::vector> vec_buffers(num_buffers); for (size_t i = 0; i < num_buffers; i++) { RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize>( &vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i])); } std::vector raw_buffers(num_buffers); std::transform(std::begin(vec_buffers), std::end(vec_buffers), std::begin(raw_buffers), [](auto buffer) { return buffer.Detach(); }); Microsoft::WRL::ComPtr> buffers; RETURN_HR_IF_FAILED( Microsoft::WRL::MakeAndInitialize>( &buffers, raw_buffers.data(), raw_buffers.data() + num_buffers)); Microsoft::WRL::ComPtr inspectable_tensor; RETURN_HR_IF_FAILED(buffers.As(&inspectable_tensor)); RETURN_HR_IF_FAILED( m_learning_model_binding->Bind( Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast(feature_name_size)).Get(), inspectable_tensor.Get())); return 0; } private: inline int32_t Initialize(const WinMLLearningModelSession& session); private: Microsoft::WRL::ComPtr m_learning_model_binding; }; class WinMLLearningModelDevice { friend class WinMLLearningModelSession; public: WinMLLearningModelDevice() : WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_Default) {} WinMLLearningModelDevice(WinMLLearningModelDevice&& device) : m_learning_model_device(std::move(device.m_learning_model_device)) {} WinMLLearningModelDevice(const WinMLLearningModelDevice& device) : m_learning_model_device(device.m_learning_model_device) {} void operator=(const WinMLLearningModelDevice& device) { m_learning_model_device = device.m_learning_model_device; } WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind kind) { ML_FAIL_FAST_IF(0 != Initialize(kind)); } WinMLLearningModelDevice(ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice) { ML_FAIL_FAST_IF(0 != Initialize(d3dDevice)); } WinMLLearningModelDevice(ID3D12CommandQueue* queue) { ML_FAIL_FAST_IF(0 != Initialize(queue)); } static WinMLLearningModelDevice create_cpu_device() { return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_Cpu); } static WinMLLearningModelDevice create_directx_device() { return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectX); } static WinMLLearningModelDevice create_directx_high_power_device() { return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectXHighPerformance); } static WinMLLearningModelDevice create_directx_min_power_device() { return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectXMinPower); } static WinMLLearningModelDevice create_directx_device(ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice) { return WinMLLearningModelDevice(d3dDevice); } static WinMLLearningModelDevice create_directx_device(ID3D12CommandQueue* queue) { return WinMLLearningModelDevice(queue); } private: int32_t Initialize(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind kind) { Microsoft::WRL::ComPtr learning_model_device_factory; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice, ABI::Microsoft::AI::MachineLearning::IID_ILearningModelDeviceFactory, &learning_model_device_factory)); RETURN_HR_IF_FAILED(learning_model_device_factory->Create(kind, &m_learning_model_device)); return 0; } int32_t Initialize(ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice) { Microsoft::WRL::ComPtr learning_model_device_factory; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice, ABI::Microsoft::AI::MachineLearning::IID_ILearningModelDeviceStatics, &learning_model_device_factory)); RETURN_HR_IF_FAILED(learning_model_device_factory->CreateFromDirect3D11Device(d3dDevice, &m_learning_model_device)); return 0; } int32_t Initialize(ID3D12CommandQueue* queue) { Microsoft::WRL::ComPtr learning_model_device_factory; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice, __uuidof(ILearningModelDeviceFactoryNative), &learning_model_device_factory)); RETURN_HR_IF_FAILED(learning_model_device_factory->CreateFromD3D12CommandQueue(queue, &m_learning_model_device)); return 0; } private: Microsoft::WRL::ComPtr m_learning_model_device; }; class WinMLLearningModelSession { friend class WinMLLearningModelBinding; public: using Model = WinMLLearningModel; using Device = WinMLLearningModelDevice; public: WinMLLearningModelSession(const Model& model) { ML_FAIL_FAST_IF(0 != Initialize(model, Device())); } WinMLLearningModelSession(const Model& model, const Device& device) { ML_FAIL_FAST_IF(0 != Initialize(model, device)); } WinMLLearningModelResults evaluate(WinMLLearningModelBinding& binding) { Microsoft::WRL::ComPtr m_learning_model_evaluation_result; FAIL_FAST_IF_HR_FAILED( m_learning_model_session->Evaluate( binding.m_learning_model_binding.Get(), nullptr, m_learning_model_evaluation_result.GetAddressOf())); return WinMLLearningModelResults(m_learning_model_evaluation_result.Get()); } private: int32_t Initialize(const Model& model, const Device& device) { // {d7d86c54-d03d-5ae3-a958-fe952b640620} static const GUID IID_ILearningModelSessionFactory = { 0xd7d86c54, 0xd03d, 0x5ae3, { 0xa9, 0x58, 0xfe, 0x95, 0x2b, 0x64, 0x06, 0x20 } }; Microsoft::WRL::ComPtr m_learning_model_session_factory; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModelSession, IID_ILearningModelSessionFactory, &m_learning_model_session_factory)); RETURN_HR_IF_FAILED( m_learning_model_session_factory->CreateFromModelOnDevice( model.m_learning_model.Get(), device.m_learning_model_device.Get(), m_learning_model_session.GetAddressOf())); return 0; } private: Microsoft::WRL::ComPtr m_learning_model_session; }; inline int32_t WinMLLearningModelBinding::Initialize(const WinMLLearningModelSession& session) { // {ae2f1c97-2fd5-55b9-a05f-53b9dbb4f9e2} static const GUID IID_ILearningModelBindingFactory = { 0xae2f1c97, 0x2fd5, 0x55b9, { 0xa0, 0x5f, 0x53, 0xb9, 0xdb, 0xb4, 0xf9, 0xe2 } }; Microsoft::WRL::ComPtr learning_model_binding_factory; RETURN_HR_IF_FAILED( GetActivationFactory( RuntimeClass_Microsoft_AI_MachineLearning_LearningModelBinding, IID_ILearningModelBindingFactory, &learning_model_binding_factory)); RETURN_HR_IF_FAILED( learning_model_binding_factory->CreateFromSession( session.m_learning_model_session.Get(), m_learning_model_binding.GetAddressOf())); return 0; } }}}} // namespace Microsoft::AI::MachineLearning::Details #endif // WINML_H_