Register Custom Symbolic of torch.einsum for ORTModule (#9590)

* register custom symbolic for einsum

* bugfix for case needs permute at the end

* refactor

* refactor equation parser

* support new case, use ReduceProd

* optimize perf and graph

* remove some Gather node

* add more ut, fix gemm trans fusion
This commit is contained in:
Vincent Wang 2021-11-18 10:13:58 +08:00 committed by GitHub
parent 6545e24b60
commit 3654a5d60e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 522 additions and 1 deletions

View file

@ -68,8 +68,9 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
Node& output_node = *graph.GetNode(output_node_ptr->Index());
// (AB)' = B'A' : reverse the inputs
std::reverse(new_gemm_input_defs.begin(), new_gemm_input_defs.end());
bool new_transB = !transA;
transA = !transB;
transB = !transA;
transB = new_transB;
nodes_to_remove.push_back(output_node);
}

View file

@ -1368,6 +1368,35 @@ TEST_F(GraphTransformationTests, GemmTransposeFusionInputOutput) {
ASSERT_TRUE(new_input_defs[1]->Name() == "A");
}
// (A'(B'))' = BA
TEST_F(GraphTransformationTests, GemmTransposeFusionInputOutput2) {
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_inputs_output_transposed_2.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 2);
ASSERT_EQ(op_to_count["Gemm"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 0);
ASSERT_EQ(op_to_count["Gemm"], 1);
auto& node = *graph.Nodes().begin();
ASSERT_TRUE(node.OpType() == "Gemm");
ASSERT_FALSE(static_cast<bool>(node.GetAttributes().at("transA").i()));
ASSERT_FALSE(static_cast<bool>(node.GetAttributes().at("transB").i()));
auto new_input_defs = node.InputDefs();
ASSERT_TRUE(new_input_defs[0]->Name() == "B");
ASSERT_TRUE(new_input_defs[1]->Name() == "A");
}
// Sum(Gemm(A, B, _), C) -> Gemm(A, B, C)
TEST_F(GraphTransformationTests, GemmSumFusionBasic) {
auto model_uri = MODEL_FOLDER "fusion/gemm_sum_basic.onnx";

View file

@ -85,3 +85,23 @@ gen_gemm_2inputs_transposed("gemm_transpose_2inputs_transposed.onnx")
gen_gemm_output_transposed("gemm_transpose_output_transposed.onnx")
gen_gemm_inputs_output_transposed("gemm_transpose_inputs_output_transposed.onnx")
# (A'(B')) = BA
def gen_gemm_inputs_output_transposed_2(model_path):
nodes = [
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
helper.make_node("Gemm", ["tp0", "B"], ["out"], "Gemm", alpha=3.0, transB=1),
helper.make_node("Transpose", ["out"], ["output"], "TransposeOut"),
]
inputs = [
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['K', 'M']),
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K'])
]
outputs = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['N', 'M'])
]
save(model_path, nodes, inputs, outputs, [])
gen_gemm_inputs_output_transposed_2("gemm_transpose_inputs_output_transposed_2.onnx")

View file

@ -6,6 +6,7 @@
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes
import torch.onnx.symbolic_helper as sym_help
import torch
class CustomOpSymbolicRegistry:
@ -78,6 +79,7 @@ def diagonal(g, self, offset, dim1, dim2):
return g.op("com.microsoft::ATenOp", self, offset, dim1, dim2,
name_s='aten::diagonal')
@register_symbolic('multinomial')
def multinomial(g, self, num_samples, replacement=False, generator=None):
if generator is not None and not sym_help._is_none(generator):
@ -85,6 +87,7 @@ def multinomial(g, self, num_samples, replacement=False, generator=None):
return g.op("com.microsoft::ATenOp", self, num_samples, replacement, generator,
name_s='aten::multinomial')
@register_symbolic('max_pool2d')
def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode):
stride_val = sym_help._maybe_get_const(stride, 'is')
@ -128,3 +131,338 @@ def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduct
name_s='aten::binary_cross_entropy_with_logits')
from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce
return bce(g, self, target, weight, pos_weight, reduction)
# For torch.einsum.
def parse_equation(equation):
pos_comma = equation.find(',')
pos_arrow = equation.find('->')
assert pos_comma != -1 and pos_arrow > pos_comma
lhs_labels = [label for label in equation[:pos_comma] if label != ' ']
rhs_labels = [label for label in equation[pos_comma + 1:pos_arrow] if label != ' ']
result_labels = [label for label in equation[pos_arrow + 2:] if label != ' ']
# Two operands and result are not empty, and are all alpha characters.
assert lhs_labels and rhs_labels and result_labels
assert all(label.isalpha() for label in lhs_labels + rhs_labels + result_labels)
# Output has no repeated label, each label must be in at least one operand.
assert len(result_labels) == len(set(result_labels))
assert all(label in lhs_labels or label in rhs_labels for label in result_labels)
return lhs_labels, rhs_labels, result_labels
def need_permute(perm):
return any(idx != axis for idx, axis in enumerate(perm))
def map_labels_to_output(input_labels, label_perm_map):
output_len = len(label_perm_map)
perm = [-1] * output_len
unsqueeze_axes = []
idx = 0
for label in input_labels:
# Lookup output index for label.
perm[label_perm_map[label]] = idx
idx += 1
# Add dimensions for missing labels.
for i in range(output_len):
if perm[i] == -1:
unsqueeze_axes.append(idx)
perm[i] = idx
idx += 1
return perm, unsqueeze_axes
def unsqueeze_and_permute_for_mul(g, tensor, unsqueeze_axes, perm):
# If perm is sorted after removing unsqueeze axes, then permute is not needed.
# For example, a.unsqueeze(2).permute([0, 2, 1]) is same as a.unsqueeze(1).
if unsqueeze_axes:
new_perm = [v for v in perm if v not in unsqueeze_axes]
sorted = all(new_perm[i] < new_perm[i + 1] for i in range(len(new_perm) - 1))
if sorted:
return sym_help._unsqueeze_helper(g, tensor, [perm.index(axis) for axis in unsqueeze_axes])
if len(unsqueeze_axes) > 0:
tensor = sym_help._unsqueeze_helper(g, tensor, unsqueeze_axes)
if need_permute(perm):
tensor = g.op("Transpose", tensor, perm_i=perm)
return tensor
def combine_unsqueeze_and_permute_for_matmul(unsqueeze_axes, perm1, perm2):
# When going here, the unsqueeze axes must be some axes at the end.
# We can combine two permutes and remove unsqueeze axes, because we will reshape it after this.
# For example, a.unsqueeze([2,3]).permute([2,3,1,0]).permute([0,1,3,2])
# = a.unsqueeze([2,3]).permute([2,3,0,1]) = a.permute([0,1]) = a.
new_perm = [perm1[axis] for axis in perm2]
new_perm = [axis for axis in new_perm if axis not in unsqueeze_axes]
return new_perm
def is_axes_contiguous(axes):
return len(axes) < 2 or all(axes[axis] + 1 == axes[axis + 1] for axis in range(len(axes) - 1))
def get_shape_tensor_by_axes(g, input, input_shape, axes, need_numel_shape):
if input_shape is None:
input_shape = g.op("Shape", input)
shape_tensor = g.op("Gather", input_shape, g.op("Constant", value_t=torch.tensor(axes, dtype=torch.int64)), axis_i=0)
numel_shape_tensor = None
if need_numel_shape:
assert len(axes) > 1
numel_shape_tensor = g.op("ReduceProd", shape_tensor)
return shape_tensor, numel_shape_tensor, input_shape
def reshape_tensor(g, input, shape_tensors):
shape_tensor = g.op("Concat", *shape_tensors, axis_i=0) if len(shape_tensors) > 1 else shape_tensors[0]
return g.op("Reshape", input, shape_tensor)
def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes, contraction_axes,
batch_length, matmul_output_numel_tensor, contraction_numel_tensor, shape_tensor):
# If matmul_output_axes and contraction_axes are contiguous in input tensor,
# we can move Reshape to before Transpose, so it's possible that the Transpoase is fused to MatMul.
# Otherwise, we have to Transpose first to move those axes together and then Reshape.
is_matmul_output_axes_contiguous = is_axes_contiguous(matmul_output_axes)
is_contraction_axes_contiguous = is_axes_contiguous(contraction_axes)
if is_matmul_output_axes_contiguous and is_contraction_axes_contiguous:
# Combine contiguous axes to one axis.
first_matmul_output_axis = matmul_output_axes[0] if len(matmul_output_axes) > 1 else -1
first_contraction_axis = contraction_axes[0] if len(contraction_axes) > 1 else -1
# If length of matmul_output_axes and contraction_axes are less than 2, no need to Reshape,
# it needs an Unsqueeze and a Transpose if needed.
if first_matmul_output_axis == -1 and first_contraction_axis == -1:
assert not matmul_output_axes and len(contraction_axes) == 1
if need_permute(perm):
new_tensor = sym_help._unsqueeze_helper(g, tensor, [-1])
pos = batch_length if is_lhs else len(perm)
perm = perm[:pos] + [len(perm)] + perm[pos:]
new_tensor = g.op("Transpose", new_tensor, perm_i=perm)
else:
new_tensor = sym_help._unsqueeze_helper(g, tensor, [batch_length if is_lhs else -1])
else:
axes_to_remove = contraction_axes[1:] # contraction_axes can't be empty.
if len(matmul_output_axes) > 1:
axes_to_remove = axes_to_remove + matmul_output_axes[1:]
remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove]
# Calculate the new shape, use 0 or -1 if possible.
shape_tensors = []
all_zeros = True
for axis in remaining_axes:
if axis == first_matmul_output_axis:
shape_tensors.append(matmul_output_numel_tensor)
all_zeros = False
elif axis == first_contraction_axis:
shape_tensors.append(contraction_numel_tensor)
all_zeros = False
elif all_zeros:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
elif axis == remaining_axes[-1]:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
else:
single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes(
g, tensor, shape_tensor, [axis], False)
shape_tensors.append(single_axis_shape_tensor)
# Adjust the perm.
perm = [axis for axis in perm if axis not in axes_to_remove]
new_axis = 0
for axis in remaining_axes:
perm[perm.index(axis)] = new_axis
new_axis += 1
# If matmul_output_axes is empty, need to add a dim-1 axis.
if not matmul_output_axes:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)))
pos = batch_length if is_lhs else len(perm)
perm = perm[:pos] + [new_axis] + perm[pos:]
new_tensor = reshape_tensor(g, tensor, shape_tensors)
if need_permute(perm):
new_tensor = g.op("Transpose", new_tensor, perm_i=perm)
else:
if need_permute(perm):
new_tensor = g.op("Transpose", tensor, perm_i=perm)
# Calculate the new shape, use 0 or -1 if possible.
shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * batch_length
if is_lhs:
if matmul_output_numel_tensor is None:
matmul_output_numel_tensor = g.op("Constant", value_t=torch.tensor([1 - len(matmul_output_axes)], dtype=torch.int64))
shape_tensors.append(matmul_output_numel_tensor)
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
else:
if contraction_numel_tensor is None: # contraction_axes can't be empty, None here means only one contraction axis.
contraction_numel_tensor = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
shape_tensors.append(contraction_numel_tensor)
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
new_tensor = reshape_tensor(g, new_tensor, shape_tensors)
return new_tensor, shape_tensor
@register_symbolic('einsum')
@parse_args('s', 'v')
def einsum(g, equation, tensor_list):
tensors = sym_help._unpack_list(tensor_list)
num_ops = len(tensors)
assert num_ops > 0
# Doesn't support implicit output is ellipsis or more than 2 oprands for now.
# Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands.
if num_ops != 2 or equation.find('->') == -1 or '.' in equation:
return g.op("Einsum", *tensors, equation_s=equation)
# Take "ks,ksm->sm" as example. After prcoess inputs,
# lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m].
lhs_labels, rhs_labels, result_labels = parse_equation(equation)
# Doesn't support repeated label in operand for now as it needs to take extra diagonal.
if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len(set(rhs_labels)):
return g.op("Einsum", *tensors, equation_s=equation)
# Add contraction labels (labels not present in output).
# After process contraction labels, contraction_labels = [k],
# label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3.
out_size = len(result_labels)
label_perm_map = dict([(label, idx) for idx, label in enumerate(result_labels)])
perm_size = out_size
contraction_labels = []
lhs_reduce_sum_axes = []
rhs_reduce_sum_axes = []
for label in lhs_labels + rhs_labels:
if label not in label_perm_map:
if label in lhs_labels and label in rhs_labels:
label_perm_map[label] = perm_size
contraction_labels.append(label)
perm_size += 1
elif label in lhs_labels:
lhs_reduce_sum_axes.append(lhs_labels.index(label))
else:
rhs_reduce_sum_axes.append(rhs_labels.index(label))
lhs_tensor = tensors[0]
rhs_tensor = tensors[1]
# If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels,
# and use the output as original_lhs_tensor/original_rhs_tensor.
if lhs_reduce_sum_axes:
lhs_tensor = sym_help._reducesum_helper(g, lhs_tensor, lhs_reduce_sum_axes, keepdims_i=False)
lhs_labels = [lhs_labels[axis] for axis in range(len(lhs_labels)) if axis not in lhs_reduce_sum_axes]
if rhs_reduce_sum_axes:
rhs_tensor = sym_help._reducesum_helper(g, rhs_tensor, rhs_reduce_sum_axes, keepdims_i=False)
rhs_labels = [rhs_labels[axis] for axis in range(len(rhs_labels)) if axis not in rhs_reduce_sum_axes]
# Need to unsqueeze and permute the inputs to order of output with contraction labels.
# lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2].
# rhs_perm = [1,2,0], rhs_unsqueeze_axes = [].
lhs_perm, lhs_unsqueeze_axes = map_labels_to_output(lhs_labels, label_perm_map)
rhs_perm, rhs_unsqueeze_axes = map_labels_to_output(rhs_labels, label_perm_map)
# If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result.
if not contraction_labels:
lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor, lhs_unsqueeze_axes, lhs_perm)
rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor, rhs_unsqueeze_axes, rhs_perm)
return g.op("Mul", lhs_tensor, rhs_tensor)
# If contraction_labels is not empty, need a BatchedMatMul.
# Batched labels are those in all inputs and output. Below axes are based on output.
# batched_labels = [s], batched_axes = [0] for the example.
# Matmul output labels are those in one of inputs and output.
# matmul_output_labels = [m], matmul_output_axes = [1] for the example.
# contraction_labels = [k], contraction_axes = [2] for the example.
batched_axes = []
matmul_output_axes = []
contraction_axes = [axis for axis in range(out_size, perm_size)]
for axis in range(out_size):
label = result_labels[axis]
if label in lhs_labels and label in rhs_labels:
batched_axes.append(axis)
else:
matmul_output_axes.append(axis)
# Based on above unsqueeze and permute on inputs, need to permute again.
# For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2],
# i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example.
# For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1].
# i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example.
lhs_perm = combine_unsqueeze_and_permute_for_matmul(lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes)
rhs_perm = combine_unsqueeze_and_permute_for_matmul(rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes)
# Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape.
# Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)].
# Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)]
# Convert all axes based on inputs.
# lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example.
lhs_contraction_axes = [lhs_labels.index(label) for label in contraction_labels]
rhs_contraction_axes = [rhs_labels.index(label) for label in contraction_labels]
lhs_matmul_output_axes = [lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels]
rhs_matmul_output_axes = [rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels]
# Caches of input shape tensors to avoid generating duplicated graph.
lhs_shape_tensor = None
rhs_shape_tensor = None
# contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here.
contraction_numel_tensor = None
if len(lhs_contraction_axes) > 1:
_, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True)
# Prepare some shape tensors for Reshape if needed.
# Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example.
lhs_matmul_output_shape_tensor = None
lhs_matmul_output_numel_tensor = None
if len(lhs_matmul_output_axes) > 1:
lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True)
# Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example.
rhs_matmul_output_shape_tensor = None
rhs_matmul_output_numel_tensor = None
if len(rhs_matmul_output_axes) > 1:
rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes(
g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True)
new_lhs_tensor = lhs_tensor
# Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly.
# Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)].
if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1:
new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor(
g, lhs_tensor, True, len(lhs_labels), lhs_perm, lhs_matmul_output_axes, lhs_contraction_axes,
len(batched_axes), lhs_matmul_output_numel_tensor, contraction_numel_tensor, lhs_shape_tensor)
else:
if need_permute(lhs_perm):
new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm)
# Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly.
# rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example.
new_rhs_tensor = rhs_tensor
if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1:
new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor(
g, rhs_tensor, False, len(rhs_labels), rhs_perm, rhs_matmul_output_axes, rhs_contraction_axes,
len(batched_axes), rhs_matmul_output_numel_tensor, contraction_numel_tensor, rhs_shape_tensor)
else:
if need_permute(rhs_perm):
new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm)
# Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example.
result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor)
# Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1.
# Need to Reshape the result for the example, the new shape is [size(s), size(m)].
if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1:
shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * len(batched_axes)
if lhs_matmul_output_axes:
if len(lhs_matmul_output_axes) == 1:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)))
else:
shape_tensors.append(lhs_matmul_output_shape_tensor)
if rhs_matmul_output_axes:
if len(rhs_matmul_output_axes) == 1:
shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
else:
shape_tensors.append(rhs_matmul_output_shape_tensor)
result = reshape_tensor(g, result, shape_tensors)
# Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes],
# if this is not same as output, need one permute.
labels = [result_labels[axis] for axis in batched_axes] + [
lhs_labels[axis] for axis in lhs_matmul_output_axes] + [rhs_labels[axis] for axis in rhs_matmul_output_axes]
assert len(labels) == out_size
output_perm = [labels.index(label) for label in result_labels]
assert all(axis in output_perm for axis in range(out_size))
if need_permute(output_perm):
result = g.op("Transpose", result, perm_i=output_perm)
return result
# End of torch.einsum.

