mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Support for country-specific holidays in the DateTimeTransformer (#3701)
* Support for country-specific holidays in the DateTimeTransformer Updates the DateTimeTransformer featurizer to support holidays, where holiday information is read from country-specific json files. * Addressed build breaks * Enhanced Windows strategies for scenarios when tests run from root dir * Skipping test for nuget installations
This commit is contained in:
parent
bf1caba2b2
commit
a917023f94
2 changed files with 188 additions and 3 deletions
|
|
@ -9,9 +9,131 @@
|
|||
#include "Featurizers/DateTimeFeaturizer.h"
|
||||
#include "Featurizers/../Archive.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
# include <dirent.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace featurizers {
|
||||
|
||||
std::string GetDateTimeTransformerDataDir(void) {
|
||||
// This code can be run in a variety of different environments, and the data directory could
|
||||
// be impacted by the environment. Attempt to account for those different environments here.
|
||||
|
||||
// Production environment
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./FeaturizersLibrary"))
|
||||
return "./FeaturizersLibrary";
|
||||
|
||||
// Get the direname (as this will be used by the strategies below)
|
||||
std::string const exe(Microsoft::Featurizer::Featurizers::GetExecutable());
|
||||
std::string::size_type const lastSlash(
|
||||
[&exe](void) -> std::string::size_type {
|
||||
std::string::size_type slash;
|
||||
|
||||
// Linux-style
|
||||
slash = exe.find_last_of('/');
|
||||
|
||||
if(slash != std::string::npos)
|
||||
return slash;
|
||||
|
||||
// Windows-style
|
||||
slash = exe.find_last_of('\\');
|
||||
if(slash != std::string::npos)
|
||||
return slash;
|
||||
|
||||
return std::string::npos;
|
||||
}()
|
||||
);
|
||||
|
||||
std::string const dirname(
|
||||
[&exe, &lastSlash](void) -> std::string {
|
||||
if(lastSlash == std::string::npos)
|
||||
return "";
|
||||
|
||||
// Include the slash in the dirname
|
||||
return std::string(exe.c_str(), exe.c_str() + lastSlash + 1);
|
||||
}()
|
||||
);
|
||||
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "FeaturizersLibrary"))
|
||||
return dirname + "FeaturizersLibrary";
|
||||
|
||||
// Python environment
|
||||
{
|
||||
// Is the executable python?
|
||||
std::string const basename(lastSlash != std::string::npos ? &exe[lastSlash + 1] : exe.c_str());
|
||||
|
||||
if(strncmp(basename.c_str(), "python", 6) == 0) {
|
||||
|
||||
#if (defined _WIN32)
|
||||
// Get the directory relative to python's executable
|
||||
std::string const potentialDataDir(dirname + "Lib\\site-packages\\onnxruntime\\FeaturizersLibrary");
|
||||
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir))
|
||||
return potentialDataDir;
|
||||
#else
|
||||
// The site packages dir is lib/python<version/site-packages. Because we don't
|
||||
// know the exact version of python, enumerate through the directories under ./lib
|
||||
// and return the first one that begins with python.
|
||||
//
|
||||
// This is a huge HACK, and we should figure out a better way to do this. The python
|
||||
// version number is available in Python.h, but I don't think that that header file
|
||||
// is available for inclusion when this file is compiled.
|
||||
std::vector<std::string> const potentialDirs{
|
||||
// Search relative to the executable
|
||||
dirname + "lib",
|
||||
|
||||
// Search in the user's local path
|
||||
[](void) -> std::string {
|
||||
char const * const var(std::getenv("HOME"));
|
||||
|
||||
if(var)
|
||||
return var;
|
||||
|
||||
return "";
|
||||
}() + "/.local/lib"
|
||||
};
|
||||
|
||||
for(auto const &potentialDir : potentialDirs) {
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDir)) {
|
||||
DIR * dir(opendir(potentialDir.c_str()));
|
||||
|
||||
assert(dir != nullptr);
|
||||
|
||||
// (Ab)Using std::unique_ptr to take advantage of the custom deletion functionality
|
||||
std::unique_ptr<DIR, std::function<void (DIR *)>> autoCloseDir(dir, [](DIR *d) { closedir(d); });
|
||||
|
||||
dirent * info(nullptr);
|
||||
|
||||
while((info = readdir(dir)) != nullptr) {
|
||||
if(info->d_type != DT_DIR)
|
||||
continue;
|
||||
|
||||
if(strncmp(info->d_name, "python", 6) == 0) {
|
||||
std::string const potentialDataDir(potentialDir + "/" + info->d_name + "/site-packages/onnxruntime/FeaturizersLibrary");
|
||||
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(potentialDataDir))
|
||||
return potentialDataDir;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Dev environment
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory("./external/FeaturizersLibrary"))
|
||||
return "./external/FeaturizersLibrary";
|
||||
|
||||
if(Microsoft::Featurizer::Featurizers::IsValidDirectory(dirname + "external/FeaturizersLibrary"))
|
||||
return dirname + "external/FeaturizersLibrary";
|
||||
|
||||
// Use the default logic
|
||||
return "";
|
||||
}
|
||||
|
||||
class DateTimeTransformer final : public OpKernel {
|
||||
public:
|
||||
explicit DateTimeTransformer(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -25,7 +147,8 @@ class DateTimeTransformer final : public OpKernel {
|
|||
const uint8_t* const state_data(state_tensor->Data<uint8_t>());
|
||||
|
||||
Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size());
|
||||
return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive);
|
||||
|
||||
return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive, GetDateTimeTransformerDataDir());
|
||||
}());
|
||||
|
||||
// Get the input
|
||||
|
|
|
|||
|
|
@ -12,11 +12,20 @@ namespace dft = Microsoft::Featurizer::Featurizers;
|
|||
using SysClock = std::chrono::system_clock;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace featurizers {
|
||||
|
||||
// Defined in date_time_transformer.cc
|
||||
extern std::string GetDateTimeTransformerDataDir(void);
|
||||
|
||||
} // namespace featurizers
|
||||
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
std::vector<uint8_t> GetStream() {
|
||||
dft::DateTimeTransformer dt("", "");
|
||||
|
||||
std::vector<uint8_t> GetStream(std::string const &optionalCountryCode=std::string()) {
|
||||
dft::DateTimeTransformer dt(optionalCountryCode, onnxruntime::featurizers::GetDateTimeTransformerDataDir());
|
||||
Microsoft::Featurizer::Archive ar;
|
||||
dt.save(ar);
|
||||
return ar.commit();
|
||||
|
|
@ -306,5 +315,58 @@ TEST(FeaturizersTests, DateTimeTransformer_future_2025_june_30) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
TEST(FeaturizersTests, DateTimeTransformer_Country_Canada) {
|
||||
std::string const dataDir(onnxruntime::featurizers::GetDateTimeTransformerDataDir());
|
||||
|
||||
if(dataDir.empty()) {
|
||||
GTEST_SKIP() <<
|
||||
"Skipping country-based tests, as the data directory could not be found. This likely indicates that\n"
|
||||
"the test is being invoked from a nuget installation, which isn't a scenario that is supported by\n"
|
||||
"featurizers (featurizers will only be used via the Python ORT wrappers and data is installed as\n"
|
||||
"part of the wheel).\n";
|
||||
}
|
||||
|
||||
const time_t date = 157161600;
|
||||
const auto date_tp = std::chrono::system_clock::from_time_t(date);
|
||||
|
||||
OpTester test("DateTimeTransformer", 1, onnxruntime::kMSFeaturizersDomain);
|
||||
// Add state input
|
||||
auto stream = GetStream("Canada");
|
||||
auto dim = static_cast<int64_t>(stream.size());
|
||||
test.AddInput<uint8_t>("State", {dim}, stream);
|
||||
|
||||
// We are adding a scalar Tensor in this instance
|
||||
test.AddInput<int64_t>("Date", {1}, {date});
|
||||
|
||||
dft::DateTimeTransformer dt("Canada", dataDir);
|
||||
dft::TimePoint tp = dt.execute(date_tp);
|
||||
|
||||
ASSERT_EQ(tp.holidayName, "Christmas Day");
|
||||
|
||||
test.AddOutput<int32_t>("year", {1}, {tp.year});
|
||||
test.AddOutput<uint8_t>("month", {1}, {tp.month});
|
||||
test.AddOutput<uint8_t>("day", {1}, {tp.day});
|
||||
test.AddOutput<uint8_t>("hour", {1}, {tp.hour});
|
||||
test.AddOutput<uint8_t>("minute", {1}, {tp.minute});
|
||||
test.AddOutput<uint8_t>("second", {1}, {tp.second});
|
||||
test.AddOutput<uint8_t>("amPm", {1}, {tp.amPm});
|
||||
test.AddOutput<uint8_t>("hour12", {1}, {tp.hour12});
|
||||
test.AddOutput<uint8_t>("dayOfWeek", {1}, {tp.dayOfWeek});
|
||||
test.AddOutput<uint8_t>("dayOfQuarter", {1}, {tp.dayOfQuarter});
|
||||
test.AddOutput<uint16_t>("dayOfYear", {1}, {tp.dayOfYear});
|
||||
test.AddOutput<uint16_t>("weekOfMonth", {1}, {tp.weekOfMonth});
|
||||
test.AddOutput<uint8_t>("quarterOfYear", {1}, {tp.quarterOfYear});
|
||||
test.AddOutput<uint8_t>("halfOfYear", {1}, {tp.halfOfYear});
|
||||
test.AddOutput<uint8_t>("weekIso", {1}, {tp.weekIso});
|
||||
test.AddOutput<int32_t>("yearIso", {1}, {tp.yearIso});
|
||||
test.AddOutput<std::string>("monthLabel", {1}, {tp.monthLabel});
|
||||
test.AddOutput<std::string>("amPmLabel", {1}, {tp.amPmLabel});
|
||||
test.AddOutput<std::string>("dayOfWeekLabel", {1}, {tp.dayOfWeekLabel});
|
||||
test.AddOutput<std::string>("holidayName", {1}, {tp.holidayName});
|
||||
test.AddOutput<uint8_t>("isPaidTimeOff", {1}, {tp.isPaidTimeOff});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue