[quant][docs] quantized model save/load instructions (#69789)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69789

Add details on how to save and load quantized models without hitting errors

Test Plan:
CI autogenerated docs

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D33030991

fbshipit-source-id: 8ec4610ae6d5bcbdd3c5e3bb725f2b06af960d52
This commit is contained in:
Supriya Rao 2021-12-13 20:22:08 -08:00 committed by Facebook GitHub Bot
parent 2b81ea4f9a
commit b1ef56d646

View file

@ -816,6 +816,63 @@ An e2e example::
# turn off quantization for conv2
m.conv2.qconfig = None
Saving and Loading Quantized models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When calling ``torch.load`` on a quantized model, if you see an error like::
AttributeError: 'LinearPackedParams' object has no attribute '_modules'
This is because directly saving and loading a quantized model using ``torch.save`` and ``torch.load``
is not supported. To save/load quantized models, the following ways can be used:
1. Saving/Loading the quantized model state_dict
An example::
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x
m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)
# Save/load using state_dict
b = io.BytesIO()
torch.save(quantized_orig.state_dict(), b)
m2 = M().eval()
prepared = prepare_fx(m2, {'' : default_qconfig})
quantized = convert_fx(prepared)
b.seek(0)
quantized.load_state_dict(torch.load(b))
2. Saving/Loading scripted quantized models using ``torch.jit.save`` and ``torch.jit.load``
An example::
# Note: using the same model M from previous example
m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)
# save/load using scripted model
scripted = torch.jit.script(quantized_orig)
b = io.BytesIO()
torch.jit.save(scripted, b)
b.seek(0)
scripted_quantized = torch.jit.load(b)
Numerical Debugging (prototype)
-------------------------------