From a9e7e787f8108ff2b08deb435fe4da91833e67f2 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 31 Jul 2020 10:21:43 -0700 Subject: [PATCH] [jit] make clone works for interface type (#42121) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42121 This PR changes the Module API to allow register a module with module interface type, and therefore allows Module::clone works on the case where there's a module interface type being shared by two submodules. interface type will be shared by the new cloned instance in the same compilation unit bc it only contains a list of functionSchema, which does not involve any attributes compared to classType. fixes https://github.com/pytorch/pytorch/issues/41882 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D22781205 Pulled By: wanchaol fbshipit-source-id: f97f4b75970f0b434e38b5a1f778eda2c4e5109b --- test/cpp/jit/test_module_api.cpp | 81 +++++++++++++++++++ test/cpp/jit/tests.h | 1 + test/quantization/test_quantize_jit.py | 74 +++++++++++++++++ torch/csrc/jit/api/module.cpp | 17 +++- .../passes/quantization/insert_observers.cpp | 17 +++- 5 files changed, 185 insertions(+), 5 deletions(-) diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index d0fade63c8a..386addd9fbe 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -1,10 +1,47 @@ #include #include + +#include +#include +#include +#include #include namespace torch { namespace jit { +static const auto moduleInterfaceSrc = R"JIT( +class OneInterface(ModuleInterface): + def one(self, x: Tensor, y: Tensor) -> Tensor: + pass +)JIT"; + +static const std::vector subModuleMethodsSrc = {R"JIT( +def one(self, x: Tensor, y: Tensor) -> Tensor: + return self.attr * x + y + 1 + +def forward(self, x: Tensor) -> Tensor: + return self.attr + x +)JIT"}; + +static const auto parentForward = R"JIT( +def forward(self, x: Tensor) -> Tensor: + return self.subMod1.one(x, x) + self.subMod2.one(x, x) +)JIT"; + +static void import_libs( + std::shared_ptr cu, + const std::string& class_name, + const std::shared_ptr& src, + const std::vector& tensor_table) { + SourceImporter si( + cu, + &tensor_table, + [&](const std::string& name) -> std::shared_ptr { return src; }, + /*version=*/2); + si.loadType(QualifiedName(class_name)); +} + void testModuleClone() { auto cu = std::make_shared(); // creating child module @@ -34,6 +71,50 @@ void testModuleClone() { ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3); } +void testModuleCloneWithModuleInterface() { + auto cu = std::make_shared(); + + // define a initial module with two submods share same interface + Module parentMod("parentMod", cu); + Module subMod1("subMod1", cu); + Module subMod2("subMod2", cu); + + std::vector constantTable; + import_libs( + cu, + "__torch__.OneInterface", + std::make_shared(moduleInterfaceSrc), + constantTable); + + auto v1 = IValue(2); + subMod1.register_attribute("attr", IntType::get(), v1, false); + + auto v2 = IValue(4); + subMod2.register_attribute("attr", IntType::get(), v2, false); + + for (const std::string& method : subModuleMethodsSrc) { + subMod1.define(method, nativeResolver()); + subMod2.define(method, nativeResolver()); + } + + parentMod.register_attribute( + "subMod1", + cu->get_interface("__torch__.OneInterface"), + subMod1._ivalue()); + parentMod.register_attribute( + "subMod2", + cu->get_interface("__torch__.OneInterface"), + subMod2._ivalue()); + + parentMod.define(parentForward, nativeResolver()); + + Module clonedMod = parentMod.clone(); + + // clone will copy both type and data, therefore we'll have a + // different type + ASSERT_NE(clonedMod.type(), parentMod.type()); +} + void testModuleCopy() { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index c43e83d0275..089aab699a3 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -75,6 +75,7 @@ namespace jit { _(ClassDerive) \ _(SaveLoadTorchbind) \ _(ModuleInterfaceSerialization) \ + _(ModuleCloneWithModuleInterface) \ _(ClassTypeAddRemoveAttr) \ _(Inliner) \ _(LiteInterpreterAdd) \ diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 838ffbdc065..969e39210d2 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -395,6 +395,80 @@ class TestQuantizeJitPasses(QuantizationTestCase): # for weight assert len(attrs_with_prefix(m.conv, '_observer_')) == 1 + def test_insert_observers_interface(self): + @torch.jit.interface + class SubInterface(torch.nn.Module): + def addOne(self, inp) -> torch.Tensor: + pass + + class Sub(torch.nn.Module): + def __init__(self): + super(Sub, self).__init__() + self.fc = torch.nn.Linear(5, 5) + + def addOne(self, inp): + return self.fc(inp) + 1 + + def forward(self, x): + return self.addOne(x) + + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv = torch.nn.Conv2d(3, 5, 3) + self.sub = Sub() + + def forward(self, x): + return self.sub(self.conv(x)) + + m = torch.jit.script(M()) + qconfig_dict = {'sub.conv': default_qconfig} + m = prepare_jit(m, qconfig_dict) + + def test_insert_observers_interface_unshare_type(self): + @torch.jit.interface + class OperatorIf(nn.Module): + def forward(self, inp: torch.Tensor) -> torch.Tensor: + pass + + class Operator(nn.Module): + def __init__(self, a): + super().__init__() + self.a = a + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + return self.a * (inp + self.a) + + class Inner(nn.Module): + op: OperatorIf + + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, inp): + return self.op(inp) + + + class Outer(nn.Module): + def __init__(self): + super().__init__() + self.inner_a = Inner(Operator(1)) + self.inner_b = Inner(Operator(3.0)) + + def forward(self, inp): + return self.inner_a(inp) + self.inner_b(inp) + + qconfig_dict = {'inner_a': default_qconfig, 'inner_b': default_qconfig} + + eager_model = Outer() + for tracing in [True, False]: + x = torch.rand(3) + script_model = get_script_module(eager_model, tracing, x) + # make sure it runs + prepare_jit(script_model, qconfig_dict) + + def test_insert_observers_child_qconfig(self): class Sub(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 632fc815ab1..f645f73c7a9 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -203,18 +203,29 @@ Module Module::clone_impl( size_t N = type()->numAttributes(); for (size_t i = 0; i < N; ++i) { IValue s = _ivalue()->getSlot(i); - if (type()->getAttribute(i)->is_module()) { + std::string attr_name = type()->getAttributeName(i); + TypePtr attr_type = type()->getAttribute(i); + if (attr_type->is_module()) { const Module& orig = Module(s.toObject()); Module cloned = orig.clone_impl(type_remap, inplace, memo); type_remap[orig.type()] = cloned.type(); - r.register_module(type()->getAttributeName(i), cloned); + // NOTE: why do we need to manually setattr on object instead of using + // register_module here? because the attr can be a module interface + // type and hold a Module object still. register_module will not let us + // correctly set up the type for this attr, so we had to do this manually. + // In the case it's an interface type, the type will be shared by the new + // cloned instance in the same compilation unit bc it only contains a list + // of functionSchema + r.type()->addOrCheckAttribute( + attr_name, attr_type->cast() ? cloned.type() : attr_type); + r._ivalue()->setAttr(attr_name, cloned._ivalue()); } else { // this adds new slot and creates a new attribute for the underlying type // if the type is not already cloned, otherwise it will only add a new // slot and typecheck r.register_attribute( type()->getAttributeName(i), - type()->getAttribute(i), + attr_type, // we'll deepcopy the IValue in non inplace option inplace ? s : s.deepcopy(memo), type()->is_parameter(i), diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 23c37910768..768915de371 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -124,11 +124,24 @@ class ModuleCloneHelper { size_t N = type->numAttributes(); for (size_t i = 0; i < N; ++i) { IValue s = module._ivalue()->getSlot(i); - if (type->getAttribute(i)->is_module()) { + std::string attr_name = type->getAttributeName(i); + TypePtr attr_type = type->getAttribute(i); + if (attr_type->is_module()) { const Module& orig = Module(s.toObject()); Module cloned = clone_impl(orig, module_qconfig_map, type_remap, inplace, memo); - r.register_module(type->getAttributeName(i), cloned); + + // NOTE: why do we need to manually setattr on object instead of using + // register_module here? because the attr can be a module interface + // type and hold a Module object still. register_module will not let us + // correctly set up the type for this attr, so we had to do this + // manually. In the case it's an interface type, the type will be shared + // by the new cloned instance in the same compilation unit bc it only + // contains a list of functionSchema + r.type()->addOrCheckAttribute( + attr_name, + attr_type->cast() ? cloned.type() : attr_type); + r._ivalue()->setAttr(attr_name, cloned._ivalue()); } else { // we'll deepcopy the IValue in non inplace option r.register_attribute(