// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "LearningModelBinding.g.h" #include "inc/ILotusValueProviderPrivate.h" #include "core/providers/winml/winml_provider_factory.h" namespace WINMLP { struct LearningModelBinding : LearningModelBindingT { struct ProviderInfo { wf::IInspectable CallerSpecifiedFeatureValue = nullptr; winrt::com_ptr<_winml::ILotusValueProviderPrivate> Provider = nullptr; _winml::BindingContext Context = {}; }; public: using KeyValuePair = wfc::IKeyValuePair; ~LearningModelBinding(); LearningModelBinding() = delete; LearningModelBinding(winml::LearningModelSession const& session); void Bind(hstring const& name, wf::IInspectable const& value); void Bind(hstring const& name, wf::IInspectable const& value, wfc::IPropertySet const& properties); STDMETHOD(Bind)(const wchar_t* name, UINT32 cchName, IUnknown* value); void Clear(); wfc::IIterator First(); wf::IInspectable Lookup(hstring const& key); uint32_t Size(); bool HasKey(hstring const& key); void Split(wfc::IMapView& first, wfc::IMapView& second); std::tuple, _winml::BindingType> CreateBinding( const std::string& name, const wf::IInspectable& value, wfc::IPropertySet const& properties ); std::unordered_map UpdateProviders(); const winml::LearningModelSession& GetSession() { return m_session; } const std::vector& GetInputNames() const; const std::vector& GetOutputNames() const; const std::vector>& GetInputs() const; std::vector>& GetOutputs(); HRESULT BindOutput(const std::string& name, winrt::com_ptr<_winml::IValue> value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); wf::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr<_winml::IValue> value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( const winrt::com_ptr<_winml::IValue> value, ILearningModelFeatureDescriptor& descriptor ); HRESULT BindInput(const std::string& name, winrt::com_ptr<_winml::IValue> value); private: const winml::LearningModelSession m_session; std::unordered_map m_providers; std::vector input_names_; std::vector> inputs_; std::vector output_names_; std::vector> outputs_; }; } // namespace WINMLP namespace WINML::factory_implementation { struct LearningModelBinding : LearningModelBindingT {}; } // namespace WINML::factory_implementation