mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Save quantization_tag in export graph serialization (#127473)
Summary: `quantization_tag` is a first class citizen metadata in quantization flows that is preserved by it. As we'll want to store the quantized exported graphs we also need to preserve this metadata as it's used in later flows. Only json supported metadata will be allowed to be serialized. Test Plan: Added test case Differential Revision: D57939282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127473 Approved by: https://github.com/angelayi
This commit is contained in:
parent
b7d287fbec
commit
ff25dfca5a
2 changed files with 28 additions and 0 deletions
|
|
@ -1190,6 +1190,27 @@ class TestSerializeCustomClass(TestCase):
|
|||
ep = deserialize(serialized_vals)
|
||||
self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor))
|
||||
|
||||
def test_quantization_tag_metadata(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + x
|
||||
|
||||
f = Foo()
|
||||
|
||||
inputs = (torch.zeros(4, 4),)
|
||||
ep = export(f, inputs)
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
||||
node.meta["quantization_tag"] = "foo"
|
||||
|
||||
serialized_vals = serialize(ep)
|
||||
ep = deserialize(serialized_vals)
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
||||
self.assertTrue(node.meta["quantization_tag"] == "foo")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -594,6 +594,9 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
if torch_fn := node.meta.get("torch_fn"):
|
||||
ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
|
||||
|
||||
if quantization_tag := node.meta.get("quantization_tag"):
|
||||
ret["quantization_tag"] = json.dumps(quantization_tag)
|
||||
|
||||
return ret
|
||||
|
||||
def serialize_script_obj_meta(
|
||||
|
|
@ -2149,6 +2152,10 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
|
||||
if torch_fn_str := metadata.get("torch_fn"):
|
||||
ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
|
||||
|
||||
if quantization_tag_str := metadata.get("quantization_tag"):
|
||||
ret["quantization_tag"] = json.loads(quantization_tag_str)
|
||||
|
||||
return ret
|
||||
|
||||
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
|
||||
|
|
|
|||
Loading…
Reference in a new issue