From 3654a5d60e345d662c8a9566cd8e922efd82dbb0 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 18 Nov 2021 10:13:58 +0800 Subject: [PATCH] Register Custom Symbolic of torch.einsum for ORTModule (#9590) * register custom symbolic for einsum * bugfix for case needs permute at the end * refactor * refactor equation parser * support new case, use ReduceProd * optimize perf and graph * remove some Gather node * add more ut, fix gemm trans fusion --- .../core/optimizer/gemm_transpose_fusion.cc | 3 +- .../test/optimizer/graph_transform_test.cc | 29 ++ .../transform/fusion/gemm_transpose_gen.py | 20 ++ ..._transpose_inputs_output_transposed_2.onnx | Bin 0 -> 273 bytes .../ortmodule/_custom_op_symbolic_registry.py | 338 ++++++++++++++++++ .../python/orttraining_test_ortmodule_api.py | 133 +++++++ 6 files changed, 522 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/gemm_transpose_inputs_output_transposed_2.onnx diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index 6b4652d3ea..cb3002c5c9 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -68,8 +68,9 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m Node& output_node = *graph.GetNode(output_node_ptr->Index()); // (AB)' = B'A' : reverse the inputs std::reverse(new_gemm_input_defs.begin(), new_gemm_input_defs.end()); + bool new_transB = !transA; transA = !transB; - transB = !transA; + transB = new_transB; nodes_to_remove.push_back(output_node); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 05bc3868d3..bcb7be6e1a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1368,6 +1368,35 @@ TEST_F(GraphTransformationTests, GemmTransposeFusionInputOutput) { ASSERT_TRUE(new_input_defs[1]->Name() == "A"); } +// (A'(B'))' = BA +TEST_F(GraphTransformationTests, GemmTransposeFusionInputOutput2) { + auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_inputs_output_transposed_2.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2); + ASSERT_EQ(op_to_count["Gemm"], 1); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + auto& node = *graph.Nodes().begin(); + ASSERT_TRUE(node.OpType() == "Gemm"); + ASSERT_FALSE(static_cast(node.GetAttributes().at("transA").i())); + ASSERT_FALSE(static_cast(node.GetAttributes().at("transB").i())); + auto new_input_defs = node.InputDefs(); + ASSERT_TRUE(new_input_defs[0]->Name() == "B"); + ASSERT_TRUE(new_input_defs[1]->Name() == "A"); +} + // Sum(Gemm(A, B, _), C) -> Gemm(A, B, C) TEST_F(GraphTransformationTests, GemmSumFusionBasic) { auto model_uri = MODEL_FOLDER "fusion/gemm_sum_basic.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py index c5576b19b9..9e86a49823 100644 --- a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py @@ -85,3 +85,23 @@ gen_gemm_2inputs_transposed("gemm_transpose_2inputs_transposed.onnx") gen_gemm_output_transposed("gemm_transpose_output_transposed.onnx") gen_gemm_inputs_output_transposed("gemm_transpose_inputs_output_transposed.onnx") +# (A'(B')) = BA +def gen_gemm_inputs_output_transposed_2(model_path): + nodes = [ + helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), + helper.make_node("Gemm", ["tp0", "B"], ["out"], "Gemm", alpha=3.0, transB=1), + helper.make_node("Transpose", ["out"], ["output"], "TransposeOut"), + ] + + inputs = [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, ['K', 'M']), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + ] + + outputs = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ['N', 'M']) + ] + + save(model_path, nodes, inputs, outputs, []) + +gen_gemm_inputs_output_transposed_2("gemm_transpose_inputs_output_transposed_2.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_inputs_output_transposed_2.onnx b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_inputs_output_transposed_2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..69198573be9497a5e8b8ab9733088f249473001b GIT binary patch literal 273 zcmd;J7ZS+N%d03V%`3^wP1P+)EiSQo$jBwn#po!+TvA{l#T8PNm{(koU!3Zw#0h6} zS#p77xfq>Q') + assert pos_comma != -1 and pos_arrow > pos_comma + lhs_labels = [label for label in equation[:pos_comma] if label != ' '] + rhs_labels = [label for label in equation[pos_comma + 1:pos_arrow] if label != ' '] + result_labels = [label for label in equation[pos_arrow + 2:] if label != ' '] + # Two operands and result are not empty, and are all alpha characters. + assert lhs_labels and rhs_labels and result_labels + assert all(label.isalpha() for label in lhs_labels + rhs_labels + result_labels) + # Output has no repeated label, each label must be in at least one operand. + assert len(result_labels) == len(set(result_labels)) + assert all(label in lhs_labels or label in rhs_labels for label in result_labels) + return lhs_labels, rhs_labels, result_labels + +def need_permute(perm): + return any(idx != axis for idx, axis in enumerate(perm)) + +def map_labels_to_output(input_labels, label_perm_map): + output_len = len(label_perm_map) + perm = [-1] * output_len + unsqueeze_axes = [] + idx = 0 + for label in input_labels: + # Lookup output index for label. + perm[label_perm_map[label]] = idx + idx += 1 + + # Add dimensions for missing labels. + for i in range(output_len): + if perm[i] == -1: + unsqueeze_axes.append(idx) + perm[i] = idx + idx += 1 + + return perm, unsqueeze_axes + +def unsqueeze_and_permute_for_mul(g, tensor, unsqueeze_axes, perm): + # If perm is sorted after removing unsqueeze axes, then permute is not needed. + # For example, a.unsqueeze(2).permute([0, 2, 1]) is same as a.unsqueeze(1). + if unsqueeze_axes: + new_perm = [v for v in perm if v not in unsqueeze_axes] + sorted = all(new_perm[i] < new_perm[i + 1] for i in range(len(new_perm) - 1)) + if sorted: + return sym_help._unsqueeze_helper(g, tensor, [perm.index(axis) for axis in unsqueeze_axes]) + + if len(unsqueeze_axes) > 0: + tensor = sym_help._unsqueeze_helper(g, tensor, unsqueeze_axes) + if need_permute(perm): + tensor = g.op("Transpose", tensor, perm_i=perm) + return tensor + +def combine_unsqueeze_and_permute_for_matmul(unsqueeze_axes, perm1, perm2): + # When going here, the unsqueeze axes must be some axes at the end. + # We can combine two permutes and remove unsqueeze axes, because we will reshape it after this. + # For example, a.unsqueeze([2,3]).permute([2,3,1,0]).permute([0,1,3,2]) + # = a.unsqueeze([2,3]).permute([2,3,0,1]) = a.permute([0,1]) = a. + new_perm = [perm1[axis] for axis in perm2] + new_perm = [axis for axis in new_perm if axis not in unsqueeze_axes] + return new_perm + +def is_axes_contiguous(axes): + return len(axes) < 2 or all(axes[axis] + 1 == axes[axis + 1] for axis in range(len(axes) - 1)) + +def get_shape_tensor_by_axes(g, input, input_shape, axes, need_numel_shape): + if input_shape is None: + input_shape = g.op("Shape", input) + shape_tensor = g.op("Gather", input_shape, g.op("Constant", value_t=torch.tensor(axes, dtype=torch.int64)), axis_i=0) + numel_shape_tensor = None + if need_numel_shape: + assert len(axes) > 1 + numel_shape_tensor = g.op("ReduceProd", shape_tensor) + return shape_tensor, numel_shape_tensor, input_shape + +def reshape_tensor(g, input, shape_tensors): + shape_tensor = g.op("Concat", *shape_tensors, axis_i=0) if len(shape_tensors) > 1 else shape_tensors[0] + return g.op("Reshape", input, shape_tensor) + +def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes, contraction_axes, + batch_length, matmul_output_numel_tensor, contraction_numel_tensor, shape_tensor): + # If matmul_output_axes and contraction_axes are contiguous in input tensor, + # we can move Reshape to before Transpose, so it's possible that the Transpoase is fused to MatMul. + # Otherwise, we have to Transpose first to move those axes together and then Reshape. + is_matmul_output_axes_contiguous = is_axes_contiguous(matmul_output_axes) + is_contraction_axes_contiguous = is_axes_contiguous(contraction_axes) + if is_matmul_output_axes_contiguous and is_contraction_axes_contiguous: + # Combine contiguous axes to one axis. + first_matmul_output_axis = matmul_output_axes[0] if len(matmul_output_axes) > 1 else -1 + first_contraction_axis = contraction_axes[0] if len(contraction_axes) > 1 else -1 + # If length of matmul_output_axes and contraction_axes are less than 2, no need to Reshape, + # it needs an Unsqueeze and a Transpose if needed. + if first_matmul_output_axis == -1 and first_contraction_axis == -1: + assert not matmul_output_axes and len(contraction_axes) == 1 + if need_permute(perm): + new_tensor = sym_help._unsqueeze_helper(g, tensor, [-1]) + pos = batch_length if is_lhs else len(perm) + perm = perm[:pos] + [len(perm)] + perm[pos:] + new_tensor = g.op("Transpose", new_tensor, perm_i=perm) + else: + new_tensor = sym_help._unsqueeze_helper(g, tensor, [batch_length if is_lhs else -1]) + else: + axes_to_remove = contraction_axes[1:] # contraction_axes can't be empty. + if len(matmul_output_axes) > 1: + axes_to_remove = axes_to_remove + matmul_output_axes[1:] + remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove] + # Calculate the new shape, use 0 or -1 if possible. + shape_tensors = [] + all_zeros = True + for axis in remaining_axes: + if axis == first_matmul_output_axis: + shape_tensors.append(matmul_output_numel_tensor) + all_zeros = False + elif axis == first_contraction_axis: + shape_tensors.append(contraction_numel_tensor) + all_zeros = False + elif all_zeros: + shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) + elif axis == remaining_axes[-1]: + shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + else: + single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes( + g, tensor, shape_tensor, [axis], False) + shape_tensors.append(single_axis_shape_tensor) + # Adjust the perm. + perm = [axis for axis in perm if axis not in axes_to_remove] + new_axis = 0 + for axis in remaining_axes: + perm[perm.index(axis)] = new_axis + new_axis += 1 + # If matmul_output_axes is empty, need to add a dim-1 axis. + if not matmul_output_axes: + shape_tensors.append(g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))) + pos = batch_length if is_lhs else len(perm) + perm = perm[:pos] + [new_axis] + perm[pos:] + new_tensor = reshape_tensor(g, tensor, shape_tensors) + if need_permute(perm): + new_tensor = g.op("Transpose", new_tensor, perm_i=perm) + else: + if need_permute(perm): + new_tensor = g.op("Transpose", tensor, perm_i=perm) + # Calculate the new shape, use 0 or -1 if possible. + shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * batch_length + if is_lhs: + if matmul_output_numel_tensor is None: + matmul_output_numel_tensor = g.op("Constant", value_t=torch.tensor([1 - len(matmul_output_axes)], dtype=torch.int64)) + shape_tensors.append(matmul_output_numel_tensor) + shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + else: + if contraction_numel_tensor is None: # contraction_axes can't be empty, None here means only one contraction axis. + contraction_numel_tensor = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) + shape_tensors.append(contraction_numel_tensor) + shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + new_tensor = reshape_tensor(g, new_tensor, shape_tensors) + return new_tensor, shape_tensor + +@register_symbolic('einsum') +@parse_args('s', 'v') +def einsum(g, equation, tensor_list): + tensors = sym_help._unpack_list(tensor_list) + num_ops = len(tensors) + assert num_ops > 0 + + # Doesn't support implicit output is ellipsis or more than 2 oprands for now. + # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. + if num_ops != 2 or equation.find('->') == -1 or '.' in equation: + return g.op("Einsum", *tensors, equation_s=equation) + + # Take "ks,ksm->sm" as example. After prcoess inputs, + # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m]. + lhs_labels, rhs_labels, result_labels = parse_equation(equation) + + # Doesn't support repeated label in operand for now as it needs to take extra diagonal. + if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len(set(rhs_labels)): + return g.op("Einsum", *tensors, equation_s=equation) + + # Add contraction labels (labels not present in output). + # After process contraction labels, contraction_labels = [k], + # label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3. + out_size = len(result_labels) + label_perm_map = dict([(label, idx) for idx, label in enumerate(result_labels)]) + perm_size = out_size + contraction_labels = [] + lhs_reduce_sum_axes = [] + rhs_reduce_sum_axes = [] + for label in lhs_labels + rhs_labels: + if label not in label_perm_map: + if label in lhs_labels and label in rhs_labels: + label_perm_map[label] = perm_size + contraction_labels.append(label) + perm_size += 1 + elif label in lhs_labels: + lhs_reduce_sum_axes.append(lhs_labels.index(label)) + else: + rhs_reduce_sum_axes.append(rhs_labels.index(label)) + + lhs_tensor = tensors[0] + rhs_tensor = tensors[1] + + # If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels, + # and use the output as original_lhs_tensor/original_rhs_tensor. + if lhs_reduce_sum_axes: + lhs_tensor = sym_help._reducesum_helper(g, lhs_tensor, lhs_reduce_sum_axes, keepdims_i=False) + lhs_labels = [lhs_labels[axis] for axis in range(len(lhs_labels)) if axis not in lhs_reduce_sum_axes] + + if rhs_reduce_sum_axes: + rhs_tensor = sym_help._reducesum_helper(g, rhs_tensor, rhs_reduce_sum_axes, keepdims_i=False) + rhs_labels = [rhs_labels[axis] for axis in range(len(rhs_labels)) if axis not in rhs_reduce_sum_axes] + + # Need to unsqueeze and permute the inputs to order of output with contraction labels. + # lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2]. + # rhs_perm = [1,2,0], rhs_unsqueeze_axes = []. + lhs_perm, lhs_unsqueeze_axes = map_labels_to_output(lhs_labels, label_perm_map) + rhs_perm, rhs_unsqueeze_axes = map_labels_to_output(rhs_labels, label_perm_map) + + # If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result. + if not contraction_labels: + lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor, lhs_unsqueeze_axes, lhs_perm) + rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor, rhs_unsqueeze_axes, rhs_perm) + return g.op("Mul", lhs_tensor, rhs_tensor) + + # If contraction_labels is not empty, need a BatchedMatMul. + # Batched labels are those in all inputs and output. Below axes are based on output. + # batched_labels = [s], batched_axes = [0] for the example. + # Matmul output labels are those in one of inputs and output. + # matmul_output_labels = [m], matmul_output_axes = [1] for the example. + # contraction_labels = [k], contraction_axes = [2] for the example. + batched_axes = [] + matmul_output_axes = [] + contraction_axes = [axis for axis in range(out_size, perm_size)] + for axis in range(out_size): + label = result_labels[axis] + if label in lhs_labels and label in rhs_labels: + batched_axes.append(axis) + else: + matmul_output_axes.append(axis) + + # Based on above unsqueeze and permute on inputs, need to permute again. + # For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2], + # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example. + # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1]. + # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example. + lhs_perm = combine_unsqueeze_and_permute_for_matmul(lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes) + rhs_perm = combine_unsqueeze_and_permute_for_matmul(rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes) + + # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape. + # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)]. + # Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)] + # Convert all axes based on inputs. + # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example. + lhs_contraction_axes = [lhs_labels.index(label) for label in contraction_labels] + rhs_contraction_axes = [rhs_labels.index(label) for label in contraction_labels] + lhs_matmul_output_axes = [lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels] + rhs_matmul_output_axes = [rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels] + + # Caches of input shape tensors to avoid generating duplicated graph. + lhs_shape_tensor = None + rhs_shape_tensor = None + + # contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here. + contraction_numel_tensor = None + if len(lhs_contraction_axes) > 1: + _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( + g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True) + + # Prepare some shape tensors for Reshape if needed. + # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example. + lhs_matmul_output_shape_tensor = None + lhs_matmul_output_numel_tensor = None + if len(lhs_matmul_output_axes) > 1: + lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( + g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True) + + # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example. + rhs_matmul_output_shape_tensor = None + rhs_matmul_output_numel_tensor = None + if len(rhs_matmul_output_axes) > 1: + rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes( + g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True) + + new_lhs_tensor = lhs_tensor + # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly. + # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)]. + if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1: + new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor( + g, lhs_tensor, True, len(lhs_labels), lhs_perm, lhs_matmul_output_axes, lhs_contraction_axes, + len(batched_axes), lhs_matmul_output_numel_tensor, contraction_numel_tensor, lhs_shape_tensor) + else: + if need_permute(lhs_perm): + new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm) + + # Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly. + # rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example. + new_rhs_tensor = rhs_tensor + if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1: + new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor( + g, rhs_tensor, False, len(rhs_labels), rhs_perm, rhs_matmul_output_axes, rhs_contraction_axes, + len(batched_axes), rhs_matmul_output_numel_tensor, contraction_numel_tensor, rhs_shape_tensor) + else: + if need_permute(rhs_perm): + new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm) + + # Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example. + result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor) + + # Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1. + # Need to Reshape the result for the example, the new shape is [size(s), size(m)]. + if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1: + shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * len(batched_axes) + if lhs_matmul_output_axes: + if len(lhs_matmul_output_axes) == 1: + shape_tensors.append(g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) + else: + shape_tensors.append(lhs_matmul_output_shape_tensor) + if rhs_matmul_output_axes: + if len(rhs_matmul_output_axes) == 1: + shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + else: + shape_tensors.append(rhs_matmul_output_shape_tensor) + result = reshape_tensor(g, result, shape_tensors) + + # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes], + # if this is not same as output, need one permute. + labels = [result_labels[axis] for axis in batched_axes] + [ + lhs_labels[axis] for axis in lhs_matmul_output_axes] + [rhs_labels[axis] for axis in rhs_matmul_output_axes] + assert len(labels) == out_size + output_perm = [labels.index(label) for label in result_labels] + assert all(axis in output_perm for axis in range(out_size)) + if need_permute(output_perm): + result = g.op("Transpose", result, perm_i=output_perm) + + return result +# End of torch.einsum. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index cefbaafc3a..d12fa746dd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1072,6 +1072,139 @@ def test_gradient_correctness_reducesum(dim, keepdim): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +@pytest.mark.parametrize("equation", ["s,se->se", "se,sc->sec", "se,se->s", "sec,sm->ecm", + "sec,ecm->sm", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms"]) +def test_gradient_correctness_einsum(equation): + class NeuralNetEinsum(torch.nn.Module): + def __init__(self, bias_size): + super(NeuralNetEinsum, self).__init__() + self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size))) + + def forward(self, left, right): + left = left + self.bias + return torch.einsum(equation, left, right) + + device = 'cuda' + K, S, M, E = 16, 1024, 768, 64 + C = int(S/E*2) + + SIZE_MAP = { 'K': K, 'S': S, 'E': E, 'C': C, 'M': M } + + pos1 = equation.find(',') + pos2 = equation.find('->') + lhs_op = equation[0:pos1] + rhs_op = equation[pos1 + 1:pos2] + lhs_shape = [] + for c in lhs_op: + lhs_shape.append(SIZE_MAP[c.upper()]) + rhs_shape = [] + for c in rhs_op: + rhs_shape.append(SIZE_MAP[c.upper()]) + + pt_model = NeuralNetEinsum(lhs_shape[-1]).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input_left, input_right): + prediction = model(input_left, input_right) + loss = prediction.sum() + loss.backward() + return prediction + + for _ in range(10): + pt_input_left = torch.rand(lhs_shape, device=device) + pt_input_right = torch.rand(rhs_shape, device=device) + ort_input_left = copy.deepcopy(pt_input_left) + ort_input_right = copy.deepcopy(pt_input_right) + pt_prediction = run_step(pt_model, pt_input_left, pt_input_right) + ort_prediction = run_step(ort_model, ort_input_left, ort_input_right) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + +def test_gradient_correctness_einsum_2(): + class NeuralNetEinsum(torch.nn.Module): + def __init__(self, bias_size): + super(NeuralNetEinsum, self).__init__() + self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size))) + + def forward(self, left, right): + left = left + self.bias + return torch.einsum(equation, left, right) + + device = 'cuda' + A, B, C, D = 16, 32, 8, 64 + + SIZE_MAP = { 'A': A, 'B': B, 'C': C, 'D': D } + + def to_string(perm): + result = '' + for v in perm: + result += chr(ord('a') + v) + return result + + lhs_candidates = [[0], [0,1], [0,1,2]] + perm = [0,1,2,3] + combs = list(itertools.combinations(perm, 1)) + list(itertools.combinations(perm, 2)) + list(itertools.combinations(perm, 3)) + rhs_candidates = [] + for comb in combs: + rhs_candidates += list(itertools.permutations(comb)) + + all_cases = [] + for lhs_candidate in lhs_candidates: + for rhs_candidate in [list(candidate) for candidate in rhs_candidates]: + union = list(set(lhs_candidate + rhs_candidate)) + # Union should contains contiguous numbers from 0, otherwise it's same as another case. + if any(v >= len(union) for v in union): + continue + # Numbers in right but not in left should be sorted, otherwise it's same as another case. + only_in_right = [v for v in rhs_candidate if v not in lhs_candidate] + if any(only_in_right[i] > only_in_right[i + 1] for i in range(len(only_in_right) - 1)): + continue + combs = [] + for i in range(1, len(union) + 1): + combs += list(itertools.combinations(union, i)) + output_candidates = [] + for comb in combs: + output_candidates += list(itertools.permutations(comb)) + # When output_candidates is too many, it will take long time to run. Sample part of them. + random.shuffle(output_candidates) + output_candidates = output_candidates[:8] + for output_candidate in [list(candidate) for candidate in output_candidates]: + all_cases.append((lhs_candidate, rhs_candidate, output_candidate)) + + for case in all_cases: + equation = to_string(case[0]) + ',' + to_string(case[1]) + '->' + to_string(case[2]) + pos1 = equation.find(',') + pos2 = equation.find('->') + lhs_op = equation[0:pos1] + rhs_op = equation[pos1 + 1:pos2] + lhs_shape = [] + for c in lhs_op: + lhs_shape.append(SIZE_MAP[c.upper()]) + rhs_shape = [] + for c in rhs_op: + rhs_shape.append(SIZE_MAP[c.upper()]) + + pt_model = NeuralNetEinsum(lhs_shape[-1]).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input_left, input_right): + prediction = model(input_left, input_right) + loss = prediction.sum() + loss.backward() + return prediction + + for _ in range(5): + pt_input_left = torch.rand(lhs_shape, device=device) + pt_input_right = torch.rand(rhs_shape, device=device) + ort_input_left = copy.deepcopy(pt_input_left) + ort_input_right = copy.deepcopy(pt_input_right) + pt_prediction = run_step(pt_model, pt_input_left, pt_input_right) + ort_prediction = run_step(ort_model, ort_input_left, ort_input_right) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-4, rtol=1e-5) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + # Since multinomial is a generator function, we do not have to test for gradient # Two consecutive calls on the torch.multinomail on a probability distribution with more # than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in