From a917023f94be664fb82db2adcfacce6f35e32f74 Mon Sep 17 00:00:00 2001 From: David Brownell <38224104+davidbrownellWork@users.noreply.github.com> Date: Sun, 26 Apr 2020 11:12:26 -0700 Subject: [PATCH] Support for country-specific holidays in the DateTimeTransformer (#3701) * Support for country-specific holidays in the DateTimeTransformer Updates the DateTimeTransformer featurizer to support holidays, where holiday information is read from country-specific json files. * Addressed build breaks * Enhanced Windows strategies for scenarios when tests run from root dir * Skipping test for nuget installations --- .../cpu/date_time_transformer.cc | 125 +++++++++++++++++- .../datetimetransformer_test.cc | 66 ++++++++- 2 files changed, 188 insertions(+), 3 deletions(-) diff --git a/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc b/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc index 15c9712ac3..89eb0fd520 100644 --- a/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc @@ -9,9 +9,131 @@ #include "Featurizers/DateTimeFeaturizer.h" #include "Featurizers/../Archive.h" +#ifndef _WIN32 +# include +# include +#endif + namespace onnxruntime { namespace featurizers { +std::string GetDateTimeTransformerDataDir(void) { + // This code can be run in a variety of different environments, and the data directory could + // be impacted by the environment. Attempt to account for those different environments here. + + // Production environment + if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./FeaturizersLibrary")) + return "./FeaturizersLibrary"; + + // Get the direname (as this will be used by the strategies below) + std::string const exe(Microsoft::Featurizer::Featurizers::GetExecutable()); + std::string::size_type const lastSlash( + [&exe](void) -> std::string::size_type { + std::string::size_type slash; + + // Linux-style + slash = exe.find_last_of('/'); + + if(slash != std::string::npos) + return slash; + + // Windows-style + slash = exe.find_last_of('\\'); + if(slash != std::string::npos) + return slash; + + return std::string::npos; + }() + ); + + std::string const dirname( + [&exe, &lastSlash](void) -> std::string { + if(lastSlash == std::string::npos) + return ""; + + // Include the slash in the dirname + return std::string(exe.c_str(), exe.c_str() + lastSlash + 1); + }() + ); + + if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "FeaturizersLibrary")) + return dirname + "FeaturizersLibrary"; + + // Python environment + { + // Is the executable python? + std::string const basename(lastSlash != std::string::npos ? &exe[lastSlash + 1] : exe.c_str()); + + if(strncmp(basename.c_str(), "python", 6) == 0) { + +#if (defined _WIN32) + // Get the directory relative to python's executable + std::string const potentialDataDir(dirname + "Lib\\site-packages\\onnxruntime\\FeaturizersLibrary"); + + if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir)) + return potentialDataDir; +#else + // The site packages dir is lib/python const potentialDirs{ + // Search relative to the executable + dirname + "lib", + + // Search in the user's local path + [](void) -> std::string { + char const * const var(std::getenv("HOME")); + + if(var) + return var; + + return ""; + }() + "/.local/lib" + }; + + for(auto const &potentialDir : potentialDirs) { + if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDir)) { + DIR * dir(opendir(potentialDir.c_str())); + + assert(dir != nullptr); + + // (Ab)Using std::unique_ptr to take advantage of the custom deletion functionality + std::unique_ptr> autoCloseDir(dir, [](DIR *d) { closedir(d); }); + + dirent * info(nullptr); + + while((info = readdir(dir)) != nullptr) { + if(info->d_type != DT_DIR) + continue; + + if(strncmp(info->d_name, "python", 6) == 0) { + std::string const potentialDataDir(potentialDir + "/" + info->d_name + "/site-packages/onnxruntime/FeaturizersLibrary"); + + if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir)) + return potentialDataDir; + } + } + } + } +#endif + } + } + + // Dev environment + if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./external/FeaturizersLibrary")) + return "./external/FeaturizersLibrary"; + + if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "external/FeaturizersLibrary")) + return dirname + "external/FeaturizersLibrary"; + + // Use the default logic + return ""; +} + class DateTimeTransformer final : public OpKernel { public: explicit DateTimeTransformer(const OpKernelInfo& info) : OpKernel(info) { @@ -25,7 +147,8 @@ class DateTimeTransformer final : public OpKernel { const uint8_t* const state_data(state_tensor->Data()); Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); - return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive); + + return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive, GetDateTimeTransformerDataDir()); }()); // Get the input diff --git a/onnxruntime/test/featurizers_ops/datetimetransformer_test.cc b/onnxruntime/test/featurizers_ops/datetimetransformer_test.cc index 5db0aab0a6..42a6dcea49 100644 --- a/onnxruntime/test/featurizers_ops/datetimetransformer_test.cc +++ b/onnxruntime/test/featurizers_ops/datetimetransformer_test.cc @@ -12,11 +12,20 @@ namespace dft = Microsoft::Featurizer::Featurizers; using SysClock = std::chrono::system_clock; namespace onnxruntime { + +namespace featurizers { + +// Defined in date_time_transformer.cc +extern std::string GetDateTimeTransformerDataDir(void); + +} // namespace featurizers + namespace test { namespace { -std::vector GetStream() { - dft::DateTimeTransformer dt("", ""); + +std::vector GetStream(std::string const &optionalCountryCode=std::string()) { + dft::DateTimeTransformer dt(optionalCountryCode, onnxruntime::featurizers::GetDateTimeTransformerDataDir()); Microsoft::Featurizer::Archive ar; dt.save(ar); return ar.commit(); @@ -306,5 +315,58 @@ TEST(FeaturizersTests, DateTimeTransformer_future_2025_june_30) { test.Run(OpTester::ExpectResult::kExpectSuccess); } +TEST(FeaturizersTests, DateTimeTransformer_Country_Canada) { + std::string const dataDir(onnxruntime::featurizers::GetDateTimeTransformerDataDir()); + + if(dataDir.empty()) { + GTEST_SKIP() << + "Skipping country-based tests, as the data directory could not be found. This likely indicates that\n" + "the test is being invoked from a nuget installation, which isn't a scenario that is supported by\n" + "featurizers (featurizers will only be used via the Python ORT wrappers and data is installed as\n" + "part of the wheel).\n"; + } + + const time_t date = 157161600; + const auto date_tp = std::chrono::system_clock::from_time_t(date); + + OpTester test("DateTimeTransformer", 1, onnxruntime::kMSFeaturizersDomain); + // Add state input + auto stream = GetStream("Canada"); + auto dim = static_cast(stream.size()); + test.AddInput("State", {dim}, stream); + + // We are adding a scalar Tensor in this instance + test.AddInput("Date", {1}, {date}); + + dft::DateTimeTransformer dt("Canada", dataDir); + dft::TimePoint tp = dt.execute(date_tp); + + ASSERT_EQ(tp.holidayName, "Christmas Day"); + + test.AddOutput("year", {1}, {tp.year}); + test.AddOutput("month", {1}, {tp.month}); + test.AddOutput("day", {1}, {tp.day}); + test.AddOutput("hour", {1}, {tp.hour}); + test.AddOutput("minute", {1}, {tp.minute}); + test.AddOutput("second", {1}, {tp.second}); + test.AddOutput("amPm", {1}, {tp.amPm}); + test.AddOutput("hour12", {1}, {tp.hour12}); + test.AddOutput("dayOfWeek", {1}, {tp.dayOfWeek}); + test.AddOutput("dayOfQuarter", {1}, {tp.dayOfQuarter}); + test.AddOutput("dayOfYear", {1}, {tp.dayOfYear}); + test.AddOutput("weekOfMonth", {1}, {tp.weekOfMonth}); + test.AddOutput("quarterOfYear", {1}, {tp.quarterOfYear}); + test.AddOutput("halfOfYear", {1}, {tp.halfOfYear}); + test.AddOutput("weekIso", {1}, {tp.weekIso}); + test.AddOutput("yearIso", {1}, {tp.yearIso}); + test.AddOutput("monthLabel", {1}, {tp.monthLabel}); + test.AddOutput("amPmLabel", {1}, {tp.amPmLabel}); + test.AddOutput("dayOfWeekLabel", {1}, {tp.dayOfWeekLabel}); + test.AddOutput("holidayName", {1}, {tp.holidayName}); + test.AddOutput("isPaidTimeOff", {1}, {tp.isPaidTimeOff}); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + } // namespace test } // namespace onnxruntime