// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "TensorKindFrom.h" #include "MapFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" namespace _winml { // // MapBase // // This is the base class for all data based Map types. // // Supported derived classes: // , , , // , , , // template struct MapBase : winrt::implements< MapBase, winml::ILearningModelFeatureValue, _winml::IMapFeatureValue, _winml::ILotusValueProviderPrivate> { static_assert( std::is_same::value || std::is_same::value, "Map keys must be int64_t or winrt::hstring!" ); static_assert( std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!" ); using ABIMap = wfc::IMap; using ABIMapView = wfc::IMapView; MapBase(ABIMap const& data) : data_(data) {} static winml::ILearningModelFeatureValue Create() { auto abiMap = winrt::single_threaded_map(); return winrt::make(abiMap); } static winml::ILearningModelFeatureValue Create(const ABIMap& data) { return winrt::make(data); } static winml::ILearningModelFeatureValue Create(const ABIMapView& data) { auto abiMap = winrt::single_threaded_map(); for (const auto& pair : data) { auto key = pair.Key(); auto value = pair.Value(); abiMap.Insert(key, value); } return winrt::make(abiMap); } // ILearningModelFeatureValue implementation winml::LearningModelFeatureKind Kind() { return winml::LearningModelFeatureKind::Map; } STDMETHOD(get_KeyKind) (winml::TensorKind* kind) { FAIL_FAST_IF_NULL(kind); *kind = TensorKindFrom::Type; return S_OK; } STDMETHOD(get_ValueDescriptor) (winml::ILearningModelFeatureDescriptor* result) { FAIL_FAST_IF_NULL(result); *result = TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); return S_OK; } STDMETHOD(GetValue) (_winml::BindingContext& context, IValue** out) { auto session = context.session.as(); auto engine = session->GetEngine(); if (context.type == _winml::BindingType::kInput) { RETURN_IF_FAILED(engine->CreateMapValue( reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, out )); } else { RETURN_IF_FAILED(engine->CreateNullValue(out)); } return S_OK; } STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder) { FAIL_FAST_IF_NULL(pIsPlaceHolder); *pIsPlaceHolder = false; return S_OK; } STDMETHOD(UpdateSourceResourceData) (BindingContext& context, IValue* value) { data_.Clear(); auto session = context.session.as(); auto engine = session->GetEngine(); RETURN_IF_FAILED(engine->FillFromMapValue( reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, value )); return S_OK; } STDMETHOD(AbiRepresentation) (wf::IInspectable& abiRepresentation) { data_.as(abiRepresentation); return S_OK; } private: ABIMap data_; }; } // namespace _winml