// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "LearningModelBinding.g.h" #include "inc/ILotusValueProviderPrivate.h" #include "WinMLAdapter.h" namespace winrt::Windows::AI::MachineLearning::implementation { struct LearningModelBinding : LearningModelBindingT { struct ProviderInfo { Windows::Foundation::IInspectable CallerSpecifiedFeatureValue = nullptr; winrt::com_ptr Provider = nullptr; WinML::BindingContext Context = {}; }; public: using KeyValuePair = Windows::Foundation::Collections::IKeyValuePair; LearningModelBinding() = delete; LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); void Clear(); Windows::Foundation::Collections::IIterator First(); Windows::Foundation::IInspectable Lookup(hstring const& key); uint32_t Size(); bool HasKey(hstring const& key); void Split( Windows::Foundation::Collections::IMapView& first, Windows::Foundation::Collections::IMapView& second); std::tuple CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& value, Windows::Foundation::Collections::IPropertySet const& properties); _winmla::IIOBinding* BindingCollection(); std::unordered_map UpdateProviders(); const Windows::AI::MachineLearning::LearningModelSession& GetSession() { return m_session; } STDMETHOD(Bind) ( const wchar_t* name, UINT32 cchName, IUnknown* value); private: void CacheProvider(std::string name, ProviderInfo& spProvider); Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, _winmla::IOrtValue* mlValue); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( _winmla::IOrtValue* mlValue, ILearningModelFeatureDescriptor& descriptor); bool IsOfTensorType(_winmla::ITensor* tensorValue, TensorKind kind); bool IsOfMapType(_winmla::IOrtValue* mlValue, TensorKind key_kind, TensorKind value_kind); bool IsOfVectorMapType(_winmla::IOrtValue* mlValue, TensorKind key_kind, TensorKind value_kind); private: const Windows::AI::MachineLearning::LearningModelSession m_session; std::unordered_map m_providers; com_ptr<_winmla::IIOBinding> m_lotusBinding; com_ptr<_winmla::IWinMLAdapter> adapter_; }; } // namespace winrt::Windows::AI::MachineLearning::implementation namespace winrt::Windows::AI::MachineLearning::factory_implementation { struct LearningModelBinding : LearningModelBindingT { }; } // namespace winrt::Windows::AI::MachineLearning::factory_implementation