mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Fix export of aten op for Max and Avg Pool 2D (#9330)
This commit is contained in:
parent
f9cf62912a
commit
52c021d1f3
2 changed files with 40 additions and 0 deletions
|
|
@ -87,6 +87,9 @@ def multinomial(g, self, num_samples, replacement=False, generator=None):
|
|||
|
||||
@register_symbolic('max_pool2d')
|
||||
def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode):
|
||||
stride_val = sym_help._maybe_get_const(stride, 'is')
|
||||
if not stride_val:
|
||||
stride = kernel_size
|
||||
return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, dilation, ceil_mode,
|
||||
name_s='aten::max_pool2d_with_indices', outputs=2)[0]
|
||||
|
||||
|
|
@ -103,6 +106,9 @@ def argmax(g, input, dim, keepdim):
|
|||
|
||||
@register_symbolic('avg_pool2d')
|
||||
def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override):
|
||||
stride_val = sym_help._maybe_get_const(stride, 'is')
|
||||
if not stride_val:
|
||||
stride = kernel_size
|
||||
return g.op("com.microsoft::ATenOp", self, kernel_size, stride, padding, ceil_mode,
|
||||
count_include_pad, divisor_override, name_s='aten::avg_pool2d')
|
||||
|
||||
|
|
|
|||
|
|
@ -916,6 +916,40 @@ def test_gradient_correctness_pool2d(pool_type):
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=4e-3)
|
||||
|
||||
@pytest.mark.parametrize("pool_type", ['MaxPool', 'AvgPool'])
|
||||
@pytest.mark.parametrize("stride", [None, 2])
|
||||
def test_export_correctness_pool2d(pool_type, stride):
|
||||
class NeuralNetPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetPool2d, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.pool_type = pool_type
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv(input)
|
||||
if pool_type == 'MaxPool':
|
||||
output = torch.nn.functional.max_pool2d(x, kernel_size=3, stride=stride)
|
||||
elif pool_type == 'AvgPool':
|
||||
output = torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=stride)
|
||||
return output
|
||||
|
||||
N, C, H, W = 8, 3, 224, 224
|
||||
device = 'cuda'
|
||||
pt_model = NeuralNetPool2d().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input):
|
||||
prediction = model(input)
|
||||
return prediction
|
||||
|
||||
for _ in range(10):
|
||||
input = torch.randn(N, C, H, W, device=device)
|
||||
pt_prediction = run_step(pt_model, input)
|
||||
ort_prediction = run_step(ort_model, input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
|
||||
def test_gradient_correctness_argmax_unfold():
|
||||
class NeuralNetUnfold(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, unfold_dim, unfold_size, unfold_step):
|
||||
|
|
|
|||
Loading…
Reference in a new issue