#include #include #include #include #include #include #include namespace torch { namespace jit { using namespace script; void testSaveExtraFilesHook() { // no secrets { std::stringstream ss; { Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], ""); } } // some secret { std::stringstream ss; { SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { return {{"secret.json", "topsecret"}}; }); Module m("__torch__.m"); ExtraFilesMap extra; extra["metadata.json"] = "abc"; m.save(ss, extra); SetExportModuleExtraFilesHook(nullptr); } ss.seekg(0); { ExtraFilesMap extra; extra["metadata.json"] = ""; extra["secret.json"] = ""; jit::load(ss, c10::nullopt, extra); ASSERT_EQ(extra["metadata.json"], "abc"); ASSERT_EQ(extra["secret.json"], "topsecret"); } } } static const auto pretty_printed = R"JIT( op_version_set = 1000 def foo(x: Tensor, y: Tensor) -> Tensor: _0 = torch.add(torch.mul(x, 2), y, alpha=1) return _0 )JIT"; void testImportTooNew() { Module m("__torch__.m"); const std::vector constant_table; auto src = std::make_shared(pretty_printed); ASSERT_ANY_THROW(LEGACY_import_methods(m, src, constant_table, nullptr)); } } // namespace jit } // namespace torch