// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "TensorKindFrom.h" #include "MapFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" namespace Windows::AI::MachineLearning { // // MapBase // // This is the base class for all data based Map types. // // Supported derived classes: // , , , // , , , // template < typename TDerived, typename TKey, typename TValue> struct MapBase : winrt::implements< MapBase, winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue, WinML::IMapFeatureValue, WinML::ILotusValueProviderPrivate> { static_assert( std::is_same::value || std::is_same::value, "Map keys must be int64_t or winrt::hstring!"); static_assert( std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!"); template struct ValidLotusType { using Type = T; }; template <> struct ValidLotusType { using Type = std::string; }; using LotusKey = typename ValidLotusType::Type; using LotusValue = typename ValidLotusType::Type; using LotusMap = std::map; using ABIMap = ::winrt::Windows::Foundation::Collections::IMap; using ABIMapView = ::winrt::Windows::Foundation::Collections::IMapView; template static typename ValidLotusType::Type ConvertToValidLotusType(TRawType raw) { return raw; } template <> static typename ValidLotusType::Type ConvertToValidLotusType(winrt::hstring raw) { return WinML::Strings::UTF8FromHString(raw); } template static TRawType ConvertToABIType(typename ValidLotusType::Type lotusValue) { return lotusValue; } template <> static typename winrt::hstring ConvertToABIType(typename ValidLotusType::Type lotusValue) { return WinML::Strings::HStringFromUTF8(lotusValue); } MapBase(ABIMap const& data) : m_data(data) {} static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create() { auto abiMap = winrt::single_threaded_map(); return winrt::make(abiMap); } static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMap& data) { return winrt::make(data); } static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMapView& data) { auto abiMap = winrt::single_threaded_map(); for (const auto& pair : data) { auto key = pair.Key(); auto value = pair.Value(); abiMap.Insert(key, value); } return winrt::make(abiMap); } // ILearningModelFeatureValue implementation winrt::Windows::AI::MachineLearning::LearningModelFeatureKind Kind() { return winrt::Windows::AI::MachineLearning::LearningModelFeatureKind::Map; } STDMETHOD(get_KeyKind) (winrt::Windows::AI::MachineLearning::TensorKind* kind) { FAIL_FAST_IF_NULL(kind); *kind = TensorKindFrom::Type; return S_OK; } STDMETHOD(get_ValueDescriptor) (winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor* result) { FAIL_FAST_IF_NULL(result); *result = TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); return S_OK; } static LotusMap ConvertToLotusMap(const ABIMap& map) { LotusMap lotusMap; for (const auto& pair : map) { auto key = ConvertToValidLotusType(pair.Key()); auto value = ConvertToValidLotusType(pair.Value()); lotusMap[key] = value; } return lotusMap; } STDMETHOD(GetOrtValue) (WinML::BindingContext& context, OrtValue* mlValue) { // TODO: Tensorized data should be cached so multiple bindings work more efficiently // Create a copy of the map auto map = context.type == WinML::BindingType::kInput ? std::make_unique(ConvertToLotusMap(m_data)) : std::make_unique(); OrtValue value; value.Init( map.release(), onnxruntime::DataTypeImpl::GetType(), onnxruntime::DataTypeImpl::GetType()->GetDeleteFunc()); *mlValue = value; return S_OK; } STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder) { FAIL_FAST_IF_NULL(pIsPlaceHolder); *pIsPlaceHolder = false; return S_OK; } STDMETHOD(UpdateSourceResourceData) (BindingContext& context, OrtValue& mlValue) { m_data.Clear(); const auto& map = mlValue.Get(); for (const auto& pair : map) { auto key = ConvertToABIType(pair.first); auto value = ConvertToABIType(pair.second); m_data.Insert(key, value); } return S_OK; } STDMETHOD(AbiRepresentation) ( winrt::Windows::Foundation::IInspectable& abiRepresentation) { m_data.as(abiRepresentation); return S_OK; } private: ABIMap m_data; }; } // namespace Windows::AI::MachineLearning