From b1ef56d64669a90cb0e3584e1cb4e50f68fc2218 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Mon, 13 Dec 2021 20:22:08 -0800 Subject: [PATCH] [quant][docs] quantized model save/load instructions (#69789) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69789 Add details on how to save and load quantized models without hitting errors Test Plan: CI autogenerated docs Imported from OSS Reviewed By: jerryzh168 Differential Revision: D33030991 fbshipit-source-id: 8ec4610ae6d5bcbdd3c5e3bb725f2b06af960d52 --- docs/source/quantization.rst | 57 ++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 5e1635e00e3..69d0abc0271 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -816,6 +816,63 @@ An e2e example:: # turn off quantization for conv2 m.conv2.qconfig = None +Saving and Loading Quantized models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When calling ``torch.load`` on a quantized model, if you see an error like:: + + AttributeError: 'LinearPackedParams' object has no attribute '_modules' + +This is because directly saving and loading a quantized model using ``torch.save`` and ``torch.load`` +is not supported. To save/load quantized models, the following ways can be used: + +1. Saving/Loading the quantized model state_dict + +An example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + return x + + m = M().eval() + prepare_orig = prepare_fx(m, {'' : default_qconfig}) + prepare_orig(torch.rand(5, 5)) + quantized_orig = convert_fx(prepare_orig) + + # Save/load using state_dict + b = io.BytesIO() + torch.save(quantized_orig.state_dict(), b) + + m2 = M().eval() + prepared = prepare_fx(m2, {'' : default_qconfig}) + quantized = convert_fx(prepared) + b.seek(0) + quantized.load_state_dict(torch.load(b)) + +2. Saving/Loading scripted quantized models using ``torch.jit.save`` and ``torch.jit.load`` + +An example:: + + # Note: using the same model M from previous example + m = M().eval() + prepare_orig = prepare_fx(m, {'' : default_qconfig}) + prepare_orig(torch.rand(5, 5)) + quantized_orig = convert_fx(prepare_orig) + + # save/load using scripted model + scripted = torch.jit.script(quantized_orig) + b = io.BytesIO() + torch.jit.save(scripted, b) + b.seek(0) + scripted_quantized = torch.jit.load(b) + Numerical Debugging (prototype) -------------------------------