#include #include #include #include #include #include #include namespace torch { namespace jit { void testSaveExtraFilesHook() { // no secrets { std::stringstream ss; { Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], ""); } } // some secret { std::stringstream ss; { SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"secret.json", "topsecret"}}; }); Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); SetExportModuleExtraFilesHook(nullptr); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], "topsecret"); } } } void testTypeTags() { auto list = c10::List>(); list.push_back(c10::List({1, 2, 3})); list.push_back(c10::List({4, 5, 6})); auto dict = c10::Dict(); dict.insert("Hello", torch::ones({2, 2})); auto dict_list = c10::List>(); for (size_t i = 0; i < 5; i++) { auto another_dict = c10::Dict(); another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); dict_list.push_back(another_dict); } auto tuple = std::tuple(2, "hi"); struct TestItem { IValue value; TypePtr expected_type; }; std::vector items = { {list, ListType::create(ListType::create(IntType::get()))}, {2, IntType::get()}, {dict, DictType::create(StringType::get(), TensorType::get())}, {dict_list, ListType::create(DictType::create(StringType::get(), TensorType::get()))}, {tuple, TupleType::create({IntType::get(), StringType::get()})} }; for (auto item : items) { auto bytes = torch::pickle_save(item.value); auto loaded = torch::pickle_load(bytes); ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type)); ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type())); } } } // namespace jit } // namespace torch