register custom_op_symbolic for squeeze (#10970)

* register custom_op_symbolic for squeeze

* remove misleading warning msg from symbolic_opset9
This commit is contained in:
mindest 2022-03-24 10:28:21 +08:00 committed by GitHub
parent 7ee52fb8a0
commit 3c5853dcbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 0 deletions

View file

@ -149,6 +149,18 @@ def numpy_T(g, self):
# output a permute so use ATen instead
return g.op("com.microsoft::ATenOp", self, name_s='aten::numpy_T')
@register_symbolic('squeeze')
def squeeze(g, self, dim=None):
# Current _infer_If does not correctly infer shapes from its then- and else- branches, and will
# cause error in shape inference of following nodes, here we choose to export it as `Squeeze.`
from torch.onnx.symbolic_opset11 import squeeze as squeeze_with_if
if dim is None:
return squeeze_with_if(g, self, dim)
squeeze_dim = sym_help._get_const(dim, 'i', 'dim')
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
# For torch.einsum.
def parse_equation(equation):
pos_comma = equation.find(',')

View file

@ -4867,3 +4867,32 @@ def test_random_states_unchanged_for_ortmodule():
assert random_state_equal(ori_random_states, new_random_states)
del os.environ['ORTMODULE_FALLBACK_RETRY']
def test_squeeze_custom_symbolic_registry():
class SqueezeModel(torch.nn.Module):
def __init__(self):
super(SqueezeModel, self).__init__()
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=14, stride=14, bias=False)
def forward(self, x):
x = x.squeeze(1)
return self.conv(x)
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss
device = 'cuda'
pt_model = SqueezeModel().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_x = torch.randn(1, 1, 3, 224, 224, requires_grad=True, device=device)
ort_x = copy.deepcopy(pt_x)
pt_prediction, pt_loss = run_step(pt_model, pt_x)
ort_prediction, ort_loss = run_step(ort_model, ort_x)
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)