Support for country-specific holidays in the DateTimeTransformer (#3701)

* Support for country-specific holidays in the DateTimeTransformer

Updates the DateTimeTransformer featurizer to support holidays, where holiday information is read from country-specific json files.

* Addressed build breaks

* Enhanced Windows strategies for scenarios when tests run from root dir

* Skipping test for nuget installations
This commit is contained in:
David Brownell 2020-04-26 11:12:26 -07:00 committed by GitHub
parent bf1caba2b2
commit a917023f94
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 188 additions and 3 deletions

View file

@ -9,9 +9,131 @@
#include "Featurizers/DateTimeFeaturizer.h"
#include "Featurizers/../Archive.h"
#ifndef _WIN32
# include <dirent.h>
# include <unistd.h>
#endif
namespace onnxruntime {
namespace featurizers {
std::string GetDateTimeTransformerDataDir(void) {
// This code can be run in a variety of different environments, and the data directory could
// be impacted by the environment. Attempt to account for those different environments here.
// Production environment
if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./FeaturizersLibrary"))
return "./FeaturizersLibrary";
// Get the direname (as this will be used by the strategies below)
std::string const exe(Microsoft::Featurizer::Featurizers::GetExecutable());
std::string::size_type const lastSlash(
[&exe](void) -> std::string::size_type {
std::string::size_type slash;
// Linux-style
slash = exe.find_last_of('/');
if(slash != std::string::npos)
return slash;
// Windows-style
slash = exe.find_last_of('\\');
if(slash != std::string::npos)
return slash;
return std::string::npos;
}()
);
std::string const dirname(
[&exe, &lastSlash](void) -> std::string {
if(lastSlash == std::string::npos)
return "";
// Include the slash in the dirname
return std::string(exe.c_str(), exe.c_str() + lastSlash + 1);
}()
);
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "FeaturizersLibrary"))
return dirname + "FeaturizersLibrary";
// Python environment
{
// Is the executable python?
std::string const basename(lastSlash != std::string::npos ? &exe[lastSlash + 1] : exe.c_str());
if(strncmp(basename.c_str(), "python", 6) == 0) {
#if (defined _WIN32)
// Get the directory relative to python's executable
std::string const potentialDataDir(dirname + "Lib\\site-packages\\onnxruntime\\FeaturizersLibrary");
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir))
return potentialDataDir;
#else
// The site packages dir is lib/python<version/site-packages. Because we don't
// know the exact version of python, enumerate through the directories under ./lib
// and return the first one that begins with python.
//
// This is a huge HACK, and we should figure out a better way to do this. The python
// version number is available in Python.h, but I don't think that that header file
// is available for inclusion when this file is compiled.
std::vector<std::string> const potentialDirs{
// Search relative to the executable
dirname + "lib",
// Search in the user's local path
[](void) -> std::string {
char const * const var(std::getenv("HOME"));
if(var)
return var;
return "";
}() + "/.local/lib"
};
for(auto const &potentialDir : potentialDirs) {
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDir)) {
DIR * dir(opendir(potentialDir.c_str()));
assert(dir != nullptr);
// (Ab)Using std::unique_ptr to take advantage of the custom deletion functionality
std::unique_ptr<DIR, std::function<void (DIR *)>> autoCloseDir(dir, [](DIR *d) { closedir(d); });
dirent * info(nullptr);
while((info = readdir(dir)) != nullptr) {
if(info->d_type != DT_DIR)
continue;
if(strncmp(info->d_name, "python", 6) == 0) {
std::string const potentialDataDir(potentialDir + "/" + info->d_name + "/site-packages/onnxruntime/FeaturizersLibrary");
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir))
return potentialDataDir;
}
}
}
}
#endif
}
}
// Dev environment
if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./external/FeaturizersLibrary"))
return "./external/FeaturizersLibrary";
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "external/FeaturizersLibrary"))
return dirname + "external/FeaturizersLibrary";
// Use the default logic
return "";
}
class DateTimeTransformer final : public OpKernel {
public:
explicit DateTimeTransformer(const OpKernelInfo& info) : OpKernel(info) {
@ -25,7 +147,8 @@ class DateTimeTransformer final : public OpKernel {
const uint8_t* const state_data(state_tensor->Data<uint8_t>());
Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size());
return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive);
return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive, GetDateTimeTransformerDataDir());
}());
// Get the input