View file

@ -1072,6 +1072,139 @@ def test_gradient_correctness_reducesum(dim, keepdim):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
@pytest.mark.parametrize("equation", ["s,se->se", "se,sc->sec", "se,se->s", "sec,sm->ecm",
"sec,ecm->sm", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms"])
def test_gradient_correctness_einsum(equation):
class NeuralNetEinsum(torch.nn.Module):
def __init__(self, bias_size):
super(NeuralNetEinsum, self).__init__()
self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size)))
def forward(self, left, right):
left = left + self.bias
return torch.einsum(equation, left, right)
device = 'cuda'
K, S, M, E = 16, 1024, 768, 64
C = int(S/E*2)
SIZE_MAP = { 'K': K, 'S': S, 'E': E, 'C': C, 'M': M }
pos1 = equation.find(',')
pos2 = equation.find('->')
lhs_op = equation[0:pos1]
rhs_op = equation[pos1 + 1:pos2]
lhs_shape = []
for c in lhs_op:
lhs_shape.append(SIZE_MAP[c.upper()])
rhs_shape = []
for c in rhs_op:
rhs_shape.append(SIZE_MAP[c.upper()])
pt_model = NeuralNetEinsum(lhs_shape[-1]).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input_left, input_right):
prediction = model(input_left, input_right)
loss = prediction.sum()
loss.backward()
return prediction
for _ in range(10):
pt_input_left = torch.rand(lhs_shape, device=device)
pt_input_right = torch.rand(rhs_shape, device=device)
ort_input_left = copy.deepcopy(pt_input_left)
ort_input_right = copy.deepcopy(pt_input_right)
pt_prediction = run_step(pt_model, pt_input_left, pt_input_right)
ort_prediction = run_step(ort_model, ort_input_left, ort_input_right)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
def test_gradient_correctness_einsum_2():
class NeuralNetEinsum(torch.nn.Module):
def __init__(self, bias_size):
super(NeuralNetEinsum, self).__init__()
self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size)))
def forward(self, left, right):
left = left + self.bias
return torch.einsum(equation, left, right)
device = 'cuda'
A, B, C, D = 16, 32, 8, 64
SIZE_MAP = { 'A': A, 'B': B, 'C': C, 'D': D }
def to_string(perm):
result = ''
for v in perm:
result += chr(ord('a') + v)
return result
lhs_candidates = [[0], [0,1], [0,1,2]]
perm = [0,1,2,3]
combs = list(itertools.combinations(perm, 1)) + list(itertools.combinations(perm, 2)) + list(itertools.combinations(perm, 3))
rhs_candidates = []
for comb in combs:
rhs_candidates += list(itertools.permutations(comb))
all_cases = []
for lhs_candidate in lhs_candidates:
for rhs_candidate in [list(candidate) for candidate in rhs_candidates]:
union = list(set(lhs_candidate + rhs_candidate))
# Union should contains contiguous numbers from 0, otherwise it's same as another case.
if any(v >= len(union) for v in union):
continue
# Numbers in right but not in left should be sorted, otherwise it's same as another case.
only_in_right = [v for v in rhs_candidate if v not in lhs_candidate]
if any(only_in_right[i] > only_in_right[i + 1] for i in range(len(only_in_right) - 1)):
continue
combs = []
for i in range(1, len(union) + 1):
combs += list(itertools.combinations(union, i))
output_candidates = []
for comb in combs:
output_candidates += list(itertools.permutations(comb))
# When output_candidates is too many, it will take long time to run. Sample part of them.
random.shuffle(output_candidates)
output_candidates = output_candidates[:8]
for output_candidate in [list(candidate) for candidate in output_candidates]:
all_cases.append((lhs_candidate, rhs_candidate, output_candidate))
for case in all_cases:
equation = to_string(case[0]) + ',' + to_string(case[1]) + '->' + to_string(case[2])
pos1 = equation.find(',')
pos2 = equation.find('->')
lhs_op = equation[0:pos1]
rhs_op = equation[pos1 + 1:pos2]
lhs_shape = []
for c in lhs_op:
lhs_shape.append(SIZE_MAP[c.upper()])
rhs_shape = []
for c in rhs_op:
rhs_shape.append(SIZE_MAP[c.upper()])
pt_model = NeuralNetEinsum(lhs_shape[-1]).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input_left, input_right):
prediction = model(input_left, input_right)
loss = prediction.sum()
loss.backward()
return prediction
for _ in range(5):
pt_input_left = torch.rand(lhs_shape, device=device)
pt_input_right = torch.rand(rhs_shape, device=device)
ort_input_left = copy.deepcopy(pt_input_left)
ort_input_right = copy.deepcopy(pt_input_right)
pt_prediction = run_step(pt_model, pt_input_left, pt_input_right)
ort_prediction = run_step(ort_model, ort_input_left, ort_input_right)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-4, rtol=1e-5)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
# Since multinomial is a generator function, we do not have to test for gradient
# Two consecutive calls on the torch.multinomail on a probability distribution with more
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in