From 52c021d1f3e7d7e2bae15aabd748ac3c9566f6b1 Mon Sep 17 00:00:00 2001 From: ashbhandare Date: Tue, 12 Oct 2021 09:03:14 -0700 Subject: [PATCH] Fix export of aten op for Max and Avg Pool 2D (#9330) --- .../ortmodule/_custom_op_symbolic_registry.py | 6 ++++ .../python/orttraining_test_ortmodule_api.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index a29e03300d..b4e54de19f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -87,6 +87,9 @@ def multinomial(g, self, num_samples, replacement=False, generator=None): @register_symbolic('max_pool2d') def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode): + stride_val = sym_help._maybe_get_const(stride, 'is') + if not stride_val: + stride = kernel_size return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, dilation, ceil_mode, name_s='aten::max_pool2d_with_indices', outputs=2)[0] @@ -103,6 +106,9 @@ def argmax(g, input, dim, keepdim): @register_symbolic('avg_pool2d') def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): + stride_val = sym_help._maybe_get_const(stride, 'is') + if not stride_val: + stride = kernel_size return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, name_s='aten::avg_pool2d') diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 344556f8bf..894cd3bf8a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -916,6 +916,40 @@ def test_gradient_correctness_pool2d(pool_type): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=4e-3) +@pytest.mark.parametrize("pool_type", ['MaxPool', 'AvgPool']) +@pytest.mark.parametrize("stride", [None, 2]) +def test_export_correctness_pool2d(pool_type, stride): + class NeuralNetPool2d(torch.nn.Module): + def __init__(self): + super(NeuralNetPool2d, self).__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.pool_type = pool_type + + + def forward(self, input): + x = self.conv(input) + if pool_type == 'MaxPool': + output = torch.nn.functional.max_pool2d(x, kernel_size=3, stride=stride) + elif pool_type == 'AvgPool': + output = torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=stride) + return output + + N, C, H, W = 8, 3, 224, 224 + device = 'cuda' + pt_model = NeuralNetPool2d().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + return prediction + + for _ in range(10): + input = torch.randn(N, C, H, W, device=device) + pt_prediction = run_step(pt_model, input) + ort_prediction = run_step(ort_model, input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + def test_gradient_correctness_argmax_unfold(): class NeuralNetUnfold(torch.nn.Module): def __init__(self, input_size, hidden_size, unfold_dim, unfold_size, unfold_step):