View file

@ -12,11 +12,20 @@ namespace dft = Microsoft::Featurizer::Featurizers;
using SysClock = std::chrono::system_clock;
namespace onnxruntime {
namespace featurizers {
// Defined in date_time_transformer.cc
extern std::string GetDateTimeTransformerDataDir(void);
} // namespace featurizers
namespace test {
namespace {
std::vector<uint8_t> GetStream() {
dft::DateTimeTransformer dt("", "");
std::vector<uint8_t> GetStream(std::string const &optionalCountryCode=std::string()) {
dft::DateTimeTransformer dt(optionalCountryCode, onnxruntime::featurizers::GetDateTimeTransformerDataDir());
Microsoft::Featurizer::Archive ar;
dt.save(ar);
return ar.commit();
@ -306,5 +315,58 @@ TEST(FeaturizersTests, DateTimeTransformer_future_2025_june_30) {
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(FeaturizersTests, DateTimeTransformer_Country_Canada) {
std::string const dataDir(onnxruntime::featurizers::GetDateTimeTransformerDataDir());
if(dataDir.empty()) {
GTEST_SKIP() <<
"Skipping country-based tests, as the data directory could not be found. This likely indicates that\n"
"the test is being invoked from a nuget installation, which isn't a scenario that is supported by\n"
"featurizers (featurizers will only be used via the Python ORT wrappers and data is installed as\n"
"part of the wheel).\n";
}
const time_t date = 157161600;
const auto date_tp = std::chrono::system_clock::from_time_t(date);
OpTester test("DateTimeTransformer", 1, onnxruntime::kMSFeaturizersDomain);
// Add state input
auto stream = GetStream("Canada");
auto dim = static_cast<int64_t>(stream.size());
test.AddInput<uint8_t>("State", {dim}, stream);
// We are adding a scalar Tensor in this instance
test.AddInput<int64_t>("Date", {1}, {date});
dft::DateTimeTransformer dt("Canada", dataDir);
dft::TimePoint tp = dt.execute(date_tp);
ASSERT_EQ(tp.holidayName, "Christmas Day");
test.AddOutput<int32_t>("year", {1}, {tp.year});
test.AddOutput<uint8_t>("month", {1}, {tp.month});
test.AddOutput<uint8_t>("day", {1}, {tp.day});
test.AddOutput<uint8_t>("hour", {1}, {tp.hour});
test.AddOutput<uint8_t>("minute", {1}, {tp.minute});
test.AddOutput<uint8_t>("second", {1}, {tp.second});
test.AddOutput<uint8_t>("amPm", {1}, {tp.amPm});
test.AddOutput<uint8_t>("hour12", {1}, {tp.hour12});
test.AddOutput<uint8_t>("dayOfWeek", {1}, {tp.dayOfWeek});
test.AddOutput<uint8_t>("dayOfQuarter", {1}, {tp.dayOfQuarter});
test.AddOutput<uint16_t>("dayOfYear", {1}, {tp.dayOfYear});
test.AddOutput<uint16_t>("weekOfMonth", {1}, {tp.weekOfMonth});
test.AddOutput<uint8_t>("quarterOfYear", {1}, {tp.quarterOfYear});
test.AddOutput<uint8_t>("halfOfYear", {1}, {tp.halfOfYear});
test.AddOutput<uint8_t>("weekIso", {1}, {tp.weekIso});
test.AddOutput<int32_t>("yearIso", {1}, {tp.yearIso});
test.AddOutput<std::string>("monthLabel", {1}, {tp.monthLabel});
test.AddOutput<std::string>("amPmLabel", {1}, {tp.amPmLabel});
test.AddOutput<std::string>("dayOfWeekLabel", {1}, {tp.dayOfWeekLabel});
test.AddOutput<std::string>("holidayName", {1}, {tp.holidayName});
test.AddOutput<uint8_t>("isPaidTimeOff", {1}, {tp.isPaidTimeOff});
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
} // namespace test
} // namespace onnxruntime