From 66343e2fcfef1be19d2af26429b6fb990e505f75 Mon Sep 17 00:00:00 2001 From: Ori Levari Date: Fri, 24 Apr 2020 10:46:10 -0700 Subject: [PATCH] use map with case insensitive hash and equals functions for learningmodel metadata. add test to verify case insensitive functionality. (#3671) Co-authored-by: Ori Levari --- winml/lib/Api.Ort/OnnxruntimeModel.cpp | 22 +++++++++++++++++++++- winml/test/api/LearningModelAPITest.cpp | 11 ++++++++++- winml/test/api/LearningModelAPITest.h | 2 ++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp index 5e0f76ae91..26466e07cf 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -148,8 +148,28 @@ STDMETHODIMP ModelInfo::GetVersion(int64_t* out) { return S_OK; } +struct CaseInsensitiveHash { + size_t operator()(const winrt::hstring& key) const { + size_t h = 0, i = 0; + std::for_each(key.begin(), key.end(), [&](wchar_t c) { + i++; + h += i * towlower(c); + }); + return h; + } +}; + +struct CaseInsensitiveEqual { + bool operator()(const winrt::hstring& left, const winrt::hstring& right) const { + return left.size() == right.size() && std::equal(left.begin(), left.end(), right.begin(), + [](wchar_t a, wchar_t b) { + return towlower(a) == towlower(b); + }); + } +}; + STDMETHODIMP ModelInfo::GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView** metadata) { - std::unordered_map map_copy; + std::unordered_map map_copy; for (auto& pair : model_metadata_) { auto metadata_key = _winml::Strings::HStringFromUTF8(pair.first); auto metadata_value = _winml::Strings::HStringFromUTF8(pair.second); diff --git a/winml/test/api/LearningModelAPITest.cpp b/winml/test/api/LearningModelAPITest.cpp index 24abbc361e..d9a514e4be 100644 --- a/winml/test/api/LearningModelAPITest.cpp +++ b/winml/test/api/LearningModelAPITest.cpp @@ -260,6 +260,14 @@ static void CloseModelNoNewSessions() { }); } +static void CheckMetadataCaseInsensitive() { + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(APITest::LoadModel(L"modelWithMetaData.onnx", learningModel)); + IMapView metadata = learningModel.Metadata(); + WINML_EXPECT_TRUE(metadata.HasKey(L"tHiSiSaLoNgKeY")); + WINML_EXPECT_EQUAL(metadata.Lookup(L"tHiSiSaLoNgKeY"), L"thisisalongvalue"); +} + const LearningModelApiTestsApi& getapi() { static constexpr LearningModelApiTestsApi api = { @@ -279,7 +287,8 @@ const LearningModelApiTestsApi& getapi() { EnumerateOutputs, CloseModelCheckMetadata, CloseModelCheckEval, - CloseModelNoNewSessions + CloseModelNoNewSessions, + CheckMetadataCaseInsensitive }; return api; } \ No newline at end of file diff --git a/winml/test/api/LearningModelAPITest.h b/winml/test/api/LearningModelAPITest.h index 6881c99b31..cd83624620 100644 --- a/winml/test/api/LearningModelAPITest.h +++ b/winml/test/api/LearningModelAPITest.h @@ -21,6 +21,7 @@ struct LearningModelApiTestsApi VoidTest CloseModelCheckMetadata; VoidTest CloseModelCheckEval; VoidTest CloseModelNoNewSessions; + VoidTest CheckMetadataCaseInsensitive; }; const LearningModelApiTestsApi& getapi(); @@ -41,6 +42,7 @@ WINML_TEST(LearningModelAPITests, EnumerateInputs) WINML_TEST(LearningModelAPITests, EnumerateOutputs) WINML_TEST(LearningModelAPITests, CloseModelCheckMetadata) WINML_TEST(LearningModelAPITests, CloseModelNoNewSessions) +WINML_TEST(LearningModelAPITests, CheckMetadataCaseInsensitive) WINML_TEST_CLASS_END() WINML_TEST_CLASS_BEGIN(LearningModelAPITestsGpu)