// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "TensorKindFrom.h" #include "MapFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" namespace Windows::AI::MachineLearning { // // MapBase // // This is the base class for all data based Map types. // // Supported derived classes: // , , , // , , , // template < typename TDerived, typename TKey, typename TValue> struct MapBase : winrt::implements< MapBase, winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue, WinML::IMapFeatureValue, WinML::ILotusValueProviderPrivate> { static_assert( std::is_same::value || std::is_same::value, "Map keys must be int64_t or winrt::hstring!"); static_assert( std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!"); template struct ValidLotusType { using Type = T; }; template <> struct ValidLotusType { using Type = std::string; }; using LotusKey = typename ValidLotusType::Type; using LotusValue = typename ValidLotusType::Type; using LotusMap = std::map; using ABIMap = ::winrt::Windows::Foundation::Collections::IMap; using ABIMapView = ::winrt::Windows::Foundation::Collections::IMapView; template static typename ValidLotusType::Type ConvertToValidLotusType(TRawType raw) { return raw; } template <> static typename ValidLotusType::Type ConvertToValidLotusType(winrt::hstring raw) { return WinML::Strings::UTF8FromHString(raw); } template static TRawType ConvertToABIType(typename ValidLotusType::Type lotusValue) { return lotusValue; } template <> static typename winrt::hstring ConvertToABIType(typename ValidLotusType::Type lotusValue) { return WinML::Strings::HStringFromUTF8(lotusValue); } MapBase(ABIMap const& data) : m_data(data) {} static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create() { auto abiMap = winrt::single_threaded_map(); return winrt::make(abiMap); } static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMap& data) { return winrt::make(data); } static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMapView& data) { auto abiMap = winrt::single_threaded_map(); for (const auto& pair : data) { auto key = pair.Key(); auto value = pair.Value(); abiMap.Insert(key, value); } return winrt::make(abiMap); } // ILearningModelFeatureValue implementation winrt::Windows::AI::MachineLearning::LearningModelFeatureKind Kind() { return winrt::Windows::AI::MachineLearning::LearningModelFeatureKind::Map; } STDMETHOD(get_KeyKind) (winrt::Windows::AI::MachineLearning::TensorKind* kind) { FAIL_FAST_IF_NULL(kind); *kind = TensorKindFrom::Type; return S_OK; } STDMETHOD(get_ValueDescriptor) (winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor* result) { FAIL_FAST_IF_NULL(result); *result = TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); return S_OK; } static LotusMap ConvertToLotusMap(const ABIMap& map) { LotusMap lotusMap; for (const auto& pair : map) { auto key = ConvertToValidLotusType(pair.Key()); auto value = ConvertToValidLotusType(pair.Value()); lotusMap[key] = value; } return lotusMap; } template static onnxruntime::MLDataType GetLotusType(_winmla::IWinMLAdapter* adapter) { return adapter->GetMapType(TensorKindFrom::Type, TensorKindFrom::Type); } STDMETHOD(GetOrtValue)(WinML::BindingContext& context, _winmla::IOrtValue** mlValue) { // TODO: Tensorized data should be cached so multiple bindings work more efficiently // Create a copy of the map auto map = context.type == WinML::BindingType::kInput ? std::make_unique(ConvertToLotusMap(m_data)) : std::make_unique(); winrt::com_ptr<_winmla::IWinMLAdapter> adapter; RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); auto lotus_type = GetLotusType(adapter.get()); winrt::com_ptr<_winmla::IOrtValue> ml_value_out; adapter->CreateOrtValue(map.release(), lotus_type, ml_value_out.put()); *mlValue = ml_value_out.detach(); return S_OK; } STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder) { FAIL_FAST_IF_NULL(pIsPlaceHolder); *pIsPlaceHolder = false; return S_OK; } STDMETHOD(UpdateSourceResourceData)(BindingContext& context, _winmla::IOrtValue* mlValue) { m_data.Clear(); winrt::com_ptr<_winmla::IWinMLAdapter> adapter; RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); const LotusMap& map = *static_cast(adapter->GetMapData( mlValue, TensorKindFrom::Type, TensorKindFrom::Type)); for (const auto& pair : map) { auto key = ConvertToABIType(pair.first); auto value = ConvertToABIType(pair.second); m_data.Insert(key, value); } return S_OK; } STDMETHOD(AbiRepresentation) ( winrt::Windows::Foundation::IInspectable& abiRepresentation) { m_data.as(abiRepresentation); return S_OK; } private: ABIMap m_data; }; } // namespace Windows::AI::MachineLearning