onnxruntime/winml/lib/Api/LearningModelBinding.h
Paul McDaniel 5350abe19d
LearningModelSession is cleaned up to use the adapter, and parts of b… (#2382)
this is a big PR.    we are going to move it up to layer_dev , which is still a L3 so we are still safe to do work there agile.

we are going to move this into the L3 so that ryan can start doing intergration testing.   

we will pause for a full code review and integration test result prior to going into the L2.

>>>> raw comments from previous commits >>> 

* LearningModelSession is cleaned up to use the adapter, and parts of binding are.
* moved everything in the winmladapter
made it all nano-com using, WRL to construct objects in the ORT side.
base interfaces for everythign for winml to call
cleaned up a bunch of winml to use the base interfaces.
* more pieces
* GetData across the abi.
* renamed some namepsace
cleaned up OrtValue
cleaned up Tensor
cleaned up custom ops.
everything *but* learnignmodel should be clean
* make sure it's building.   winml.dll is still a monolith.
2019-11-14 17:44:07 -08:00

79 lines
3.2 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "LearningModelBinding.g.h"
#include "inc/ILotusValueProviderPrivate.h"
#include "WinMLAdapter.h"
namespace winrt::Windows::AI::MachineLearning::implementation {
struct LearningModelBinding : LearningModelBindingT<LearningModelBinding, ILearningModelBindingNative> {
struct ProviderInfo {
Windows::Foundation::IInspectable CallerSpecifiedFeatureValue = nullptr;
winrt::com_ptr<WinML::ILotusValueProviderPrivate> Provider = nullptr;
WinML::BindingContext Context = {};
};
public:
using KeyValuePair =
Windows::Foundation::Collections::IKeyValuePair<hstring, Windows::Foundation::IInspectable>;
LearningModelBinding() = delete;
LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session);
void Bind(hstring const& name, Windows::Foundation::IInspectable const& value);
void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties);
void Clear();
Windows::Foundation::Collections::IIterator<KeyValuePair> First();
Windows::Foundation::IInspectable Lookup(hstring const& key);
uint32_t Size();
bool HasKey(hstring const& key);
void Split(
Windows::Foundation::Collections::IMapView<hstring, Windows::Foundation::IInspectable>& first,
Windows::Foundation::Collections::IMapView<hstring, Windows::Foundation::IInspectable>& second);
std::tuple<std::string, _winmla::IOrtValue*, WinML::BindingType> CreateBinding(
const std::string& name,
const Windows::Foundation::IInspectable& value,
Windows::Foundation::Collections::IPropertySet const& properties);
_winmla::IIOBinding* BindingCollection();
std::unordered_map<std::string, Windows::Foundation::IInspectable> UpdateProviders();
const Windows::AI::MachineLearning::LearningModelSession& GetSession() { return m_session; }
STDMETHOD(Bind)
(
const wchar_t* name,
UINT32 cchName,
IUnknown* value);
private:
void CacheProvider(std::string name, ProviderInfo& spProvider);
Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, _winmla::IOrtValue* mlValue);
ILearningModelFeatureValue CreateUnboundOuputFeatureValue(
_winmla::IOrtValue* mlValue,
ILearningModelFeatureDescriptor& descriptor);
bool IsOfTensorType(_winmla::ITensor* tensorValue, TensorKind kind);
bool IsOfMapType(_winmla::IOrtValue* mlValue, TensorKind key_kind, TensorKind value_kind);
bool IsOfVectorMapType(_winmla::IOrtValue* mlValue, TensorKind key_kind, TensorKind value_kind);
private:
const Windows::AI::MachineLearning::LearningModelSession m_session;
std::unordered_map<std::string, ProviderInfo> m_providers;
com_ptr<_winmla::IIOBinding> m_lotusBinding;
com_ptr<_winmla::IWinMLAdapter> adapter_;
};
} // namespace winrt::Windows::AI::MachineLearning::implementation
namespace winrt::Windows::AI::MachineLearning::factory_implementation {
struct LearningModelBinding : LearningModelBindingT<LearningModelBinding, implementation::LearningModelBinding> {
};
} // namespace winrt::Windows::AI::MachineLearning::factory_implementation