// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "LearningModelBinding.g.h" #include "inc/ILotusValueProviderPrivate.h" #include "core/providers/winml/winml_provider_factory.h" namespace winrt::Windows::AI::MachineLearning::implementation { struct LearningModelBinding : LearningModelBindingT { struct ProviderInfo { Windows::Foundation::IInspectable CallerSpecifiedFeatureValue = nullptr; winrt::com_ptr Provider = nullptr; WinML::BindingContext Context = {}; }; public: using KeyValuePair = Windows::Foundation::Collections::IKeyValuePair; LearningModelBinding() = delete; LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); STDMETHOD(Bind)(const wchar_t* name, UINT32 cchName, IUnknown* value); void Clear(); Windows::Foundation::Collections::IIterator First(); Windows::Foundation::IInspectable Lookup(hstring const& key); uint32_t Size(); bool HasKey(hstring const& key); void Split( Windows::Foundation::Collections::IMapView& first, Windows::Foundation::Collections::IMapView& second); std::tuple, WinML::BindingType> CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& value, Windows::Foundation::Collections::IPropertySet const& properties); std::unordered_map UpdateProviders(); const Windows::AI::MachineLearning::LearningModelSession& GetSession() { return m_session; } const std::vector& GetInputNames() const; const std::vector& GetOutputNames() const; const std::vector>& GetInputs() const; std::vector>& GetOutputs(); HRESULT BindOutput(const std::string& name, winrt::com_ptr value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( const winrt::com_ptr value, ILearningModelFeatureDescriptor& descriptor); HRESULT BindInput(const std::string& name, winrt::com_ptr value); private: const Windows::AI::MachineLearning::LearningModelSession m_session; std::unordered_map m_providers; std::vector input_names_; std::vector> inputs_; std::vector output_names_; std::vector> outputs_; }; } // namespace winrt::Windows::AI::MachineLearning::implementation namespace winrt::Windows::AI::MachineLearning::factory_implementation { struct LearningModelBinding : LearningModelBindingT { }; } // namespace winrt::Windows::AI::MachineLearning::factory_implementation