Fix export of aten op for Max and Avg Pool 2D (#9330)

This commit is contained in:
ashbhandare 2021-10-12 09:03:14 -07:00 committed by GitHub
parent f9cf62912a
commit 52c021d1f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View file

@ -87,6 +87,9 @@ def multinomial(g, self, num_samples, replacement=False, generator=None):
@register_symbolic('max_pool2d')
def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode):
stride_val = sym_help._maybe_get_const(stride, 'is')
if not stride_val:
stride = kernel_size
return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, dilation, ceil_mode,
name_s='aten::max_pool2d_with_indices', outputs=2)[0]
@ -103,6 +106,9 @@ def argmax(g, input, dim, keepdim):
@register_symbolic('avg_pool2d')
def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override):
stride_val = sym_help._maybe_get_const(stride, 'is')
if not stride_val:
stride = kernel_size
return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, ceil_mode,
count_include_pad, divisor_override, name_s='aten::avg_pool2d')

View file

@ -916,6 +916,40 @@ def test_gradient_correctness_pool2d(pool_type):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=4e-3)
@pytest.mark.parametrize("pool_type", ['MaxPool', 'AvgPool'])
@pytest.mark.parametrize("stride", [None, 2])
def test_export_correctness_pool2d(pool_type, stride):
class NeuralNetPool2d(torch.nn.Module):
def __init__(self):
super(NeuralNetPool2d, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.pool_type = pool_type
def forward(self, input):
x = self.conv(input)
if pool_type == 'MaxPool':
output = torch.nn.functional.max_pool2d(x, kernel_size=3, stride=stride)
elif pool_type == 'AvgPool':
output = torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=stride)
return output
N, C, H, W = 8, 3, 224, 224
device = 'cuda'
pt_model = NeuralNetPool2d().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input):
prediction = model(input)
return prediction
for _ in range(10):
input = torch.randn(N, C, H, W, device=device)
pt_prediction = run_step(pt_model, input)
ort_prediction = run_step(ort_model, input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
def test_gradient_correctness_argmax_unfold():
class NeuralNetUnfold(torch.nn.Module):
def __init__(self, input_size, hidden_size, unfold_dim, unfold_size, unfold_step):