[ORTModule] ATen Support for upsample_bilinear (#14519)

It's required by model MobileViT.
This commit is contained in:
Vincent Wang 2023-02-04 15:20:18 +08:00 committed by GitHub
parent c1a0fc55e7
commit 3d7518762a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 9 deletions

View file

@ -218,9 +218,10 @@ class SymbolicShapeInference:
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample_nearest,
"upsample_nearest2d": self._infer_aten_upsample_nearest,
"upsample_nearest3d": self._infer_aten_upsample_nearest,
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bilinear2d": self._infer_aten_upsample,
}
self.run_ = True
self.suggested_merge_ = {}
@ -1389,14 +1390,14 @@ class SymbolicShapeInference:
)
)
def _infer_aten_upsample_nearest(self, node):
def _infer_aten_upsample(self, node):
new_shape = None
input_shape = self._get_shape(node, 0)
if input_shape is not None:
new_shape = input_shape[:2]
output_size = self._try_get_value(node, 1)
if output_size is not None:
new_shape += [dim_size.item() for dim_size in output_size]
new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
else:
rank = len(input_shape)
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]

View file

@ -239,8 +239,10 @@ def native_group_norm_gradient():
# PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name
# are available for all versions, though they are not that convienent to use.
def _upsample_nearest_gradient(backward_fn, dims):
def _upsample_gradient(backward_fn, dims):
scales = ["" for _ in range(dims)]
if "bilinear" in backward_fn:
scales = ["I(2)"] + scales
return [
("Shape", ["I(0)"], ["Shape_X"]),
("Shape", ["O(0)"], ["Shape_Y"]),
@ -258,14 +260,19 @@ def _upsample_nearest_gradient(backward_fn, dims):
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec")
def upsample_nearest1d_gradient():
return _upsample_nearest_gradient("upsample_nearest1d_backward", 1)
return _upsample_gradient("upsample_nearest1d_backward", 1)
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec")
def upsample_nearest2d_gradient():
return _upsample_nearest_gradient("upsample_nearest2d_backward", 2)
return _upsample_gradient("upsample_nearest2d_backward", 2)
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
def upsample_nearest3d_gradient():
return _upsample_nearest_gradient("upsample_nearest3d_backward", 3)
return _upsample_gradient("upsample_nearest3d_backward", 3)
@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec")
def upsample_bilinear2d_gradient():
return _upsample_gradient("upsample_bilinear2d_backward", 2)

View file

@ -799,3 +799,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")
@register_symbolic("upsample_bilinear2d")
def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors):
return g.op(
"org.pytorch.aten::ATen",
input,
output_size,
align_corners,
scale_factors,
operator_s="upsample_bilinear2d",
overload_name_s="vec",
)

View file

@ -1782,6 +1782,34 @@ def test_aten_upsample_nearest(input_rank, use_factor):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
def test_aten_upsample_bilinear():
class _NeuralNetUpsampleBilinear(torch.nn.Module):
def __init__(self):
super(_NeuralNetUpsampleBilinear, self).__init__()
def forward(self, input):
return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear")
device = "cuda"
pt_model = _NeuralNetUpsampleBilinear().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input):
prediction = model(input)
prediction.sum().backward()
return prediction
# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
def test_gradient_correctness_cast_chain():
class NeuralNetCast(torch.nn.Module):
def __init__(self, D):