From a08036da099ec512bdc6e115c86fce228ad9c87c Mon Sep 17 00:00:00 2001 From: Daigo HIROOKA Date: Tue, 8 Mar 2022 05:49:12 +0900 Subject: [PATCH] correct symbolic name of GridSample operation (#10782) Function name needs to match PyTorch ATen op name, which is `aten::grid_sampler`. --- onnxruntime/python/tools/pytorch_export_contrib_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/pytorch_export_contrib_ops.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py index d217c1086f..e3ac06c47d 100644 --- a/onnxruntime/python/tools/pytorch_export_contrib_ops.py +++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py @@ -33,7 +33,7 @@ def register(): Should be run before torch.onnx.export(). """ - def grid_sample(g, input, grid, mode, padding_mode, align_corners): + def grid_sampler(g, input, grid, mode, padding_mode, align_corners): # mode # 'bilinear' : onnx::Constant[value={0}] # 'nearest' : onnx::Constant[value={1}] @@ -59,7 +59,7 @@ def register(): mode_s=mode_str, padding_mode_s=padding_mode_str, align_corners_i=align_corners) - _reg(grid_sample) + _reg(grid_sampler) def inverse(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type())