diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index cd3cd1e66c..725425e38e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -149,6 +149,18 @@ def numpy_T(g, self): # output a permute so use ATen instead return g.op("com.microsoft::ATenOp", self, name_s='aten::numpy_T') + +@register_symbolic('squeeze') +def squeeze(g, self, dim=None): + # Current _infer_If does not correctly infer shapes from its then- and else- branches, and will + # cause error in shape inference of following nodes, here we choose to export it as `Squeeze.` + from torch.onnx.symbolic_opset11 import squeeze as squeeze_with_if + if dim is None: + return squeeze_with_if(g, self, dim) + squeeze_dim = sym_help._get_const(dim, 'i', 'dim') + return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim]) + + # For torch.einsum. def parse_equation(equation): pos_comma = equation.find(',') diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index acd1b0b6c3..dbe9f61072 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4867,3 +4867,32 @@ def test_random_states_unchanged_for_ortmodule(): assert random_state_equal(ori_random_states, new_random_states) del os.environ['ORTMODULE_FALLBACK_RETRY'] + + +def test_squeeze_custom_symbolic_registry(): + class SqueezeModel(torch.nn.Module): + def __init__(self): + super(SqueezeModel, self).__init__() + self.conv = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=14, stride=14, bias=False) + def forward(self, x): + x = x.squeeze(1) + return self.conv(x) + + def run_step(model, x): + prediction = model(x) + loss = prediction.sum() + loss.backward() + return prediction, loss + + device = 'cuda' + pt_model = SqueezeModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 1, 3, 224, 224, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_prediction, pt_loss = run_step(pt_model, pt_x) + ort_prediction, ort_loss = run_step(ort_model, ort_x) + _test_helpers.assert_values_are_close(pt_prediction, ort_prediction) + _test_helpers.assert_values_are_close(pt_loss, ort_loss) + _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)