Save quantization_tag in export graph serialization (#127473)

Summary: `quantization_tag` is a first class citizen metadata in quantization flows that is preserved by it. As we'll want to store the quantized exported graphs we also need to preserve this metadata as it's used in later flows. Only json supported metadata will be allowed to be serialized.

Test Plan: Added test case

Differential Revision: D57939282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127473
Approved by: https://github.com/angelayi
This commit is contained in:
Tarun Karuturi 2024-07-12 05:06:40 +00:00 committed by PyTorch MergeBot
parent b7d287fbec
commit ff25dfca5a
2 changed files with 28 additions and 0 deletions

View file

@ -1190,6 +1190,27 @@ class TestSerializeCustomClass(TestCase):
ep = deserialize(serialized_vals)
self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor))
def test_quantization_tag_metadata(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x + x
f = Foo()
inputs = (torch.zeros(4, 4),)
ep = export(f, inputs)
for node in ep.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.meta["quantization_tag"] = "foo"
serialized_vals = serialize(ep)
ep = deserialize(serialized_vals)
for node in ep.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
self.assertTrue(node.meta["quantization_tag"] == "foo")
if __name__ == "__main__":
run_tests()

View file

@ -594,6 +594,9 @@ class GraphModuleSerializer(metaclass=Final):
if torch_fn := node.meta.get("torch_fn"):
ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
if quantization_tag := node.meta.get("quantization_tag"):
ret["quantization_tag"] = json.dumps(quantization_tag)
return ret
def serialize_script_obj_meta(
@ -2149,6 +2152,10 @@ class GraphModuleDeserializer(metaclass=Final):
if torch_fn_str := metadata.get("torch_fn"):
ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
if quantization_tag_str := metadata.get("quantization_tag"):
ret["quantization_tag"] = json.loads(quantization_tag_str)
return ret
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: