From ff25dfca5a60452fca5465d31f6e5740cd074d02 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Fri, 12 Jul 2024 05:06:40 +0000 Subject: [PATCH] Save quantization_tag in export graph serialization (#127473) Summary: `quantization_tag` is a first class citizen metadata in quantization flows that is preserved by it. As we'll want to store the quantized exported graphs we also need to preserve this metadata as it's used in later flows. Only json supported metadata will be allowed to be serialized. Test Plan: Added test case Differential Revision: D57939282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127473 Approved by: https://github.com/angelayi --- test/export/test_serialize.py | 21 +++++++++++++++++++++ torch/_export/serde/serialize.py | 7 +++++++ 2 files changed, 28 insertions(+) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 65f2e79b3f2..d612a64708d 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -1190,6 +1190,27 @@ class TestSerializeCustomClass(TestCase): ep = deserialize(serialized_vals) self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) + def test_quantization_tag_metadata(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x + x + + f = Foo() + + inputs = (torch.zeros(4, 4),) + ep = export(f, inputs) + + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + node.meta["quantization_tag"] = "foo" + + serialized_vals = serialize(ep) + ep = deserialize(serialized_vals) + + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + self.assertTrue(node.meta["quantization_tag"] == "foo") + if __name__ == "__main__": run_tests() diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 1f6a5d0cc7f..677bded149d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -594,6 +594,9 @@ class GraphModuleSerializer(metaclass=Final): if torch_fn := node.meta.get("torch_fn"): ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + if quantization_tag := node.meta.get("quantization_tag"): + ret["quantization_tag"] = json.dumps(quantization_tag) + return ret def serialize_script_obj_meta( @@ -2149,6 +2152,10 @@ class GraphModuleDeserializer(metaclass=Final): if torch_fn_str := metadata.get("torch_fn"): ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) + + if quantization_tag_str := metadata.get("quantization_tag"): + ret["quantization_tag"] = json.loads(quantization_tag_str) + return ret def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: