correct symbolic name of GridSample operation (#10782)

Function name needs to match PyTorch ATen op name, which is `aten::grid_sampler`.
This commit is contained in:
Daigo HIROOKA 2022-03-08 05:49:12 +09:00 committed by GitHub
parent 3e54f94bb0
commit a08036da09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,7 +33,7 @@ def register():
Should be run before torch.onnx.export().
"""
def grid_sample(g, input, grid, mode, padding_mode, align_corners):
def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
# mode
# 'bilinear' : onnx::Constant[value={0}]
# 'nearest' : onnx::Constant[value={1}]
@ -59,7 +59,7 @@ def register():
mode_s=mode_str,
padding_mode_s=padding_mode_str,
align_corners_i=align_corners)
_reg(grid_sample)
_reg(grid_sampler)
def inverse(g, self):
return g.op("com.microsoft::Inverse", self).setType(self.type())