diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index dbc939bce2..ae320279d7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -218,9 +218,10 @@ class SymbolicShapeInference: "_adaptive_avg_pool2d": self._infer_aten_pool2d, "numpy_T": self._infer_Transpose, "native_group_norm": self._infer_aten_group_norm, - "upsample_nearest1d": self._infer_aten_upsample_nearest, - "upsample_nearest2d": self._infer_aten_upsample_nearest, - "upsample_nearest3d": self._infer_aten_upsample_nearest, + "upsample_nearest1d": self._infer_aten_upsample, + "upsample_nearest2d": self._infer_aten_upsample, + "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bilinear2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} @@ -1389,14 +1390,14 @@ class SymbolicShapeInference: ) ) - def _infer_aten_upsample_nearest(self, node): + def _infer_aten_upsample(self, node): new_shape = None input_shape = self._get_shape(node, 0) if input_shape is not None: new_shape = input_shape[:2] output_size = self._try_get_value(node, 1) if output_size is not None: - new_shape += [dim_size.item() for dim_size in output_size] + new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] else: rank = len(input_shape) new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index e1d3d5fcf5..89a766bd36 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -239,8 +239,10 @@ def native_group_norm_gradient(): # PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name # are available for all versions, though they are not that convienent to use. -def _upsample_nearest_gradient(backward_fn, dims): +def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] + if "bilinear" in backward_fn: + scales = ["I(2)"] + scales return [ ("Shape", ["I(0)"], ["Shape_X"]), ("Shape", ["O(0)"], ["Shape_Y"]), @@ -258,14 +260,19 @@ def _upsample_nearest_gradient(backward_fn, dims): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec") def upsample_nearest1d_gradient(): - return _upsample_nearest_gradient("upsample_nearest1d_backward", 1) + return _upsample_gradient("upsample_nearest1d_backward", 1) @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec") def upsample_nearest2d_gradient(): - return _upsample_nearest_gradient("upsample_nearest2d_backward", 2) + return _upsample_gradient("upsample_nearest2d_backward", 2) @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): - return _upsample_nearest_gradient("upsample_nearest3d_backward", 3) + return _upsample_gradient("upsample_nearest3d_backward", 3) + + +@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec") +def upsample_bilinear2d_gradient(): + return _upsample_gradient("upsample_bilinear2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 17076d862a..7cd889a156 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -799,3 +799,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + + +@register_symbolic("upsample_bilinear2d") +def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors): + return g.op( + "org.pytorch.aten::ATen", + input, + output_size, + align_corners, + scale_factors, + operator_s="upsample_bilinear2d", + overload_name_s="vec", + ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7758603c48..3cbdbd9139 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1782,6 +1782,34 @@ def test_aten_upsample_nearest(input_rank, use_factor): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +def test_aten_upsample_bilinear(): + class _NeuralNetUpsampleBilinear(torch.nn.Module): + def __init__(self): + super(_NeuralNetUpsampleBilinear, self).__init__() + + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear") + + device = "cuda" + pt_model = _NeuralNetUpsampleBilinear().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D):