mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Disable Some Einsum ORTModule Tests Due to Issue from PyTorch Exporter (#10906)
* disable some einsum tests due to pytorch issue * disable tests on specific torch versions * use skipif
This commit is contained in:
parent
5ed2f4ad5f
commit
8860fded02
2 changed files with 27 additions and 6 deletions
|
|
@ -257,22 +257,28 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes
|
|||
remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove]
|
||||
# Calculate the new shape, use 0 or -1 if possible.
|
||||
shape_tensors = []
|
||||
all_zeros = True
|
||||
before_contiguous_axes = True
|
||||
last_zero_dim = -1
|
||||
has_neg_one_dim = False
|
||||
for axis in remaining_axes:
|
||||
if axis == first_matmul_output_axis:
|
||||
shape_tensors.append(matmul_output_numel_tensor)
|
||||
all_zeros = False
|
||||
before_contiguous_axes = False
|
||||
elif axis == first_contraction_axis:
|
||||
shape_tensors.append(contraction_numel_tensor)
|
||||
all_zeros = False
|
||||
elif all_zeros:
|
||||
before_contiguous_axes = False
|
||||
elif before_contiguous_axes:
|
||||
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
|
||||
last_zero_dim = len(shape_tensors) - 1
|
||||
elif axis == remaining_axes[-1]:
|
||||
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
||||
has_neg_one_dim = True
|
||||
else:
|
||||
single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes(
|
||||
g, tensor, shape_tensor, [axis], False)
|
||||
shape_tensors.append(single_axis_shape_tensor)
|
||||
if not has_neg_one_dim and last_zero_dim >= 0:
|
||||
shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
# Adjust the perm.
|
||||
perm = [axis for axis in perm if axis not in axes_to_remove]
|
||||
new_axis = 0
|
||||
|
|
@ -458,16 +464,22 @@ def einsum(g, equation, tensor_list):
|
|||
# Need to Reshape the result for the example, the new shape is [size(s), size(m)].
|
||||
if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1:
|
||||
shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * len(batched_axes)
|
||||
last_zero_dim = len(shape_tensors) - 1
|
||||
has_neg_one_dim = False
|
||||
if lhs_matmul_output_axes:
|
||||
if len(lhs_matmul_output_axes) == 1:
|
||||
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
|
||||
last_zero_dim = len(shape_tensors) - 1
|
||||
else:
|
||||
shape_tensors.append(lhs_matmul_output_shape_tensor)
|
||||
if rhs_matmul_output_axes:
|
||||
if len(rhs_matmul_output_axes) == 1:
|
||||
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
||||
has_neg_one_dim = True
|
||||
else:
|
||||
shape_tensors.append(rhs_matmul_output_shape_tensor)
|
||||
if not has_neg_one_dim and last_zero_dim >= 0:
|
||||
shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
result = reshape_tensor(g, result, shape_tensors)
|
||||
|
||||
# Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes],
|
||||
|
|
|
|||
|
|
@ -1134,8 +1134,16 @@ def test_gradient_correctness_reducesum(dim, keepdim):
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
@pytest.mark.parametrize("equation", ["s,se->se", "se,sc->sec", "se,se->s", "sec,sm->ecm",
|
||||
"sec,ecm->sm", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms"])
|
||||
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that
|
||||
# contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm".
|
||||
# Currently skip these cases and test_gradient_correctness_einsum_2,
|
||||
# will enable these tests again once the issue in PyTorch is fixed.
|
||||
skip_torch_1_11 = pytest.mark.skipif(LooseVersion(torch.__version__) >= LooseVersion('1.11.0'), reason="PyTorch 1.11 incompatible")
|
||||
@pytest.mark.parametrize("equation", [
|
||||
"s,se->se", "se,sc->sec", "se,se->s", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms",
|
||||
pytest.param("sec,sm->ecm", marks=[skip_torch_1_11]),
|
||||
pytest.param("sec,ecm->sm", marks=[skip_torch_1_11])
|
||||
])
|
||||
def test_gradient_correctness_einsum(equation):
|
||||
class NeuralNetEinsum(torch.nn.Module):
|
||||
def __init__(self, bias_size):
|
||||
|
|
@ -1183,6 +1191,7 @@ def test_gradient_correctness_einsum(equation):
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@skip_torch_1_11
|
||||
def test_gradient_correctness_einsum_2():
|
||||
class NeuralNetEinsum(torch.nn.Module):
|
||||
def __init__(self, bias_size):
|
||||
|
|
|
|||
Loading…
Reference in a new issue