diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 65f2e79b3f2..d612a64708d 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -1190,6 +1190,27 @@ class TestSerializeCustomClass(TestCase): ep = deserialize(serialized_vals) self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) + def test_quantization_tag_metadata(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x + x + + f = Foo() + + inputs = (torch.zeros(4, 4),) + ep = export(f, inputs) + + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + node.meta["quantization_tag"] = "foo" + + serialized_vals = serialize(ep) + ep = deserialize(serialized_vals) + + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + self.assertTrue(node.meta["quantization_tag"] == "foo") + if __name__ == "__main__": run_tests() diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 1f6a5d0cc7f..677bded149d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -594,6 +594,9 @@ class GraphModuleSerializer(metaclass=Final): if torch_fn := node.meta.get("torch_fn"): ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + if quantization_tag := node.meta.get("quantization_tag"): + ret["quantization_tag"] = json.dumps(quantization_tag) + return ret def serialize_script_obj_meta( @@ -2149,6 +2152,10 @@ class GraphModuleDeserializer(metaclass=Final): if torch_fn_str := metadata.get("torch_fn"): ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) + + if quantization_tag_str := metadata.get("quantization_tag"): + ret["quantization_tag"] = json.loads(quantization_tag_str) + return ret def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: