From 8860fded02435e62b63bf69f3d420998f6a00f39 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 18 Mar 2022 21:28:18 +0800 Subject: [PATCH] Disable Some Einsum ORTModule Tests Due to Issue from PyTorch Exporter (#10906) * disable some einsum tests due to pytorch issue * disable tests on specific torch versions * use skipif --- .../ortmodule/_custom_op_symbolic_registry.py | 20 +++++++++++++++---- .../python/orttraining_test_ortmodule_api.py | 13 ++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index a721fc64a2..cd3cd1e66c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -257,22 +257,28 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove] # Calculate the new shape, use 0 or -1 if possible. shape_tensors = [] - all_zeros = True + before_contiguous_axes = True + last_zero_dim = -1 + has_neg_one_dim = False for axis in remaining_axes: if axis == first_matmul_output_axis: shape_tensors.append(matmul_output_numel_tensor) - all_zeros = False + before_contiguous_axes = False elif axis == first_contraction_axis: shape_tensors.append(contraction_numel_tensor) - all_zeros = False - elif all_zeros: + before_contiguous_axes = False + elif before_contiguous_axes: shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) + last_zero_dim = len(shape_tensors) - 1 elif axis == remaining_axes[-1]: shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + has_neg_one_dim = True else: single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes( g, tensor, shape_tensor, [axis], False) shape_tensors.append(single_axis_shape_tensor) + if not has_neg_one_dim and last_zero_dim >= 0: + shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) # Adjust the perm. perm = [axis for axis in perm if axis not in axes_to_remove] new_axis = 0 @@ -458,16 +464,22 @@ def einsum(g, equation, tensor_list): # Need to Reshape the result for the example, the new shape is [size(s), size(m)]. if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1: shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * len(batched_axes) + last_zero_dim = len(shape_tensors) - 1 + has_neg_one_dim = False if lhs_matmul_output_axes: if len(lhs_matmul_output_axes) == 1: shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) + last_zero_dim = len(shape_tensors) - 1 else: shape_tensors.append(lhs_matmul_output_shape_tensor) if rhs_matmul_output_axes: if len(rhs_matmul_output_axes) == 1: shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + has_neg_one_dim = True else: shape_tensors.append(rhs_matmul_output_shape_tensor) + if not has_neg_one_dim and last_zero_dim >= 0: + shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) result = reshape_tensor(g, result, shape_tensors) # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes], diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 84ce8405c6..acd1b0b6c3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1134,8 +1134,16 @@ def test_gradient_correctness_reducesum(dim, keepdim): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -@pytest.mark.parametrize("equation", ["s,se->se", "se,sc->sec", "se,se->s", "sec,sm->ecm", - "sec,ecm->sm", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms"]) +# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that +# contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm". +# Currently skip these cases and test_gradient_correctness_einsum_2, +# will enable these tests again once the issue in PyTorch is fixed. +skip_torch_1_11 = pytest.mark.skipif(LooseVersion(torch.__version__) >= LooseVersion('1.11.0'), reason="PyTorch 1.11 incompatible") +@pytest.mark.parametrize("equation", [ + "s,se->se", "se,sc->sec", "se,se->s", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms", + pytest.param("sec,sm->ecm", marks=[skip_torch_1_11]), + pytest.param("sec,ecm->sm", marks=[skip_torch_1_11]) +]) def test_gradient_correctness_einsum(equation): class NeuralNetEinsum(torch.nn.Module): def __init__(self, bias_size): @@ -1183,6 +1191,7 @@ def test_gradient_correctness_einsum(equation): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-3, rtol=1e-3) +@skip_torch_1_11 def test_gradient_correctness_einsum_2(): class NeuralNetEinsum(torch.nn.Module): def __init__(self, bias_size):