use map with case insensitive hash and equals functions for learningmodel metadata. add test to verify case insensitive functionality. (#3671)

Co-authored-by: Ori Levari <orlevari@microsoft.com>
This commit is contained in:
Ori Levari 2020-04-24 10:46:10 -07:00 committed by GitHub
parent 5c7f616431
commit 66343e2fcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 2 deletions

View file

@ -148,8 +148,28 @@ STDMETHODIMP ModelInfo::GetVersion(int64_t* out) {
return S_OK;
}
struct CaseInsensitiveHash {
size_t operator()(const winrt::hstring& key) const {
size_t h = 0, i = 0;
std::for_each(key.begin(), key.end(), [&](wchar_t c) {
i++;
h += i * towlower(c);
});
return h;
}
};
struct CaseInsensitiveEqual {
bool operator()(const winrt::hstring& left, const winrt::hstring& right) const {
return left.size() == right.size() && std::equal(left.begin(), left.end(), right.begin(),
[](wchar_t a, wchar_t b) {
return towlower(a) == towlower(b);
});
}
};
STDMETHODIMP ModelInfo::GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING>** metadata) {
std::unordered_map<winrt::hstring, winrt::hstring> map_copy;
std::unordered_map<winrt::hstring, winrt::hstring, CaseInsensitiveHash, CaseInsensitiveEqual> map_copy;
for (auto& pair : model_metadata_) {
auto metadata_key = _winml::Strings::HStringFromUTF8(pair.first);
auto metadata_value = _winml::Strings::HStringFromUTF8(pair.second);

View file

@ -260,6 +260,14 @@ static void CloseModelNoNewSessions() {
});
}
static void CheckMetadataCaseInsensitive() {
LearningModel learningModel = nullptr;
WINML_EXPECT_NO_THROW(APITest::LoadModel(L"modelWithMetaData.onnx", learningModel));
IMapView metadata = learningModel.Metadata();
WINML_EXPECT_TRUE(metadata.HasKey(L"tHiSiSaLoNgKeY"));
WINML_EXPECT_EQUAL(metadata.Lookup(L"tHiSiSaLoNgKeY"), L"thisisalongvalue");
}
const LearningModelApiTestsApi& getapi() {
static constexpr LearningModelApiTestsApi api =
{
@ -279,7 +287,8 @@ const LearningModelApiTestsApi& getapi() {
EnumerateOutputs,
CloseModelCheckMetadata,
CloseModelCheckEval,
CloseModelNoNewSessions
CloseModelNoNewSessions,
CheckMetadataCaseInsensitive
};
return api;
}

View file

@ -21,6 +21,7 @@ struct LearningModelApiTestsApi
VoidTest CloseModelCheckMetadata;
VoidTest CloseModelCheckEval;
VoidTest CloseModelNoNewSessions;
VoidTest CheckMetadataCaseInsensitive;
};
const LearningModelApiTestsApi& getapi();
@ -41,6 +42,7 @@ WINML_TEST(LearningModelAPITests, EnumerateInputs)
WINML_TEST(LearningModelAPITests, EnumerateOutputs)
WINML_TEST(LearningModelAPITests, CloseModelCheckMetadata)
WINML_TEST(LearningModelAPITests, CloseModelNoNewSessions)
WINML_TEST(LearningModelAPITests, CheckMetadataCaseInsensitive)
WINML_TEST_CLASS_END()
WINML_TEST_CLASS_BEGIN(LearningModelAPITestsGpu)