Disable Some Einsum ORTModule Tests Due to Issue from PyTorch Exporter (#10906)

* disable some einsum tests due to pytorch issue

* disable tests on specific torch versions

* use skipif
This commit is contained in:
Vincent Wang 2022-03-18 21:28:18 +08:00 committed by GitHub
parent 5ed2f4ad5f
commit 8860fded02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 6 deletions

View file

@ -257,22 +257,28 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes
remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove]
# Calculate the new shape, use 0 or -1 if possible.
shape_tensors = []
all_zeros = True
before_contiguous_axes = True
last_zero_dim = -1
has_neg_one_dim = False
for axis in remaining_axes:
if axis == first_matmul_output_axis:
shape_tensors.append(matmul_output_numel_tensor)
all_zeros = False
before_contiguous_axes = False
elif axis == first_contraction_axis:
shape_tensors.append(contraction_numel_tensor)
all_zeros = False
elif all_zeros:
before_contiguous_axes = False
elif before_contiguous_axes:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
last_zero_dim = len(shape_tensors) - 1
elif axis == remaining_axes[-1]:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
has_neg_one_dim = True
else:
single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes(
g, tensor, shape_tensor, [axis], False)
shape_tensors.append(single_axis_shape_tensor)
if not has_neg_one_dim and last_zero_dim >= 0:
shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
# Adjust the perm.
perm = [axis for axis in perm if axis not in axes_to_remove]
new_axis = 0
@ -458,16 +464,22 @@ def einsum(g, equation, tensor_list):
# Need to Reshape the result for the example, the new shape is [size(s), size(m)].
if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1:
shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * len(batched_axes)
last_zero_dim = len(shape_tensors) - 1
has_neg_one_dim = False
if lhs_matmul_output_axes:
if len(lhs_matmul_output_axes) == 1:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
last_zero_dim = len(shape_tensors) - 1
else:
shape_tensors.append(lhs_matmul_output_shape_tensor)
if rhs_matmul_output_axes:
if len(rhs_matmul_output_axes) == 1:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
has_neg_one_dim = True
else:
shape_tensors.append(rhs_matmul_output_shape_tensor)
if not has_neg_one_dim and last_zero_dim >= 0:
shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
result = reshape_tensor(g, result, shape_tensors)
# Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes],

View file

@ -1134,8 +1134,16 @@ def test_gradient_correctness_reducesum(dim, keepdim):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
@pytest.mark.parametrize("equation", ["s,se->se", "se,sc->sec", "se,se->s", "sec,sm->ecm",
"sec,ecm->sm", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms"])
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that
# contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm".
# Currently skip these cases and test_gradient_correctness_einsum_2,
# will enable these tests again once the issue in PyTorch is fixed.
skip_torch_1_11 = pytest.mark.skipif(LooseVersion(torch.__version__) >= LooseVersion('1.11.0'), reason="PyTorch 1.11 incompatible")
@pytest.mark.parametrize("equation", [
"s,se->se", "se,sc->sec", "se,se->s", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms",
pytest.param("sec,sm->ecm", marks=[skip_torch_1_11]),
pytest.param("sec,ecm->sm", marks=[skip_torch_1_11])
])
def test_gradient_correctness_einsum(equation):
class NeuralNetEinsum(torch.nn.Module):
def __init__(self, bias_size):
@ -1183,6 +1191,7 @@ def test_gradient_correctness_einsum(equation):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-3, rtol=1e-3)
@skip_torch_1_11
def test_gradient_correctness_einsum_2():
class NeuralNetEinsum(torch.nn.Module):
def __init__(self, bias_size):