mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
register custom_op_symbolic for squeeze (#10970)
* register custom_op_symbolic for squeeze * remove misleading warning msg from symbolic_opset9
This commit is contained in:
parent
7ee52fb8a0
commit
3c5853dcbc
2 changed files with 41 additions and 0 deletions
|
|
@ -149,6 +149,18 @@ def numpy_T(g, self):
|
|||
# output a permute so use ATen instead
|
||||
return g.op("com.microsoft::ATenOp", self, name_s='aten::numpy_T')
|
||||
|
||||
|
||||
@register_symbolic('squeeze')
|
||||
def squeeze(g, self, dim=None):
|
||||
# Current _infer_If does not correctly infer shapes from its then- and else- branches, and will
|
||||
# cause error in shape inference of following nodes, here we choose to export it as `Squeeze.`
|
||||
from torch.onnx.symbolic_opset11 import squeeze as squeeze_with_if
|
||||
if dim is None:
|
||||
return squeeze_with_if(g, self, dim)
|
||||
squeeze_dim = sym_help._get_const(dim, 'i', 'dim')
|
||||
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
|
||||
|
||||
|
||||
# For torch.einsum.
|
||||
def parse_equation(equation):
|
||||
pos_comma = equation.find(',')
|
||||
|
|
|
|||
|
|
@ -4867,3 +4867,32 @@ def test_random_states_unchanged_for_ortmodule():
|
|||
assert random_state_equal(ori_random_states, new_random_states)
|
||||
|
||||
del os.environ['ORTMODULE_FALLBACK_RETRY']
|
||||
|
||||
|
||||
def test_squeeze_custom_symbolic_registry():
|
||||
class SqueezeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(SqueezeModel, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=14, stride=14, bias=False)
|
||||
def forward(self, x):
|
||||
x = x.squeeze(1)
|
||||
return self.conv(x)
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction, loss
|
||||
|
||||
device = 'cuda'
|
||||
pt_model = SqueezeModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(1, 1, 3, 224, 224, requires_grad=True, device=device)
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
|
||||
pt_prediction, pt_loss = run_step(pt_model, pt_x)
|
||||
ort_prediction, ort_loss = run_step(ort_model, ort_x)
|
||||
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
|
||||
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)
|
||||
|
|
|
|||
Loading…
Reference in a new issue