pytorch/test/cpp/jit/test_upgrader_utils.cpp
Tugsbayasgalan (Tugsuu) Manlaibaatar df3cbcff28 Add utility methods to find an upgrader (#68355)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68355

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D33198156

Pulled By: tugsbayasgalan

fbshipit-source-id: 68380148f0d9bee96d8090bf01c8dfca8e1f8b12
2021-12-24 12:23:04 -08:00

62 lines
1.9 KiB
C++

#include <gtest/gtest.h>
#include <torch/csrc/jit/operator_upgraders/utils.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <test/cpp/jit/test_utils.h>
namespace torch {
namespace jit {
TEST(UpgraderUtils, FindCorrectUpgrader) {
std::vector<UpgraderEntry> dummy_entry = {
{4, "foo__0_3", "foo.bar()"},
{8, "foo__4_7", "foo.bar()"},
};
auto upgrader_at_6 = findUpgrader(dummy_entry, 6);
EXPECT_TRUE(upgrader_at_6.has_value());
EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7");
auto upgrader_at_1 = findUpgrader(dummy_entry, 1);
EXPECT_TRUE(upgrader_at_1.has_value());
EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
auto upgrader_at_10 = findUpgrader(dummy_entry, 10);
EXPECT_TRUE(upgrader_at_1.has_value());
EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
}
TEST(UpgraderUtils, IsVersionMapSorted) {
auto map = get_operator_version_map();
// tests if the each list of UpgraderEntry in the map is sorted by
// their bumped_at_version field.
for (const auto& entry : map) {
std::vector<int> versions;
for (const auto& el : entry.second) {
versions.push_back(el.bumped_at_version);
}
EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end()));
}
}
TEST(UpgraderUtils, FindIfOpIsCurrent) {
std::vector<UpgraderEntry> dummy_entry = {
{4, "foo__0_3", "foo.bar()"},
{8, "foo__4_7", "foo.bar()"},
};
auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6);
auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8);
EXPECT_FALSE(isCurrent);
EXPECT_TRUE(isCurrentV2);
// symbol based look up
test_only_add_entry("foo", dummy_entry[0]);
test_only_add_entry("foo", dummy_entry[1]);
EXPECT_FALSE(isOpSymbolCurrent("foo", 6));
EXPECT_TRUE(isOpSymbolCurrent("foo", 8));
test_only_remove_entry("foo");
}
} // namespace jit
} // namespace torch