pytorch/test/quantization/test_equalize.py
Edmund Williams Jr fd62847eb2 cross_layer_equalization (#41685)
Summary:
The goal is to implement cross layer equalization as described in section 4.1 in this paper: https://arxiv.org/pdf/1906.04721.pdf
Given two adjacent submodules in a trained model, A,B quantization might hurt one of the submodules more than the other. The paper poses the idea that a loss in accuracy from quantizing can be due to a difference in the channel ranges between the two submodules (the output channel range of A can be small, while the input channel range of B can be large). To minimize this source of error, we want to scale the tensors of A,B s.t. their channel ranges are equal (them being equal means no difference in ranges and minimizes this source of error).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41685

Test Plan: Imported from OSS

Reviewed By: z-a-f

Differential Revision: D22630219

Pulled By: edmundw314

fbshipit-source-id: ccc91ba12c10b652d7275222da8b85455b8a7cd5
2020-07-22 08:39:23 -07:00

95 lines
3.6 KiB
Python

import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
import torch.quantization._equalize as _equalize
import copy
class TestEqualizeEager(QuantizationTestCase):
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
''' Checks the channel ranges of tensor1, tensor2 are the same,
which is an indication that equalization has been applied correctly
'''
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
# ensuring the channels ranges of tensor1's input is the same as
# tensor2's output
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
def getModule(self, model, name):
''' Given the name is a submodule to a model, return the submodule
'''
curr = model
name = name.split('.')
for subname in name:
curr = curr._modules[subname]
return curr
def test_cross_layer_equalization(self):
''' applies _equalize.cross_layer_equalization on two modules and checks
to make sure channels ranges are equivalent
'''
module1 = nn.Conv2d(3, 4, 2)
module2 = nn.Linear(4, 4)
module1_output_channel_axis = 0
module2_input_channel_axis = 1
_equalize.cross_layer_equalization(module1, module2)
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis)
def test_converged(self):
''' Sanity checks on _equalize.converged working
identical modules should return true
modules with high difference in weights should return false
'''
module1 = nn.Linear(3, 3)
module2 = nn.Linear(3, 3)
module1.weight = nn.parameter.Parameter(torch.ones(module1.weight.size()))
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
# input is a dictionary
dictionary_1 = {'linear1': module1}
dictionary_2 = {'linear1': module2}
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
def test_equalize(self):
''' First checks to see if _equalize.equalize can handle multiple
pair modules as input
then checks correctness of the function by ensuring the equalized
and unequalized versions of the model yield the same output
given the same input
'''
class ChainModule(nn.Module):
def __init__(self):
super(ChainModule, self).__init__()
self.linear1 = nn.Linear(3, 4)
self.linear2 = nn.Linear(4, 5)
self.linear3 = nn.Linear(5, 6)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
chain1 = ChainModule()
chain2 = copy.deepcopy(chain1)
_equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
linear1 = self.getModule(chain1, 'linear1')
linear2 = self.getModule(chain1, 'linear2')
linear3 = self.getModule(chain1, 'linear3')
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
input = torch.randn(20, 3)
self.assertEqual(chain1(input), chain2(input))