// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "LearningModelBinding.g.h" #include "inc/ILotusValueProviderPrivate.h" #include "WinMLAdapter.h" namespace winrt::Windows::AI::MachineLearning::implementation { struct LearningModelBinding : LearningModelBindingT { struct ProviderInfo { Windows::Foundation::IInspectable CallerSpecifiedFeatureValue = nullptr; winrt::com_ptr Provider = nullptr; WinML::BindingContext Context = {}; }; public: using KeyValuePair = Windows::Foundation::Collections::IKeyValuePair; LearningModelBinding() = delete; LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); void Clear(); Windows::Foundation::Collections::IIterator First(); Windows::Foundation::IInspectable Lookup(hstring const& key); uint32_t Size(); bool HasKey(hstring const& key); void Split( Windows::Foundation::Collections::IMapView& first, Windows::Foundation::Collections::IMapView& second); std::tuple CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& value, Windows::Foundation::Collections::IPropertySet const& properties); std::unordered_map UpdateProviders(); const Windows::AI::MachineLearning::LearningModelSession& GetSession() { return m_session; } STDMETHOD(Bind) ( const wchar_t* name, UINT32 cchName, IUnknown* value); const std::vector& LearningModelBinding::GetOutputNames() const; std::vector& LearningModelBinding::GetOutputs(); const std::vector& LearningModelBinding::GetInputNames() const; const std::vector& LearningModelBinding::GetInputs() const; HRESULT BindOutput(const std::string& name, Ort::Value& ml_value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, Ort::Value& ort_value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( const Ort::Value& ort_value, ILearningModelFeatureDescriptor& descriptor); bool IsOfTensorType(const Ort::Value& ort_value, TensorKind kind); bool IsOfMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); bool IsOfVectorMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); HRESULT BindInput(const std::string& name, Ort::Value& ml_value); private: const Windows::AI::MachineLearning::LearningModelSession m_session; std::unordered_map m_providers; com_ptr adapter_; std::vector input_names_; std::vector inputs_; std::vector output_names_; std::vector outputs_; }; } // namespace winrt::Windows::AI::MachineLearning::implementation namespace winrt::Windows::AI::MachineLearning::factory_implementation { struct LearningModelBinding : LearningModelBindingT { }; } // namespace winrt::Windows::AI::MachineLearning::factory_implementation