mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
[ORTModule] ATen Support for upsample_bilinear (#14519)
It's required by model MobileViT.
This commit is contained in:
parent
c1a0fc55e7
commit
3d7518762a
4 changed files with 58 additions and 9 deletions
|
|
@ -218,9 +218,10 @@ class SymbolicShapeInference:
|
|||
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
|
||||
"numpy_T": self._infer_Transpose,
|
||||
"native_group_norm": self._infer_aten_group_norm,
|
||||
"upsample_nearest1d": self._infer_aten_upsample_nearest,
|
||||
"upsample_nearest2d": self._infer_aten_upsample_nearest,
|
||||
"upsample_nearest3d": self._infer_aten_upsample_nearest,
|
||||
"upsample_nearest1d": self._infer_aten_upsample,
|
||||
"upsample_nearest2d": self._infer_aten_upsample,
|
||||
"upsample_nearest3d": self._infer_aten_upsample,
|
||||
"upsample_bilinear2d": self._infer_aten_upsample,
|
||||
}
|
||||
self.run_ = True
|
||||
self.suggested_merge_ = {}
|
||||
|
|
@ -1389,14 +1390,14 @@ class SymbolicShapeInference:
|
|||
)
|
||||
)
|
||||
|
||||
def _infer_aten_upsample_nearest(self, node):
|
||||
def _infer_aten_upsample(self, node):
|
||||
new_shape = None
|
||||
input_shape = self._get_shape(node, 0)
|
||||
if input_shape is not None:
|
||||
new_shape = input_shape[:2]
|
||||
output_size = self._try_get_value(node, 1)
|
||||
if output_size is not None:
|
||||
new_shape += [dim_size.item() for dim_size in output_size]
|
||||
new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
|
||||
else:
|
||||
rank = len(input_shape)
|
||||
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
|
||||
|
|
|
|||
|
|
@ -239,8 +239,10 @@ def native_group_norm_gradient():
|
|||
|
||||
# PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name
|
||||
# are available for all versions, though they are not that convienent to use.
|
||||
def _upsample_nearest_gradient(backward_fn, dims):
|
||||
def _upsample_gradient(backward_fn, dims):
|
||||
scales = ["" for _ in range(dims)]
|
||||
if "bilinear" in backward_fn:
|
||||
scales = ["I(2)"] + scales
|
||||
return [
|
||||
("Shape", ["I(0)"], ["Shape_X"]),
|
||||
("Shape", ["O(0)"], ["Shape_Y"]),
|
||||
|
|
@ -258,14 +260,19 @@ def _upsample_nearest_gradient(backward_fn, dims):
|
|||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec")
|
||||
def upsample_nearest1d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest1d_backward", 1)
|
||||
return _upsample_gradient("upsample_nearest1d_backward", 1)
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec")
|
||||
def upsample_nearest2d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest2d_backward", 2)
|
||||
return _upsample_gradient("upsample_nearest2d_backward", 2)
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
|
||||
def upsample_nearest3d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest3d_backward", 3)
|
||||
return _upsample_gradient("upsample_nearest3d_backward", 3)
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec")
|
||||
def upsample_bilinear2d_gradient():
|
||||
return _upsample_gradient("upsample_bilinear2d_backward", 2)
|
||||
|
|
|
|||
|
|
@ -799,3 +799,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
|
|||
@register_symbolic("upsample_nearest3d")
|
||||
def upsample_nearest3d(g, input, output_size, scale_factors):
|
||||
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")
|
||||
|
||||
|
||||
@register_symbolic("upsample_bilinear2d")
|
||||
def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors):
|
||||
return g.op(
|
||||
"org.pytorch.aten::ATen",
|
||||
input,
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
operator_s="upsample_bilinear2d",
|
||||
overload_name_s="vec",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1782,6 +1782,34 @@ def test_aten_upsample_nearest(input_rank, use_factor):
|
|||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
def test_aten_upsample_bilinear():
|
||||
class _NeuralNetUpsampleBilinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(_NeuralNetUpsampleBilinear, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear")
|
||||
|
||||
device = "cuda"
|
||||
pt_model = _NeuralNetUpsampleBilinear().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input):
|
||||
prediction = model(input)
|
||||
prediction.sum().backward()
|
||||
return prediction
|
||||
|
||||
# reset manual seed to reset the generator
|
||||
torch.manual_seed(2333)
|
||||
pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
def test_gradient_correctness_cast_chain():
|
||||
class NeuralNetCast(torch.nn.Module):
|
||||
def __init__(self, D):
|
||||
|
|
|
|||
Loading…
Reference in a new issue