mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68355 Test Plan: Imported from OSS Reviewed By: samdow Differential Revision: D33198156 Pulled By: tugsbayasgalan fbshipit-source-id: 68380148f0d9bee96d8090bf01c8dfca8e1f8b12
62 lines
1.9 KiB
C++
62 lines
1.9 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/operator_upgraders/utils.h>
|
|
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(UpgraderUtils, FindCorrectUpgrader) {
|
|
std::vector<UpgraderEntry> dummy_entry = {
|
|
{4, "foo__0_3", "foo.bar()"},
|
|
{8, "foo__4_7", "foo.bar()"},
|
|
};
|
|
|
|
auto upgrader_at_6 = findUpgrader(dummy_entry, 6);
|
|
EXPECT_TRUE(upgrader_at_6.has_value());
|
|
EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7");
|
|
|
|
auto upgrader_at_1 = findUpgrader(dummy_entry, 1);
|
|
EXPECT_TRUE(upgrader_at_1.has_value());
|
|
EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
|
|
|
|
auto upgrader_at_10 = findUpgrader(dummy_entry, 10);
|
|
EXPECT_TRUE(upgrader_at_1.has_value());
|
|
EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
|
|
}
|
|
|
|
TEST(UpgraderUtils, IsVersionMapSorted) {
|
|
auto map = get_operator_version_map();
|
|
// tests if the each list of UpgraderEntry in the map is sorted by
|
|
// their bumped_at_version field.
|
|
for (const auto& entry : map) {
|
|
std::vector<int> versions;
|
|
for (const auto& el : entry.second) {
|
|
versions.push_back(el.bumped_at_version);
|
|
}
|
|
EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end()));
|
|
}
|
|
}
|
|
|
|
TEST(UpgraderUtils, FindIfOpIsCurrent) {
|
|
std::vector<UpgraderEntry> dummy_entry = {
|
|
{4, "foo__0_3", "foo.bar()"},
|
|
{8, "foo__4_7", "foo.bar()"},
|
|
};
|
|
|
|
auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6);
|
|
auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8);
|
|
EXPECT_FALSE(isCurrent);
|
|
EXPECT_TRUE(isCurrentV2);
|
|
|
|
// symbol based look up
|
|
test_only_add_entry("foo", dummy_entry[0]);
|
|
test_only_add_entry("foo", dummy_entry[1]);
|
|
EXPECT_FALSE(isOpSymbolCurrent("foo", 6));
|
|
EXPECT_TRUE(isOpSymbolCurrent("foo", 8));
|
|
test_only_remove_entry("foo");
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|