diff --git a/onnxruntime/test/optimizer/cse_test.cc b/onnxruntime/test/optimizer/cse_test.cc index bde7e4ff6e..32686b3683 100644 --- a/onnxruntime/test/optimizer/cse_test.cc +++ b/onnxruntime/test/optimizer/cse_test.cc @@ -95,8 +95,12 @@ TEST(CseTests, SimpleTestTraining) { .IsOK()); GraphTransformerManager graph_transformation_mgr(1); + // need to declare variables to avoid build error after making + // weights_to_train and updated_weight_names as non-const + std::unordered_set weights_to_train; + std::unordered_map updated_weight_names; auto transformers_to_register = onnxruntime::training::transformer_utils::GeneratePreTrainingTransformers( - TransformerLevel::Level1, {}, {}, CPUExecutionProvider(CPUExecutionProviderInfo())); + TransformerLevel::Level1, weights_to_train, {}, CPUExecutionProvider(CPUExecutionProviderInfo()), updated_weight_names); for (auto& entry : transformers_to_register) { ASSERT_TRUE( graph_transformation_mgr.Register(std::move(entry), TransformerLevel::Level1).IsOK()); diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx new file mode 100644 index 0000000000..084cdea8dd Binary files /dev/null and b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx differ diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py new file mode 100644 index 0000000000..b531f34bc8 --- /dev/null +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py @@ -0,0 +1,105 @@ +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +hidden_size = 4 +weight_dim_to_split = 16 + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) + +a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((weight_dim_to_split, hidden_size)) +a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.weight") + +a_bias_np_vals = (0.01 * np.arange(weight_dim_to_split, dtype=np.float32)) # weight_dim_to_split numbers in total +a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.bias") + +dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(()) +dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio") + +dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(()) +dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode") + +b_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((hidden_size, weight_dim_to_split)) +b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.weight") + +b_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) # hidden_size numbers in total +b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.bias") + +transpose1 = helper.make_node('Transpose', [a_weight_initializer.name], ['transpose1'], name='transpose1', perm=[1,0]) +transpose2 = helper.make_node('Transpose', [b_weight_initializer.name], ['transpose2'], name='transpose2', perm=[1,0]) +matmul = helper.make_node( + 'MatMul', # node name + ['input', 'transpose1'], # inputs + ['matmul'], # outputs + name="matmul" +) + +biasgelu = helper.make_node( + 'BiasGelu', # node name + ['matmul', a_bias_initializer.name], # inputs + ['biasgelu'], # outputs + name="biasgelu", + domain="com.microsoft" +) + +dropout1 = helper.make_node('Dropout', + ["biasgelu", dropout_initializer.name, dropout_mode_initializer.name], + ['dropout1', "dropout1_mask"], + name='dropout1') + +matmul2 = helper.make_node( + 'MatMul', # node name + ['dropout1', "transpose2"], # inputs + ['matmul2'], # outputs + name="matmul2" +) + +add = helper.make_node( + 'Add', # node name + ['matmul2', b_bias_initializer.name], # inputs + ['add'], # outputs + name="add" +) + +dropout2 = helper.make_node('Dropout', + ["add", dropout_initializer.name, dropout_mode_initializer.name], + ['dropout2', "dropout2_mask"], + name='dropout2') + +identity = helper.make_node( + 'Identity', # node name + ['dropout2'], # inputs + ['output'], # outputs + name="identity" +) + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [transpose1, transpose2, matmul, biasgelu, dropout1, matmul2, add, dropout2, identity], + 'test-model', + [X], + [Y], + [a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, dropout_initializer, dropout_mode_initializer] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" + +opsets.append(msdomain) +kwargs={} +kwargs["opset_imports"] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) + +onnx.save(model_def, 'bart_mlp_megatron_basic_test.onnx') \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx new file mode 100644 index 0000000000..f501c8c26f Binary files /dev/null and b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx differ diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py new file mode 100644 index 0000000000..81cd321c04 --- /dev/null +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py @@ -0,0 +1,150 @@ +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np +import random + +batch = 6 +hidden_size = 4 +attention_head = 2 +hidden_per_attention = 2 + +relative_attention_num_buckets=32 +input_len=8 +output_len=8 + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [batch, input_len, hidden_size]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [output_len, batch, hidden_size]) + +q_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) +q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, 'encoder.layers.0.self_attn.q_proj.weight') + +k_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) +k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, 'encoder.layers.0.self_attn.k_proj.weight') + +v_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) +v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, 'encoder.layers.0.self_attn.v_proj.weight') + +q_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) +q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, 'encoder.layers.0.self_attn.q_proj.bias') + +k_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) +k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, 'encoder.layers.0.self_attn.k_proj.bias') + +v_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) +v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, 'encoder.layers.0.self_attn.v_proj.bias') + +q_shape_initializer = numpy_helper.from_array(np.asarray([input_len, batch*attention_head , hidden_per_attention], dtype=np.int64), 'q_shape') +k_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'k_shape') +v_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'v_shape') + +mul_np_vals = np.asarray([0.1767766922712326], dtype=np.float32).reshape(()) +mul_initializer = numpy_helper.from_array(mul_np_vals, "mul_const") + +qk_shape_initializer = numpy_helper.from_array(np.asarray([batch, attention_head , input_len, input_len], dtype=np.int64), 'qk_shape') + +dummy_condition_initializer = numpy_helper.from_array(np.zeros((batch, input_len), dtype=bool), 'dummy_cond') +inf_const_initializer = numpy_helper.from_array(np.asarray([-np.inf], dtype=np.float32), 'inf_const') + +where_shape_initializer = numpy_helper.from_array(np.asarray([batch*attention_head , input_len, input_len], dtype=np.int64), 'where_shape') + +dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(()) +dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio") + +dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(()) +dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode") + +shape_initializer3 = numpy_helper.from_array(np.array([input_len, batch, attention_head * hidden_per_attention], dtype=np.int64), 'concat_shape_3') + +dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) +dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, 'encoder.layers.0.self_attn.out_proj.weight') + +dense_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) +dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, 'encoder.layers.0.self_attn.out_proj.bias') + + +transpose_ip = helper.make_node('Transpose', ['input'], ['transpose_ip'], name='transpose_ip', perm=[1,0,2]) + +transpose_q = helper.make_node('Transpose', [q_weight_initializer.name], ['transpose_q'], name='transpose_q', perm=[1,0]) +transpose_k = helper.make_node('Transpose', [k_weight_initializer.name], ['transpose_k'], name='transpose_k', perm=[1,0]) +transpose_v = helper.make_node('Transpose', [v_weight_initializer.name], ['transpose_v'], name='transpose_v', perm=[1,0]) + +matmul_q = helper.make_node('MatMul', ['transpose_ip', 'transpose_q'], ['matmul_q'], name='matmul_q') +matmul_k = helper.make_node('MatMul', ['transpose_ip', 'transpose_k'], ['matmul_k'], name='matmul_k') +matmul_v = helper.make_node('MatMul', ['transpose_ip', 'transpose_v'], ['matmul_v'], name='matmul_v') + + +add_q = helper.make_node('Add', ['matmul_q', q_bias_initializer.name], ['add_q'], name='add_q') +add_k = helper.make_node('Add', ['matmul_k', k_bias_initializer.name], ['add_k'], name='add_k') +add_v = helper.make_node('Add', ['matmul_v', v_bias_initializer.name], ['add_v'], name='add_v') + +mul_q = helper.make_node('Mul', ['add_q' , 'mul_const'], ['mul_q'], name='mul_q') + +reshape_q = helper.make_node('Reshape', ['mul_q', q_shape_initializer.name], ['reshape_q'], name='reshape_q') +reshape_k = helper.make_node('Reshape', ['add_k', k_shape_initializer.name], ['reshape_k'], name='reshape_k') +reshape_v = helper.make_node('Reshape', ['add_v', v_shape_initializer.name], ['reshape_v'], name='reshape_v') + +transpose_q2 = helper.make_node('Transpose', ['reshape_q'], ['transpose_q2'], name='transpose_q2', perm=[1,0,2]) +transpose_k2 = helper.make_node('Transpose', ['reshape_k'], ['transpose_k2'], name='transpose_k2', perm=[1,2,0]) +transpose_v2 = helper.make_node('Transpose', ['reshape_v'], ['transpose_v2'], name='transpose_v2', perm=[1,0,2]) + +matmul = helper.make_node('MatMul', ['transpose_q2', 'transpose_k2'], ['matmul'], name='matmul') +reshape_qk = helper.make_node("Reshape", ['matmul', qk_shape_initializer.name], ['reshape_qk'], name='reshape_qk') + + +unsqueeze = helper.make_node('Unsqueeze', [dummy_condition_initializer.name],['unsqueeze_cond'], axes=[1,2], name='unsqueeze_cond') +where = helper.make_node('Where', ['unsqueeze_cond', inf_const_initializer.name, 'reshape_qk'], ['where'], name='where') + +reshape_where = helper.make_node("Reshape", ['where', where_shape_initializer.name], ['reshape_where'], name='reshape_where') + +softmax = helper.make_node('Softmax', ['reshape_where'], ['softmax'], name='softmax', axis=2) +dropout1 = helper.make_node('Dropout', + ["softmax", dropout_initializer.name, dropout_mode_initializer.name], + ['dropout1', "dropout1_mask"], + name='dropout1') + +matmul2 = helper.make_node('MatMul', ['dropout1', 'transpose_v2'], ['matmul2'], name='matmul2') +transpose = helper.make_node('Transpose', ['matmul2'], ['transpose'], name='transpose', perm=[1,0,2]) +reshape = helper.make_node('Reshape', ['transpose', shape_initializer3.name], ['reshape'], name='reshape') + +transpose_o_weight = helper.make_node('Transpose', [dense_weight_initializer.name], ['transpose_o_weight'], name='transpose_o_weight', perm=[1,0]) +matmul3 = helper.make_node('MatMul', ['reshape', 'transpose_o_weight'], ['matmul3'], name='matmul3') +add3 = helper.make_node('Add', ['matmul3', dense_bias_initializer.name], ['add3'], name='add3') +identity = helper.make_node('Identity', ['add3'], ['output'], name='identity') + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [transpose_ip,transpose_q,transpose_k,transpose_v,matmul_q,matmul_k,matmul_v,add_q,add_k,add_v, + mul_q,reshape_q,reshape_k,reshape_v,transpose_q2,transpose_k2,transpose_v2,matmul,reshape_qk, + unsqueeze,where,reshape_where,softmax,dropout1,matmul2,transpose,reshape,transpose_o_weight, + matmul3,add3,identity], + 'self-attention-megatron-test-model', + [X], + [Y], + [q_weight_initializer,k_weight_initializer,v_weight_initializer,q_bias_initializer,k_bias_initializer, + v_bias_initializer,q_shape_initializer,k_shape_initializer,v_shape_initializer,mul_initializer, + qk_shape_initializer,dummy_condition_initializer,inf_const_initializer,where_shape_initializer, + dropout_initializer,dropout_mode_initializer,shape_initializer3, + dense_weight_initializer, dense_bias_initializer] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'bart_self_attention_megatron_basic_test.onnx') + + diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index f1c433fc01..95412761cc 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -52,9 +52,10 @@ namespace transformer_utils { std::vector> GeneratePreTrainingTransformers( TransformerLevel level, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const IExecutionProvider& execution_provider, + std::unordered_map& updated_weight_names, const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; @@ -94,7 +95,6 @@ std::vector> GeneratePreTrainingTransformers( if (config.enable_gelu_approximation) { transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); } - transformers.emplace_back(onnxruntime::make_unique(execution_provider, compatible_eps, weights_to_train)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); auto horizontal_parallel_size = training::DistributedRunContext::GroupSize(training::WorkerGroupType::HorizontalParallel); @@ -102,7 +102,7 @@ std::vector> GeneratePreTrainingTransformers( LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled"; transformers.emplace_back(onnxruntime::make_unique( training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel), - horizontal_parallel_size, compatible_eps)); + horizontal_parallel_size, updated_weight_names, weights_to_train, compatible_eps)); } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index d6f932742e..f5c7bb601d 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -17,9 +17,10 @@ namespace transformer_utils { /** Generates all pre-training transformers for this level. */ std::vector> GeneratePreTrainingTransformers( TransformerLevel level, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config, const IExecutionProvider& execution_provider, // required for constant folding + std::unordered_map& updated_weight_names, const std::vector& rules_and_transformers_to_enable = {}); /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 385db12f57..851086baeb 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -32,6 +32,7 @@ const std::initializer_list opset_v1_13 = {1 const std::initializer_list opset_v1_11_13 = {1, 11, 13}; const std::initializer_list opset_v2_11_13 = {2, 11, 13}; const std::initializer_list opset_v5_13 = {5, 13}; +const std::initializer_list opset_v1_6_7_13 = {1, 6, 7, 13}; const std::initializer_list opset_v7_13 = {7, 13}; const std::initializer_list opset_v9 = {9}; const std::initializer_list opset_v9_13 = {9, 13}; @@ -42,11 +43,12 @@ const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13); const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13); const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13); const OpInfo div_info = OpInfo("Div", opset_v7_13); -const OpInfo mul_info = OpInfo("Mul", opset_v7_13); +const OpInfo mul_info = OpInfo("Mul", opset_v1_6_7_13); const OpInfo sub_info = OpInfo("Sub", opset_v7_13); const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13); const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain); const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13); +const OpInfo where_info = OpInfo("Where", opset_v9); struct NodeInfo { NodeInfo(const std::vector& op_infos, @@ -119,7 +121,6 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << input_arg.Name() << " is not an initializer"; return false; } - auto data_type = tensor_proto->data_type(); const ONNX_NAMESPACE::TensorShapeProto* shape = input_arg.Shape(); int rank = shape->dim_size(); @@ -147,8 +148,9 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) + - "_" + input_arg.Name() + "_partition"); + std::string new_initializer_name = input_arg.Name() + "_column_rank_" + std::to_string(horizontal_parallel_rank_); + + initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t column_partition = column_count / horizontal_parallel_size_; @@ -209,12 +211,12 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg << horizontal_parallel_size_ << ", not supported currently."; return false; } - auto initializer = onnxruntime::make_unique(*tensor_proto, graph.ModelPath()); const float* a_weight = initializer->data(); - initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) + - "_" + input_arg.Name() + "_partition"); + std::string new_initializer_name = input_arg.Name() + "_row_rank_" + std::to_string(horizontal_parallel_rank_); + + initializer_partition.set_name(new_initializer_name); initializer_partition.set_data_type(data_type); int64_t row_partition = row_count / horizontal_parallel_size_; @@ -231,13 +233,13 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg const int64_t row_index_offset = horizontal_parallel_rank_ * row_partition; memcpy(result.data(), a_weight + row_index_offset * column_count, sizeof(float) * element_count); initializer_partition.set_raw_data(result.data(), element_count * sizeof(float)); - return true; } Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::vector& nodes_to_clear_shape) const { + std::vector& nodes_to_clear_shape, + int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); for (auto node_index : node_topology_list) { @@ -309,12 +311,15 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg); + updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()}); NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg); + updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()}); NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg); + updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()}); graph.RemoveInitializedTensor(a_weight_arg->Name()); graph.RemoveInitializedTensor(b_weight_arg->Name()); @@ -349,6 +354,151 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0); modified = true; + counter++; + } + + return Status::OK(); +} + +/* +DenseWeight -- Transpose \ + MatMul -- BiasGelu -- Dropout -- MatMul -- Add -- Dropout +*/ +Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || + node.GetOutputEdgesCount() != 1) { + continue; + } + Node* second_op = const_cast(graph.GetProducerNode(node.MutableInputDefs()[1]->Name())); + Node* first_op = const_cast(graph.GetProducerNode(node.MutableInputDefs()[0]->Name())); + if (node.GetInputEdgesCount() > 0) { + if (second_op == nullptr) { + break; + } + if (first_op != nullptr && first_op->OpType().compare("MegatronF") == 0) { + continue; + } + + if (second_op->OpType().compare("Transpose") != 0) { + continue; + } + } else { + continue; + } + // check if transpose is only 2-dim + if (!optimizer_utils::IsAttributeWithExpectedValues(*second_op, "perm", {1LL, 0LL})) { + continue; + } + ProviderType provider_type = node.GetExecutionProviderType(); + + Node* biasgelu_node_ptr = graph.GetNode(node.OutputNodesBegin()->Index()); + Node& biasgelu_node = *biasgelu_node_ptr; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(biasgelu_node, "BiasGelu", {1}, kMSDomain) || + biasgelu_node.GetExecutionProviderType() != provider_type || + biasgelu_node.GetOutputEdgesCount() != 1) { + continue; + } + Node& dropout_node = *graph.GetNode(biasgelu_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(dropout_node, dropout_info, provider_type)) { + continue; + } + Node& matmul2_node = *graph.GetNode(dropout_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(matmul2_node, matmul_info, provider_type)) { + continue; + } + Node& add_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(add_node, add_info, provider_type)) { + continue; + } + Node& dropout2_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); + if (!IsExpectedOpAndProvider(dropout2_node, dropout_info, provider_type)) { + continue; + } + Node* transpose_op_ptr = const_cast(graph.GetProducerNode(matmul2_node.MutableInputDefs()[1]->Name())); + if (transpose_op_ptr == nullptr || !IsExpectedOpAndProvider(*transpose_op_ptr, transpose_info, provider_type)) { + continue; + } + + nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &dropout_node, + &matmul2_node, transpose_op_ptr}); + + auto dense_wi_weight_arg = second_op->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_wi_weight_initializer_partition; + if (!PartitionWeightByRow(graph, *dense_wi_weight_arg, dense_wi_weight_initializer_partition)) { + continue; + } + + //since the bias doesnt get transposed, partitioning by col + auto dense_wi_bias_arg = biasgelu_node.MutableInputDefs()[1]; + ONNX_NAMESPACE::TensorProto dense_wi_bias_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_wi_bias_arg, dense_wi_bias_initializer_partition)) { + continue; + } + + auto dense_wo_weight_arg = transpose_op_ptr->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_wo_weight_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_wo_weight_arg, dense_wo_weight_initializer_partition)) { + continue; + } + + NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg); + updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()}); + + NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition); + graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg); + updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()}); + + NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg); + updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()}); + + graph.RemoveInitializedTensor(dense_wi_weight_arg->Name()); + graph.RemoveInitializedTensor(dense_wi_bias_arg->Name()); + graph.RemoveInitializedTensor(dense_wo_weight_arg->Name()); + + dropout_nodes_to_transform.insert(&dropout_node); + + const std::vector mlp_f_input_defs{node.MutableInputDefs()[0]}; + auto mlp_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto(); + auto& mlp_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronF_Output"), &mlp_f_type_info); + Node& mlp_f_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronF"), + "MegatronF", + "MLP MegatronF", + mlp_f_input_defs, + {&mlp_f_out_arg}, {}, kMSDomain); + counter++; + mlp_f_node.SetExecutionProviderType(node.GetExecutionProviderType()); + const Node::EdgeEnd* edge = graph_utils::GetInputEdge(node, 0); + if (nullptr == edge) { // handle input/initializer + graph_utils::ReplaceNodeInput(node, 0, *(mlp_f_node.MutableOutputDefs()[0])); + } else { + auto input_node = const_cast(&edge->GetNode()); + graph_utils::ReplaceDownstreamNodeInput(graph, *input_node, edge->GetSrcArgIndex(), mlp_f_node, 0); + } + + const std::vector mlp_g_input_defs{matmul2_node.MutableOutputDefs()[0]}; + auto mlp_g_type_info = *matmul2_node.MutableOutputDefs()[0]->TypeAsProto(); + auto& mlp_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronG_Output"), &mlp_g_type_info); + Node& mlp_g_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronG"), + "MegatronG", + "MLP MegatronG", + mlp_g_input_defs, + {&mlp_g_out_arg}, {}, kMSDomain); + mlp_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); + mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); + graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0); + modified = true; } return Status::OK(); @@ -357,7 +507,8 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, std::vector& nodes_to_clear_shape, - std::unordered_set& self_attention_dropout_nodes) const { + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -406,7 +557,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, // Get all useful nodes here as more vector push back below will change the index. Node& add_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15]; Node& split_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14]; - Node& transpose_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]; + Node& k_transpose_after_reshape_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]; Node* matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 11]; Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 6]; Node* matmul_node_ptr1 = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5]; @@ -414,7 +565,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, Node& matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 2]; // Transpose node attribute checking. - if (!optimizer_utils::IsAttributeWithExpectedValues(transpose_node, "perm", {0LL, 2LL, 1LL, 3LL}) || + if (!optimizer_utils::IsAttributeWithExpectedValues(k_transpose_after_reshape_node, "perm", {0LL, 2LL, 1LL, 3LL}) || !optimizer_utils::IsAttributeWithExpectedValues(transpose_node1, "perm", {0LL, 2LL, 1LL, 3LL})) { continue; } @@ -518,12 +669,15 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, // Replace by the partition weights. NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg); + updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()}); NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg); + updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()}); NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg); + updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); graph.RemoveInitializedTensor(qkv_weight_arg->Name()); graph.RemoveInitializedTensor(qkv_bias_arg->Name()); @@ -554,16 +708,16 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, } if (dropout_node_ptr != nullptr) { - self_attention_dropout_nodes.insert(dropout_node_ptr); + dropout_nodes_to_transform.insert(dropout_node_ptr); } // Add MegatronF before the 1st MatMul and MegatronG before the last Add. const std::vector sa_f_input_defs{node.MutableInputDefs()[0]}; auto sa_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto(); - auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronF_Output"), &sa_f_type_info); - Node& sa_f_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronF"), + auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronF_Output"), &sa_f_type_info); + Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SeftAttention_MegatronF"), "MegatronF", - "SelfAttention MegatronF", + "SeftAttention MegatronF", sa_f_input_defs, {&sa_f_out_arg}, {}, kMSDomain); sa_f_node.SetExecutionProviderType(node.GetExecutionProviderType()); @@ -577,23 +731,400 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, const std::vector sa_g_input_defs{matmul_node.MutableOutputDefs()[0]}; auto sa_g_type_info = *matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy - auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronG_Output"), &sa_g_type_info); - Node& sa_g_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronG"), + auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronG_Output"), &sa_g_type_info); + Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SelfAttention_MegatronG"), "MegatronG", - "SelfAttention MegatronG", + "Attention MegatronG", sa_g_input_defs, {&sa_g_out_arg}, {}, kMSDomain); sa_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); sa_g_node.SetExecutionProviderType(node.GetExecutionProviderType()); graph_utils::ReplaceDownstreamNodeInput(graph, matmul_node, 0, sa_g_node, 0); modified = true; + counter++; + } + + return Status::OK(); +} + +Status MegatronTransformer::TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + // Self attention sub-graph. + // + // MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout + // MatMul->Add->Reshape->Transpose-------> | | + // MatMul->Add->Reshape->Transpose----------------------------------------------------------> | + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || + node.GetOutputEdgesCount() != 1) { + continue; + } + + Node* q_matmul_input_node_ptr = const_cast(graph.GetProducerNode(node.MutableInputDefs()[0]->Name())); + if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) { + continue; + } + std::vector sub_graph_node_ptrs; + sub_graph_node_ptrs.push_back(&node); + ProviderType provider_type = node.GetExecutionProviderType(); + + std::vector linear_pattern = { + NodeInfo({add_info}), + NodeInfo({mul_info}), + NodeInfo({reshape_info}), + NodeInfo({transpose_info}), + NodeInfo({matmul_info}), + NodeInfo({add_info}, false), // -13 + NodeInfo({reshape_info}), + NodeInfo({where_info}), + NodeInfo({reshape_info}), + NodeInfo({softmax_info}), + NodeInfo({dropout_info}, false), // -8 + NodeInfo({matmul_info}), + NodeInfo({add_info}, false), // -6 + NodeInfo({transpose_info}), + NodeInfo({reshape_info}), + NodeInfo({matmul_info}), // -3 + NodeInfo({add_info}), + NodeInfo({dropout_info}, false)}; // -1 + if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) { + continue; + } + // Get all useful nodes here as more vector push back below will change the index. + // Other than the optional nodes in the pattern, all other node pointers are valid + // if they match the linear pattern. + Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18]; + Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15]; + Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14]; + Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8]; + Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7]; + Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5]; + Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3]; + + // Transpose node attribute checking. + if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) || + !optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) { + continue; + } + // map between reshape node and dim of reshape that must be modified + std::unordered_map reshape_node_ptrs; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0; + reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2; + // till now node should be q matmul operation + + std::vector weight_transpose_node_ptrs; + std::vector bias_add_node_ptrs; + + Node* q_transpose_ptr = const_cast(graph.GetProducerNode(node.MutableInputDefs()[1]->Name())); + if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) { + continue; + } + weight_transpose_node_ptrs.push_back(q_transpose_ptr); + sub_graph_node_ptrs.push_back(q_transpose_ptr); + bias_add_node_ptrs.push_back(q_biasadd_node_ptr); + + Node* k_transpose_ptr = const_cast(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name())); + if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_transpose_ptr); + + Node* k_reshape_ptr = const_cast(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name())); + if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) { + continue; + } + reshape_node_ptrs[k_reshape_ptr] = 1; + sub_graph_node_ptrs.push_back(k_reshape_ptr); + + Node* k_add_ptr = const_cast(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name())); + if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_add_ptr); + bias_add_node_ptrs.push_back(k_add_ptr); + + Node* k_matmul_ptr = const_cast(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name())); + if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_matmul_ptr); + + Node* k_weight_transpose_ptr = const_cast(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name())); + if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(k_weight_transpose_ptr); + weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr); + + Node* v_transpose_ptr = const_cast(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name())); + if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_transpose_ptr); + + Node* v_reshape_ptr = const_cast(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name())); + if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) { + continue; + } + reshape_node_ptrs[v_reshape_ptr] = 1; + sub_graph_node_ptrs.push_back(v_reshape_ptr); + + Node* v_add_ptr = const_cast(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name())); + if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_add_ptr); + bias_add_node_ptrs.push_back(v_add_ptr); + + Node* v_matmul_ptr = const_cast(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name())); + if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_matmul_ptr); + + Node* v_weight_transpose_ptr = const_cast(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name())); + if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) { + continue; + } + sub_graph_node_ptrs.push_back(v_weight_transpose_ptr); + weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr); + + // K and V matmul must have the same input + Node* q_matmul_ptr = &node; + if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) { + continue; + } + + // Check the constant value in the Reshape nodes. + bool is_reshape_valid = true; + for (auto x : reshape_node_ptrs) { + Node* node_ptr = x.first; + int64_t idx = x.second; + auto shape_arg = node_ptr->MutableInputDefs()[1]; + const ONNX_NAMESPACE::TensorProto* tensor; + if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) { + is_reshape_valid = false; + break; + } + auto data_type = tensor->data_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + is_reshape_valid = false; + break; + } + // The number of the values should be more than idx, and the idx'th value should be divisible by parallel size, + // i.e., the attention head number should be divisible by parallel size. + auto init_const = onnxruntime::make_unique(*tensor, graph.ModelPath()); + if (init_const->size() <= idx) { + is_reshape_valid = false; + break; + } + const int64_t* val = init_const->data(); + if (val[idx] % horizontal_parallel_size_ != 0) { + LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx] + << " is not divisible by horizontal_parallel_size_ " + << horizontal_parallel_size_ << ", not supported currently."; + is_reshape_valid = false; + break; + } + } + + if (!is_reshape_valid) { + continue; + } + + // Partition weights. If any of them fails, skip transforming the rest. + std::vector qkv_weight_initializer_partitions; + for (auto trans_ptr : weight_transpose_node_ptrs) { + auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition; + if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) { + break; + } + qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition); + } + + // Partition bias. If any of them fails, skip transforming the rest. + std::vector qkv_bias_initializer_partitions; + for (auto add_ptr : bias_add_node_ptrs) { + auto qkv_bias_arg = add_ptr->MutableInputDefs()[1]; + ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition; + if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) { + break; + } + qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition); + } + + // if all the weights or biases weren't transformed, skip transforming this subgraph + if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) { + continue; + } + if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) { + continue; + } + + // transform the dense weight. If it fails, skip transforming this subgraph. + Node* last_transpose = const_cast(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name())); + auto dense_weight_arg = last_transpose->MutableInputDefs()[0]; + ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition; + if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) { + continue; + } + + // Ready to transform the sub-graph when reach here. + // Replace node inputs + size_t i = 0; + for (auto trans_ptr : weight_transpose_node_ptrs) { + auto weight_name = trans_ptr->MutableInputDefs()[0]->Name(); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]); + graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg); + graph.RemoveInitializedTensor(weight_name); + updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()}); + i++; + } + i = 0; + for (auto add_ptr : bias_add_node_ptrs) { + auto bias_name = add_ptr->MutableInputDefs()[1]->Name(); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]); + graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg); + graph.RemoveInitializedTensor(bias_name); + updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()}); + i++; + } + + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg); + graph.RemoveInitializedTensor(dense_weight_arg->Name()); + updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); + + // It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching. + std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(), + std::back_inserter(nodes_to_clear_shape), + [](Node* node_ptr) { return node_ptr != nullptr; }); + + // Change the constant for the reshape nodes. + for (auto x : reshape_node_ptrs) { + Node* node_ptr = x.first; + int64_t idx = x.second; + auto shape_arg = node_ptr->MutableInputDefs()[1]; + const ONNX_NAMESPACE::TensorProto* tensor; + graph.GetInitializedTensor(shape_arg->Name(), tensor); + auto data_type = tensor->data_type(); + auto init_const = onnxruntime::make_unique(*tensor, graph.ModelPath()); + const int64_t* val = init_const->data(); + int64_t size = init_const->size(); + ONNX_NAMESPACE::TensorProto tensor_partition; + tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); + tensor_partition.set_data_type(data_type); + tensor_partition.add_dims(size); + + std::vector val_partition; + val_partition.reserve(size); + val_partition.insert(val_partition.end(), val, val + size); + val_partition[idx] /= horizontal_parallel_size_; + tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); + NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); + graph.RemoveInitializedTensor(shape_arg->Name()); + } + + if (dropout_node_ptr != nullptr) { + dropout_nodes_to_transform.insert(dropout_node_ptr); + } + + // Add MegatronF before the 1st MatMul and MegatronG before the last Add. + + NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0]; + std::vector new_consumer_nodes; + const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name()); + for (auto& n : node_consumers) { + if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) { + continue; + } + new_consumer_nodes.emplace_back(const_cast(n)); + } + + bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0; + + //then for q, and k&v will have different MegatronF node. + { + const std::vector sa_f_input_defs{prev_input_node_ptr}; + auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto(); + auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info); + Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"), + "MegatronF", + k_matmul_ptr->Name() + " BARTAttention MegatronF", + sa_f_input_defs, + {&sa_f_out_arg}, {}, kMSDomain); + sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType()); + graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + if (shared_same_input) { + graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0])); + } + new_consumer_nodes.push_back(&sa_f_node); + } + graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes); + counter++; + if (!shared_same_input) { + { + NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0]; + std::vector q_new_consumer_nodes; + const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name()); + for (auto& n : q_node_consumers) { + if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) { + continue; + } + q_new_consumer_nodes.emplace_back(const_cast(n)); + } + + const std::vector q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]}; + auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto(); + auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info); + Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"), + "MegatronF", + q_matmul_ptr->Name() + " BARTAttention MegatronF", + q_sa_f_input_defs, + {&q_sa_f_out_arg}, {}, kMSDomain); + q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType()); + + graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0])); + q_new_consumer_nodes.push_back(&q_sa_f_node); + graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes); + // todo: need update the consumer node for the input_node as well. + } + } + + const std::vector sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]}; + auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy + auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info); + Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"), + "MegatronG", + "BARTAttention MegatronG", + sa_g_input_defs, + {&sa_g_out_arg}, {}, kMSDomain); + sa_g_node.AddAttribute("group_type", static_cast(training::WorkerGroupType::HorizontalParallel)); + sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType()); + graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0); + + modified = true; } return Status::OK(); } Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::unordered_set& self_attention_dropout_nodes) const { + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); for (auto node_index : node_topology_list) { @@ -610,18 +1141,17 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g } // Only need to set the seed if it's a transformed self-attention dropout, or the seed attribute is not set. - if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end() || - graph_utils::GetNodeAttribute(node, "seed") == nullptr) { + if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) { int64_t seed = static_cast(HashName(node.MutableOutputDefs()[0]->Name())) + utils::GetRandomSeed(); - if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end()) { + if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) { seed += horizontal_parallel_rank_; } if (graph_utils::GetNodeAttribute(node, "seed") != nullptr) { node.ClearAttribute("seed"); } - node.AddAttribute("seed", seed); + counter++; modified = true; } } @@ -635,11 +1165,19 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le } std::vector nodes_to_clear_shape; - std::unordered_set self_attention_dropout_nodes; + std::unordered_set dropout_nodes_to_transform; - ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape).IsOK()); - ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, self_attention_dropout_nodes).IsOK()); - ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, self_attention_dropout_nodes).IsOK()); + int32_t partitioned_mlp_count = 0; + int32_t partitioned_bart_mlp_count = 0; + int32_t partitioned_attention_count = 0; + int32_t partitioned_bart_attention_count = 0; + int32_t dropout_changed = 0; + + ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, partitioned_mlp_count).IsOK()); + ORT_ENFORCE(TransformBARTMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_mlp_count).IsOK()); + ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_attention_count).IsOK()); + ORT_ENFORCE(TransformBARTSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_attention_count).IsOK()); + ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, dropout_nodes_to_transform, dropout_changed).IsOK()); auto& graph_inputs = graph.GetInputs(); for (auto& node : nodes_to_clear_shape) { @@ -653,13 +1191,28 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le output->ClearShape(); } + for (auto x : updated_weight_names_) { + auto old_initializer_name = x.first; + auto new_initializer_name = x.second; + if (weights_to_train_.find(old_initializer_name) != weights_to_train_.end()) { + weights_to_train_.erase(old_initializer_name); + weights_to_train_.insert(new_initializer_name); + } + } + if (modified) { graph.SetGraphResolveNeeded(); auto ret = graph.Resolve(); + LOGS_DEFAULT(WARNING) << "Megatron transformer result : Partitioned " << partitioned_mlp_count << " MLP Blocks, " + << partitioned_bart_mlp_count << " BART MLP Blocks, " << partitioned_attention_count << " Attention Blocks, " + << partitioned_bart_attention_count << " BART Attention Blocks; Reset seed for " << dropout_changed + << " Dropout nodes. Error Message: " << ret.ErrorMessage() << std::endl; ORT_ENFORCE(ret.IsOK()); + } else { + LOGS_DEFAULT(WARNING) << "Megatron transformer result : unmodified\n"; } return Status::OK(); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.h b/orttraining/orttraining/core/optimizer/megatron_transformer.h index 168d70748a..02690d00cd 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.h +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.h @@ -11,33 +11,50 @@ namespace onnxruntime { class MegatronTransformer : public GraphTransformer { public: MegatronTransformer(int32_t horizontal_parallel_rank, int32_t horizontal_parallel_size, + std::unordered_map& updated_weight_names, + std::unordered_set& weights_to_train, const std::unordered_set& compatible_execution_providers = {}) noexcept : GraphTransformer("MegatronTransformer", compatible_execution_providers), horizontal_parallel_rank_(horizontal_parallel_rank), - horizontal_parallel_size_(horizontal_parallel_size) {} + horizontal_parallel_size_(horizontal_parallel_size), + updated_weight_names_(updated_weight_names), + weights_to_train_(weights_to_train) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: Status TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::vector& nodes_to_clear_shape) const; + std::vector& nodes_to_clear_shape, + int32_t& counter) const; Status TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, std::vector& nodes_to_clear_shape, - std::unordered_set& self_attention_dropout_nodes) const; + std::unordered_set& dropout_nodes_to_transform, + int32_t& counter) const; + + Status TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; + + Status TransformBARTMLP(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger, + std::vector& nodes_to_clear_shape, + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; Status TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger, - std::unordered_set& self_attention_dropout_nodes) const; + std::unordered_set& dropout_nodes_to_transform, int32_t& counter) const; bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg, - ONNX_NAMESPACE::TensorProto& initializer_partition, int stride = 1) const; + ONNX_NAMESPACE::TensorProto& initializer_partition, + int stride = 1) const; - bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, - ONNX_NAMESPACE::TensorProto& initializer_partition) const; + bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition) const; const int32_t horizontal_parallel_rank_; const int32_t horizontal_parallel_size_; + std::unordered_map& updated_weight_names_; + std::unordered_set& weights_to_train_; }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index b49db28db2..9b52a55749 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -43,24 +43,36 @@ Status SetupOptimizerParams( const optional& loss_scale_input_name, const TrainingSession::TrainingConfiguration& config, OptimizerGraphConfig& opt_graph_config_result, - std::unordered_map& opt_node_configs_result) { + std::unordered_map& opt_node_configs_result, + std::unordered_map& weight_name_map_after_graph_transform) { ORT_RETURN_IF_NOT(config.optimizer_config.has_value()); const auto& optimizer_config = config.optimizer_config.value(); + // This is the mapping from the new weight name to the original weight name + // It is required to look up the optimizer config for the original weight + // passed in the training session config + std::unordered_map reversed_weight_names_map; + for (auto& p : weight_name_map_after_graph_transform) { + reversed_weight_names_map.insert({p.second, p.first}); + } + std::unordered_map opt_node_configs{}; for (const auto& weight_name : weight_names_to_train) { OptimizerNodeConfig opt_node_config{}; opt_node_config.name = optimizer_config.name; opt_node_config.lr_feed_name = optimizer_config.learning_rate_input_name; - + std::string original_weight_name = weight_name; + if (reversed_weight_names_map.find(original_weight_name) != reversed_weight_names_map.end()) { + original_weight_name = reversed_weight_names_map.at(original_weight_name); + } try { - opt_node_config.attributes = optimizer_config.weight_attributes_generator(weight_name); + opt_node_config.attributes = optimizer_config.weight_attributes_generator(original_weight_name); } catch (const std::exception& ex) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); } try { - opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(weight_name); + opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(original_weight_name); } catch (const std::exception& ex) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); } @@ -232,7 +244,8 @@ Status TrainingSession::ConfigureForTraining( } } - ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config)); + ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config, + config_result)); if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) { ORT_IGNORE_RETURN_VALUE(Save( @@ -244,8 +257,14 @@ Status TrainingSession::ConfigureForTraining( !filtered_config_weight_names_to_train.empty() ? filtered_config_weight_names_to_train : GetTrainableModelInitializers(config.immutable_weights, loss_name); + for (const auto& weight_name_to_not_train : config.weight_names_to_not_train) { - weight_names_to_train.erase(weight_name_to_not_train); + if (config_result.weight_name_map_after_graph_transform.find(weight_name_to_not_train) != + config_result.weight_name_map_after_graph_transform.end()) { + weight_names_to_train.erase(config_result.weight_name_map_after_graph_transform.at(weight_name_to_not_train)); + } else { + weight_names_to_train.erase(weight_name_to_not_train); + } } { @@ -316,7 +335,7 @@ Status TrainingSession::ConfigureForTraining( std::unordered_map opt_node_configs{}; ORT_RETURN_IF_ERROR(SetupOptimizerParams( weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg, - loss_scale_input_name, config, opt_graph_config, opt_node_configs)); + loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform)); TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{}; ORT_RETURN_IF_ERROR(BuildOptimizer( opt_graph_config, opt_node_configs, @@ -512,8 +531,9 @@ static Status AddGradientAccumulationNodes(Graph& graph, return GraphAugmenter::AugmentGraph(graph, graph_defs); } -Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, - const TrainingConfiguration::GraphTransformerConfiguration& config) { +Status TrainingSession::ApplyTransformationsToMainGraph(std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out) { GraphTransformerManager graph_transformation_mgr{2}; // TODO: ideally we can just reuse the CPU EP registered with the session, but in the training session case // the EPs are registered after ConfigureForTraining and before Initialize is called. Hence we don't have access @@ -522,7 +542,7 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set // Create execution frame for executing constant nodes. std::unique_ptr cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo()); - AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config); + AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config, config_result_out); // apply transformers Graph& graph = model_->MainGraph(); @@ -536,15 +556,16 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set // Registers all the pre transformers with transformer manager void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execution_provider, GraphTransformerManager& transformer_manager, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out, TransformerLevel graph_optimization_level, const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers( - level, weights_to_train, config, execution_provider, custom_list); + level, weights_to_train, config, execution_provider, config_result_out.weight_name_map_after_graph_transform, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 58eb1d526c..1b687b8820 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -245,6 +245,9 @@ class TrainingSession : public InferenceSession { // The pipeline configuration output. // This is only set if an pipeline is enabled. optional pipeline_config_result; + + // Mapped initialized names after weight partitioning for example MegatronTransformer + std::unordered_map weight_name_map_after_graph_transform{}; }; /** @@ -392,14 +395,16 @@ class TrainingSession : public InferenceSession { common::Status InsertPipelineOps(const std::unordered_set& initializer_names_to_preserve, pipeline::PipelineTensorNames& pipeline_tensor_names); - common::Status ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, - const TrainingConfiguration::GraphTransformerConfiguration& config); + common::Status ApplyTransformationsToMainGraph(std::unordered_set& weights_to_train, + const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out); /** configure initial transformers for training */ void AddPreTrainingTransformers(const IExecutionProvider& execution_provider, // for constant folding GraphTransformerManager& transformer_manager, - const std::unordered_set& weights_to_train, + std::unordered_set& weights_to_train, const TrainingConfiguration::GraphTransformerConfiguration& config, + TrainingConfigurationResult& config_result_out, TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel, const std::vector& custom_list = {}); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index fd9f0be0e4..2ae57b30b9 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -43,6 +43,7 @@ struct TrainingParameters { int gradient_accumulation_steps = 1; int data_parallel_size = 1; int horizontal_parallel_size = 1; + int pipeline_parallel_size = 1; int deepspeed_zero_stage = 0; bool enable_grad_norm_clip = true; bool set_gradients_as_graph_outputs = false; @@ -65,11 +66,16 @@ TrainingConfigurationResult ConfigureSessionForTraining( //TODO tix, refactor the mpi related code to populate all fields correctly by default. ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size); ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size); + + if (parameters.world_size % parameters.data_parallel_size != 0) { + throw std::runtime_error("Cannot split data parallel group because world_size is not divisible"); + } + if (parameters.world_size % parameters.horizontal_parallel_size != 0) { throw std::runtime_error("Cannot split horizontal parallel group because world_size is not divisible"); } - auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size; + auto data_group_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); if (data_group_size != parameters.data_parallel_size) { LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to " << data_group_size; @@ -201,7 +207,10 @@ void addObjectMethodsForTraining(py::module& m) { .def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute) .def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute) .def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute) - .def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers); + .def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers) + .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) + .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) + .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size); #if defined(USE_MPI) m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); }); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 6be077c7f3..340a4c07bc 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -163,7 +163,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -231,7 +233,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -298,7 +302,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) { ASSERT_TRUE(ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -363,7 +369,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2), TransformerLevel::Level1); + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1); ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); @@ -449,27 +457,35 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) { // We only tested on CUDA run. #if defined(USE_CUDA) -TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { - auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; - const int total_rank = 4; +static void RunPartitionCorrectnessTest(std::string model_path, +const logging::Logger& logger, +const int total_rank, +std::vector input_names, +std::vector> input_dims) { + const PathString model_uri = ToPathString(model_path) + ORT_TSTR(".onnx"); + // const int total_rank = 4; std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); - ASSERT_TRUE(ret.IsOK()); + Status ret = Model::Load(model_uri, p_models[i], nullptr, logger); + ORT_ENFORCE(ret.IsOK()); Graph& graph = p_models[i]->MainGraph(); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); - ASSERT_TRUE(ret.IsOK()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + std::unordered_map updated_weight_names; + std::unordered_set weights_to_train; + graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank, updated_weight_names, weights_to_train), TransformerLevel::Level1); + ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger); + ORT_ENFORCE(ret.IsOK()); graphs.push_back(&graph); + auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx"); + Model::Save(*p_models[i], model_uri2); } - onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_); + onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, logger); auto& combine_graph = combine_model.MainGraph(); auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); ORT_ENFORCE(ret.IsOK()); - auto model_uri2 = "mlp_megatron_basic_test_partition_combine.onnx"; + auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx"); Model::Save(combine_model, model_uri2); float scale = 1.f; @@ -479,10 +495,18 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { std::default_random_engine generator{gsl::narrow_cast(seed)}; std::normal_distribution distribution{mean, scale}; - std::vector dims_X = {8, 16, 4}; - std::vector values_X(TensorShape(dims_X).Size()); - std::for_each(values_X.begin(), values_X.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); + ORT_ENFORCE(input_names.size() == input_dims.size()); + NameMLValMap feeds; + for(size_t i = 0; i< input_dims.size();i++){ + std::vector dims_X = input_dims[i]; + std::vector values_X(TensorShape(dims_X).Size()); + std::for_each(values_X.begin(), values_X.end(), + [&generator, &distribution](float& value) { value = distribution(generator); }); + + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); + feeds.insert(std::make_pair(input_names[i], ml_value)); + } std::vector expected_ort_values; { @@ -497,11 +521,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("input", ml_value)); - // prepare outputs std::vector output_names; output_names.push_back("output"); @@ -527,11 +546,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st; ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - NameMLValMap feeds; - feeds.insert(std::make_pair("input", ml_value)); - // prepare outputs std::vector output_names; for (auto i = 0; i < total_rank; i++) { @@ -554,137 +568,21 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { } } +TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}}); +} + +TEST_F(GraphTransformationTests, MegatronBARTMLPPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}}); +} + TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { - auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; - const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. - std::vector graphs; - std::vector> p_models(total_rank); - for (auto i = 0; i < total_rank; i++) { - auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_); - ASSERT_TRUE(ret.IsOK()); - Graph& graph = p_models[i]->MainGraph(); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(i, total_rank), TransformerLevel::Level1); - ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); - ASSERT_TRUE(ret.IsOK()); - graphs.push_back(&graph); - } - - // Dropout seed checking. - const AttributeProto* attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout1"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i()); - int64_t dropout1_rank0_seed = attr->i(); - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout2"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i()); - int64_t dropout2_rank0_seed = attr->i(); - for (auto i = 1; i < total_rank; i++) { - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout1"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout1_rank0_seed + i); - attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout2"), "seed"); - ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed); - } - - onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_); - auto& combine_graph = combine_model.MainGraph(); - auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph); - ORT_ENFORCE(ret.IsOK()); - auto model_uri2 = "self_attention_megatron_basic_test_partition_combine.onnx"; - Model::Save(combine_model, model_uri2); - - float scale = 1.f; - float mean = 0.f; - float seed = 123.f; - - std::default_random_engine generator{gsl::narrow_cast(seed)}; - std::normal_distribution distribution{mean, scale}; - - std::vector dims_X = {8, 16, 4}; - std::vector values_X(TensorShape(dims_X).Size()); - std::for_each(values_X.begin(), values_X.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); - - std::vector dims_Mask = {8, 1, 16, 16}; - std::vector values_Mask(TensorShape(dims_Mask).Size()); - std::for_each(values_Mask.begin(), values_Mask.end(), - [&generator, &distribution](float& value) { value = distribution(generator); }); - - std::vector expected_ort_values; - { - SessionOptions so; - so.session_logid = "RawGraphRun"; - - InferenceSession session_object{so, GetEnvironment()}; - std::unique_ptr execution_provider = DefaultCudaExecutionProvider(); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - - Status st; - ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - - NameMLValMap feeds; - - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - feeds.insert(std::make_pair("input", ml_value)); - - OrtValue mask_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value); - feeds.insert(std::make_pair("mask", mask_value)); - - // prepare outputs - std::vector output_names; - output_names.push_back("output"); - - // Now run - RunOptions run_options; - run_options.training_mode = true; - st = session_object.Run(run_options, feeds, output_names, &expected_ort_values); - EXPECT_TRUE(st.IsOK()); - } - - std::vector actual_ort_values; - { - SessionOptions so; - so.session_logid = "SplitThenCombineRun"; - - InferenceSession session_object{so, GetEnvironment()}; - std::unique_ptr execution_provider = DefaultCudaExecutionProvider(); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - - Status st; - ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; - - NameMLValMap feeds; - OrtValue ml_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value); - feeds.insert(std::make_pair("input", ml_value)); - - OrtValue mask_value; - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value); - feeds.insert(std::make_pair("mask", mask_value)); - - // prepare outputs - std::vector output_names; - for (auto i = 0; i < total_rank; i++) { - output_names.push_back("output_rank_" + std::to_string(i)); - } - - // Now run - RunOptions run_options; - run_options.training_mode = true; - st = session_object.Run(run_options, feeds, output_names, &actual_ort_values); - EXPECT_TRUE(st.IsOK()); - } - - auto& expected_val = expected_ort_values[0].Get(); - for (auto i = 0; i < total_rank; i++) { - auto& actual_val = actual_ort_values[i].Get(); - horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, true); - horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, false); - } + RunPartitionCorrectnessTest("testdata/transform/model_parallel/self_attention_megatron_basic_test", *logger_, 2, {"input", "mask"}, {{8, 16, 4}, {8, 1, 16, 16}}); } +TEST_F(GraphTransformationTests, MegatronBARTSelfAttentionPartitionCorrectnessTest) { + RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_self_attention_megatron_basic_test", *logger_, 2, {"input"}, {{6, 8, 4}}); +} // end of USE_CUDA #endif