Add megatron transforms for BART (#5521)

* Large model export and run ORT Python support

* Megatron change

refine a bit

workaround self attention issue

use partitioned name for weights when megatron model parallel is enabled

Fix Megatron Transformer Issue (cuased by the renaming)

Add UTs for T5 model parallel

Fix megatron seed issue

fix log a bit

checkkpointing changes + rebase

Unintended reshape transform change

t5 layer norm changes

add t5 layer norm kernel

use template for t5 layer norm

template definition changes

no build error

add CPU cuda kernel

first unit test

other forward unit tests

add T5LayerNormGrad

Add c++ transform and test for T5 LN

minor fix

BART MLP Megatron tranform

Add concat slice transform + test

Cosmetic improvements in concat slice transform

Constant folding bug fix + megatron attention transform for BART

Undo unnecessary changes

* Cleanup

* Remove unnecessary changes

* Cleanup megatron

* Windows build

* Add self attention test graph

* Correcting transforms + cleanup

* review comments

* review comments

* fix build and test failures

* Fix CI

* fix windows CI

Co-authored-by: Peng Wang <pengwa@microsoft.com>
Co-authored-by: Aishwarya <aibhanda@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
ashbhandare 2020-11-11 16:21:36 -08:00 committed by GitHub
parent a14cd6267b
commit 5aec34500d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 979 additions and 216 deletions

View file

@ -95,8 +95,12 @@ TEST(CseTests, SimpleTestTraining) {
.IsOK());
GraphTransformerManager graph_transformation_mgr(1);
// need to declare variables to avoid build error after making
// weights_to_train and updated_weight_names as non-const
std::unordered_set<std::string> weights_to_train;
std::unordered_map<std::string, std::string> updated_weight_names;
auto transformers_to_register = onnxruntime::training::transformer_utils::GeneratePreTrainingTransformers(
TransformerLevel::Level1, {}, {}, CPUExecutionProvider(CPUExecutionProviderInfo()));
TransformerLevel::Level1, weights_to_train, {}, CPUExecutionProvider(CPUExecutionProviderInfo()), updated_weight_names);
for (auto& entry : transformers_to_register) {
ASSERT_TRUE(
graph_transformation_mgr.Register(std::move(entry), TransformerLevel::Level1).IsOK());

View file

@ -0,0 +1,105 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
hidden_size = 4
weight_dim_to_split = 16
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", hidden_size])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", hidden_size])
a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((weight_dim_to_split, hidden_size))
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.weight")
a_bias_np_vals = (0.01 * np.arange(weight_dim_to_split, dtype=np.float32)) # weight_dim_to_split numbers in total
a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.bias")
dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(())
dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio")
dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(())
dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode")
b_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((hidden_size, weight_dim_to_split))
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.weight")
b_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) # hidden_size numbers in total
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.bias")
transpose1 = helper.make_node('Transpose', [a_weight_initializer.name], ['transpose1'], name='transpose1', perm=[1,0])
transpose2 = helper.make_node('Transpose', [b_weight_initializer.name], ['transpose2'], name='transpose2', perm=[1,0])
matmul = helper.make_node(
'MatMul', # node name
['input', 'transpose1'], # inputs
['matmul'], # outputs
name="matmul"
)
biasgelu = helper.make_node(
'BiasGelu', # node name
['matmul', a_bias_initializer.name], # inputs
['biasgelu'], # outputs
name="biasgelu",
domain="com.microsoft"
)
dropout1 = helper.make_node('Dropout',
["biasgelu", dropout_initializer.name, dropout_mode_initializer.name],
['dropout1', "dropout1_mask"],
name='dropout1')
matmul2 = helper.make_node(
'MatMul', # node name
['dropout1', "transpose2"], # inputs
['matmul2'], # outputs
name="matmul2"
)
add = helper.make_node(
'Add', # node name
['matmul2', b_bias_initializer.name], # inputs
['add'], # outputs
name="add"
)
dropout2 = helper.make_node('Dropout',
["add", dropout_initializer.name, dropout_mode_initializer.name],
['dropout2', "dropout2_mask"],
name='dropout2')
identity = helper.make_node(
'Identity', # node name
['dropout2'], # inputs
['output'], # outputs
name="identity"
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[transpose1, transpose2, matmul, biasgelu, dropout1, matmul2, add, dropout2, identity],
'test-model',
[X],
[Y],
[a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, dropout_initializer, dropout_mode_initializer]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
kwargs={}
kwargs["opset_imports"] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'bart_mlp_megatron_basic_test.onnx')

View file

@ -0,0 +1,150 @@
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
import random
batch = 6
hidden_size = 4
attention_head = 2
hidden_per_attention = 2
relative_attention_num_buckets=32
input_len=8
output_len=8
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [batch, input_len, hidden_size])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [output_len, batch, hidden_size])
q_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, 'encoder.layers.0.self_attn.q_proj.weight')
k_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, 'encoder.layers.0.self_attn.k_proj.weight')
v_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, 'encoder.layers.0.self_attn.v_proj.weight')
q_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, 'encoder.layers.0.self_attn.q_proj.bias')
k_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, 'encoder.layers.0.self_attn.k_proj.bias')
v_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, 'encoder.layers.0.self_attn.v_proj.bias')
q_shape_initializer = numpy_helper.from_array(np.asarray([input_len, batch*attention_head , hidden_per_attention], dtype=np.int64), 'q_shape')
k_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'k_shape')
v_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'v_shape')
mul_np_vals = np.asarray([0.1767766922712326], dtype=np.float32).reshape(())
mul_initializer = numpy_helper.from_array(mul_np_vals, "mul_const")
qk_shape_initializer = numpy_helper.from_array(np.asarray([batch, attention_head , input_len, input_len], dtype=np.int64), 'qk_shape')
dummy_condition_initializer = numpy_helper.from_array(np.zeros((batch, input_len), dtype=bool), 'dummy_cond')
inf_const_initializer = numpy_helper.from_array(np.asarray([-np.inf], dtype=np.float32), 'inf_const')
where_shape_initializer = numpy_helper.from_array(np.asarray([batch*attention_head , input_len, input_len], dtype=np.int64), 'where_shape')
dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(())
dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio")
dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(())
dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode")
shape_initializer3 = numpy_helper.from_array(np.array([input_len, batch, attention_head * hidden_per_attention], dtype=np.int64), 'concat_shape_3')
dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, 'encoder.layers.0.self_attn.out_proj.weight')
dense_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, 'encoder.layers.0.self_attn.out_proj.bias')
transpose_ip = helper.make_node('Transpose', ['input'], ['transpose_ip'], name='transpose_ip', perm=[1,0,2])
transpose_q = helper.make_node('Transpose', [q_weight_initializer.name], ['transpose_q'], name='transpose_q', perm=[1,0])
transpose_k = helper.make_node('Transpose', [k_weight_initializer.name], ['transpose_k'], name='transpose_k', perm=[1,0])
transpose_v = helper.make_node('Transpose', [v_weight_initializer.name], ['transpose_v'], name='transpose_v', perm=[1,0])
matmul_q = helper.make_node('MatMul', ['transpose_ip', 'transpose_q'], ['matmul_q'], name='matmul_q')
matmul_k = helper.make_node('MatMul', ['transpose_ip', 'transpose_k'], ['matmul_k'], name='matmul_k')
matmul_v = helper.make_node('MatMul', ['transpose_ip', 'transpose_v'], ['matmul_v'], name='matmul_v')
add_q = helper.make_node('Add', ['matmul_q', q_bias_initializer.name], ['add_q'], name='add_q')
add_k = helper.make_node('Add', ['matmul_k', k_bias_initializer.name], ['add_k'], name='add_k')
add_v = helper.make_node('Add', ['matmul_v', v_bias_initializer.name], ['add_v'], name='add_v')
mul_q = helper.make_node('Mul', ['add_q' , 'mul_const'], ['mul_q'], name='mul_q')
reshape_q = helper.make_node('Reshape', ['mul_q', q_shape_initializer.name], ['reshape_q'], name='reshape_q')
reshape_k = helper.make_node('Reshape', ['add_k', k_shape_initializer.name], ['reshape_k'], name='reshape_k')
reshape_v = helper.make_node('Reshape', ['add_v', v_shape_initializer.name], ['reshape_v'], name='reshape_v')
transpose_q2 = helper.make_node('Transpose', ['reshape_q'], ['transpose_q2'], name='transpose_q2', perm=[1,0,2])
transpose_k2 = helper.make_node('Transpose', ['reshape_k'], ['transpose_k2'], name='transpose_k2', perm=[1,2,0])
transpose_v2 = helper.make_node('Transpose', ['reshape_v'], ['transpose_v2'], name='transpose_v2', perm=[1,0,2])
matmul = helper.make_node('MatMul', ['transpose_q2', 'transpose_k2'], ['matmul'], name='matmul')
reshape_qk = helper.make_node("Reshape", ['matmul', qk_shape_initializer.name], ['reshape_qk'], name='reshape_qk')
unsqueeze = helper.make_node('Unsqueeze', [dummy_condition_initializer.name],['unsqueeze_cond'], axes=[1,2], name='unsqueeze_cond')
where = helper.make_node('Where', ['unsqueeze_cond', inf_const_initializer.name, 'reshape_qk'], ['where'], name='where')
reshape_where = helper.make_node("Reshape", ['where', where_shape_initializer.name], ['reshape_where'], name='reshape_where')
softmax = helper.make_node('Softmax', ['reshape_where'], ['softmax'], name='softmax', axis=2)
dropout1 = helper.make_node('Dropout',
["softmax", dropout_initializer.name, dropout_mode_initializer.name],
['dropout1', "dropout1_mask"],
name='dropout1')
matmul2 = helper.make_node('MatMul', ['dropout1', 'transpose_v2'], ['matmul2'], name='matmul2')
transpose = helper.make_node('Transpose', ['matmul2'], ['transpose'], name='transpose', perm=[1,0,2])
reshape = helper.make_node('Reshape', ['transpose', shape_initializer3.name], ['reshape'], name='reshape')
transpose_o_weight = helper.make_node('Transpose', [dense_weight_initializer.name], ['transpose_o_weight'], name='transpose_o_weight', perm=[1,0])
matmul3 = helper.make_node('MatMul', ['reshape', 'transpose_o_weight'], ['matmul3'], name='matmul3')
add3 = helper.make_node('Add', ['matmul3', dense_bias_initializer.name], ['add3'], name='add3')
identity = helper.make_node('Identity', ['add3'], ['output'], name='identity')
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[transpose_ip,transpose_q,transpose_k,transpose_v,matmul_q,matmul_k,matmul_v,add_q,add_k,add_v,
mul_q,reshape_q,reshape_k,reshape_v,transpose_q2,transpose_k2,transpose_v2,matmul,reshape_qk,
unsqueeze,where,reshape_where,softmax,dropout1,matmul2,transpose,reshape,transpose_o_weight,
matmul3,add3,identity],
'self-attention-megatron-test-model',
[X],
[Y],
[q_weight_initializer,k_weight_initializer,v_weight_initializer,q_bias_initializer,k_bias_initializer,
v_bias_initializer,q_shape_initializer,k_shape_initializer,v_shape_initializer,mul_initializer,
qk_shape_initializer,dummy_condition_initializer,inf_const_initializer,where_shape_initializer,
dropout_initializer,dropout_mode_initializer,shape_initializer3,
dense_weight_initializer, dense_bias_initializer]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'bart_self_attention_megatron_basic_test.onnx')

View file

@ -52,9 +52,10 @@ namespace transformer_utils {
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
std::unordered_set<std::string>& weights_to_train,
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
const IExecutionProvider& execution_provider,
std::unordered_map<std::string, std::string>& updated_weight_names,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> transformers;
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
@ -94,7 +95,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
if (config.enable_gelu_approximation) {
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(compatible_eps));
}
transformers.emplace_back(onnxruntime::make_unique<ConstantFolding>(execution_provider, compatible_eps, weights_to_train));
transformers.emplace_back(onnxruntime::make_unique<ReshapeFusion>(compatible_eps));
auto horizontal_parallel_size = training::DistributedRunContext::GroupSize(training::WorkerGroupType::HorizontalParallel);
@ -102,7 +102,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled";
transformers.emplace_back(onnxruntime::make_unique<MegatronTransformer>(
training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel),
horizontal_parallel_size, compatible_eps));
horizontal_parallel_size, updated_weight_names, weights_to_train, compatible_eps));
}
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));

View file

@ -17,9 +17,10 @@ namespace transformer_utils {
/** Generates all pre-training transformers for this level. */
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
std::unordered_set<std::string>& weights_to_train,
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
const IExecutionProvider& execution_provider, // required for constant folding
std::unordered_map<std::string, std::string>& updated_weight_names,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.

View file

@ -32,6 +32,7 @@ const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_13 = {1
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11_13 = {1, 11, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11_13 = {2, 11, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5_13 = {5, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_6_7_13 = {1, 6, 7, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7_13 = {7, 13};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9 = {9};
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9_13 = {9, 13};
@ -42,11 +43,12 @@ const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13);
const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13);
const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13);
const OpInfo div_info = OpInfo("Div", opset_v7_13);
const OpInfo mul_info = OpInfo("Mul", opset_v7_13);
const OpInfo mul_info = OpInfo("Mul", opset_v1_6_7_13);
const OpInfo sub_info = OpInfo("Sub", opset_v7_13);
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13);
const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain);
const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13);
const OpInfo where_info = OpInfo("Where", opset_v9);
struct NodeInfo {
NodeInfo(const std::vector<OpInfo>& op_infos,
@ -119,7 +121,6 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node
LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << input_arg.Name() << " is not an initializer";
return false;
}
auto data_type = tensor_proto->data_type();
const ONNX_NAMESPACE::TensorShapeProto* shape = input_arg.Shape();
int rank = shape->dim_size();
@ -147,8 +148,9 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
const float* a_weight = initializer->data<float>();
initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) +
"_" + input_arg.Name() + "_partition");
std::string new_initializer_name = input_arg.Name() + "_column_rank_" + std::to_string(horizontal_parallel_rank_);
initializer_partition.set_name(new_initializer_name);
initializer_partition.set_data_type(data_type);
int64_t column_partition = column_count / horizontal_parallel_size_;
@ -209,12 +211,12 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg
<< horizontal_parallel_size_ << ", not supported currently.";
return false;
}
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
const float* a_weight = initializer->data<float>();
initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) +
"_" + input_arg.Name() + "_partition");
std::string new_initializer_name = input_arg.Name() + "_row_rank_" + std::to_string(horizontal_parallel_rank_);
initializer_partition.set_name(new_initializer_name);
initializer_partition.set_data_type(data_type);
int64_t row_partition = row_count / horizontal_parallel_size_;
@ -231,13 +233,13 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg
const int64_t row_index_offset = horizontal_parallel_rank_ * row_partition;
memcpy(result.data(), a_weight + row_index_offset * column_count, sizeof(float) * element_count);
initializer_partition.set_raw_data(result.data(), element_count * sizeof(float));
return true;
}
Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape) const {
std::vector<Node*>& nodes_to_clear_shape,
int32_t& counter) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
@ -309,12 +311,15 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition);
graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg);
updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()});
NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition);
graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg);
updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()});
NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition);
graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg);
updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()});
graph.RemoveInitializedTensor(a_weight_arg->Name());
graph.RemoveInitializedTensor(b_weight_arg->Name());
@ -349,6 +354,151 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0);
modified = true;
counter++;
}
return Status::OK();
}
/*
DenseWeight -- Transpose \
MatMul -- BiasGelu -- Dropout -- MatMul -- Add -- Dropout
*/
Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
}
Node* second_op = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
Node* first_op = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
if (node.GetInputEdgesCount() > 0) {
if (second_op == nullptr) {
break;
}
if (first_op != nullptr && first_op->OpType().compare("MegatronF") == 0) {
continue;
}
if (second_op->OpType().compare("Transpose") != 0) {
continue;
}
} else {
continue;
}
// check if transpose is only 2-dim
if (!optimizer_utils::IsAttributeWithExpectedValues(*second_op, "perm", {1LL, 0LL})) {
continue;
}
ProviderType provider_type = node.GetExecutionProviderType();
Node* biasgelu_node_ptr = graph.GetNode(node.OutputNodesBegin()->Index());
Node& biasgelu_node = *biasgelu_node_ptr;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(biasgelu_node, "BiasGelu", {1}, kMSDomain) ||
biasgelu_node.GetExecutionProviderType() != provider_type ||
biasgelu_node.GetOutputEdgesCount() != 1) {
continue;
}
Node& dropout_node = *graph.GetNode(biasgelu_node.OutputNodesBegin()->Index());
if (!IsExpectedOpAndProvider(dropout_node, dropout_info, provider_type)) {
continue;
}
Node& matmul2_node = *graph.GetNode(dropout_node.OutputNodesBegin()->Index());
if (!IsExpectedOpAndProvider(matmul2_node, matmul_info, provider_type)) {
continue;
}
Node& add_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index());
if (!IsExpectedOpAndProvider(add_node, add_info, provider_type)) {
continue;
}
Node& dropout2_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
if (!IsExpectedOpAndProvider(dropout2_node, dropout_info, provider_type)) {
continue;
}
Node* transpose_op_ptr = const_cast<Node*>(graph.GetProducerNode(matmul2_node.MutableInputDefs()[1]->Name()));
if (transpose_op_ptr == nullptr || !IsExpectedOpAndProvider(*transpose_op_ptr, transpose_info, provider_type)) {
continue;
}
nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &dropout_node,
&matmul2_node, transpose_op_ptr});
auto dense_wi_weight_arg = second_op->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto dense_wi_weight_initializer_partition;
if (!PartitionWeightByRow(graph, *dense_wi_weight_arg, dense_wi_weight_initializer_partition)) {
continue;
}
//since the bias doesnt get transposed, partitioning by col
auto dense_wi_bias_arg = biasgelu_node.MutableInputDefs()[1];
ONNX_NAMESPACE::TensorProto dense_wi_bias_initializer_partition;
if (!PartitionWeightByColumn(graph, *dense_wi_bias_arg, dense_wi_bias_initializer_partition)) {
continue;
}
auto dense_wo_weight_arg = transpose_op_ptr->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto dense_wo_weight_initializer_partition;
if (!PartitionWeightByColumn(graph, *dense_wo_weight_arg, dense_wo_weight_initializer_partition)) {
continue;
}
NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition);
graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg);
updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()});
NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition);
graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg);
updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()});
NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition);
graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg);
updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()});
graph.RemoveInitializedTensor(dense_wi_weight_arg->Name());
graph.RemoveInitializedTensor(dense_wi_bias_arg->Name());
graph.RemoveInitializedTensor(dense_wo_weight_arg->Name());
dropout_nodes_to_transform.insert(&dropout_node);
const std::vector<NodeArg*> mlp_f_input_defs{node.MutableInputDefs()[0]};
auto mlp_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto();
auto& mlp_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronF_Output"), &mlp_f_type_info);
Node& mlp_f_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronF"),
"MegatronF",
"MLP MegatronF",
mlp_f_input_defs,
{&mlp_f_out_arg}, {}, kMSDomain);
counter++;
mlp_f_node.SetExecutionProviderType(node.GetExecutionProviderType());
const Node::EdgeEnd* edge = graph_utils::GetInputEdge(node, 0);
if (nullptr == edge) { // handle input/initializer
graph_utils::ReplaceNodeInput(node, 0, *(mlp_f_node.MutableOutputDefs()[0]));
} else {
auto input_node = const_cast<Node*>(&edge->GetNode());
graph_utils::ReplaceDownstreamNodeInput(graph, *input_node, edge->GetSrcArgIndex(), mlp_f_node, 0);
}
const std::vector<NodeArg*> mlp_g_input_defs{matmul2_node.MutableOutputDefs()[0]};
auto mlp_g_type_info = *matmul2_node.MutableOutputDefs()[0]->TypeAsProto();
auto& mlp_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronG_Output"), &mlp_g_type_info);
Node& mlp_g_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronG"),
"MegatronG",
"MLP MegatronG",
mlp_g_input_defs,
{&mlp_g_out_arg}, {}, kMSDomain);
mlp_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0);
modified = true;
}
return Status::OK();
@ -357,7 +507,8 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& self_attention_dropout_nodes) const {
std::unordered_set<Node*>& dropout_nodes_to_transform,
int32_t& counter) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -406,7 +557,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
// Get all useful nodes here as more vector push back below will change the index.
Node& add_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
Node& split_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
Node& transpose_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12];
Node& k_transpose_after_reshape_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12];
Node* matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 11];
Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 6];
Node* matmul_node_ptr1 = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
@ -414,7 +565,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
Node& matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 2];
// Transpose node attribute checking.
if (!optimizer_utils::IsAttributeWithExpectedValues(transpose_node, "perm", {0LL, 2LL, 1LL, 3LL}) ||
if (!optimizer_utils::IsAttributeWithExpectedValues(k_transpose_after_reshape_node, "perm", {0LL, 2LL, 1LL, 3LL}) ||
!optimizer_utils::IsAttributeWithExpectedValues(transpose_node1, "perm", {0LL, 2LL, 1LL, 3LL})) {
continue;
}
@ -518,12 +669,15 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
// Replace by the partition weights.
NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition);
graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg);
updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()});
NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition);
graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg);
updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()});
NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg);
updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});
graph.RemoveInitializedTensor(qkv_weight_arg->Name());
graph.RemoveInitializedTensor(qkv_bias_arg->Name());
@ -554,16 +708,16 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
}
if (dropout_node_ptr != nullptr) {
self_attention_dropout_nodes.insert(dropout_node_ptr);
dropout_nodes_to_transform.insert(dropout_node_ptr);
}
// Add MegatronF before the 1st MatMul and MegatronG before the last Add.
const std::vector<NodeArg*> sa_f_input_defs{node.MutableInputDefs()[0]};
auto sa_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto();
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronF_Output"), &sa_f_type_info);
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronF"),
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronF_Output"), &sa_f_type_info);
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SeftAttention_MegatronF"),
"MegatronF",
"SelfAttention MegatronF",
"SeftAttention MegatronF",
sa_f_input_defs,
{&sa_f_out_arg}, {}, kMSDomain);
sa_f_node.SetExecutionProviderType(node.GetExecutionProviderType());
@ -577,23 +731,400 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
const std::vector<NodeArg*> sa_g_input_defs{matmul_node.MutableOutputDefs()[0]};
auto sa_g_type_info = *matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronG_Output"), &sa_g_type_info);
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronG"),
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronG_Output"), &sa_g_type_info);
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SelfAttention_MegatronG"),
"MegatronG",
"SelfAttention MegatronG",
"Attention MegatronG",
sa_g_input_defs,
{&sa_g_out_arg}, {}, kMSDomain);
sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
sa_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
graph_utils::ReplaceDownstreamNodeInput(graph, matmul_node, 0, sa_g_node, 0);
modified = true;
counter++;
}
return Status::OK();
}
Status MegatronTransformer::TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& dropout_nodes_to_transform,
int32_t& counter) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
// Self attention sub-graph.
//
// MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout
// MatMul->Add->Reshape->Transpose-------> | |
// MatMul->Add->Reshape->Transpose----------------------------------------------------------> |
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
}
Node* q_matmul_input_node_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) {
continue;
}
std::vector<Node*> sub_graph_node_ptrs;
sub_graph_node_ptrs.push_back(&node);
ProviderType provider_type = node.GetExecutionProviderType();
std::vector<NodeInfo> linear_pattern = {
NodeInfo({add_info}),
NodeInfo({mul_info}),
NodeInfo({reshape_info}),
NodeInfo({transpose_info}),
NodeInfo({matmul_info}),
NodeInfo({add_info}, false), // -13
NodeInfo({reshape_info}),
NodeInfo({where_info}),
NodeInfo({reshape_info}),
NodeInfo({softmax_info}),
NodeInfo({dropout_info}, false), // -8
NodeInfo({matmul_info}),
NodeInfo({add_info}, false), // -6
NodeInfo({transpose_info}),
NodeInfo({reshape_info}),
NodeInfo({matmul_info}), // -3
NodeInfo({add_info}),
NodeInfo({dropout_info}, false)}; // -1
if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) {
continue;
}
// Get all useful nodes here as more vector push back below will change the index.
// Other than the optional nodes in the pattern, all other node pointers are valid
// if they match the linear pattern.
Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18];
Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8];
Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7];
Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3];
// Transpose node attribute checking.
if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) ||
!optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) {
continue;
}
// map between reshape node and dim of reshape that must be modified
std::unordered_map<Node*, int64_t> reshape_node_ptrs;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2;
// till now node should be q matmul operation
std::vector<Node*> weight_transpose_node_ptrs;
std::vector<Node*> bias_add_node_ptrs;
Node* q_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) {
continue;
}
weight_transpose_node_ptrs.push_back(q_transpose_ptr);
sub_graph_node_ptrs.push_back(q_transpose_ptr);
bias_add_node_ptrs.push_back(q_biasadd_node_ptr);
Node* k_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name()));
if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(k_transpose_ptr);
Node* k_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name()));
if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) {
continue;
}
reshape_node_ptrs[k_reshape_ptr] = 1;
sub_graph_node_ptrs.push_back(k_reshape_ptr);
Node* k_add_ptr = const_cast<Node*>(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name()));
if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(k_add_ptr);
bias_add_node_ptrs.push_back(k_add_ptr);
Node* k_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name()));
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(k_matmul_ptr);
Node* k_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name()));
if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(k_weight_transpose_ptr);
weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr);
Node* v_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name()));
if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(v_transpose_ptr);
Node* v_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name()));
if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) {
continue;
}
reshape_node_ptrs[v_reshape_ptr] = 1;
sub_graph_node_ptrs.push_back(v_reshape_ptr);
Node* v_add_ptr = const_cast<Node*>(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name()));
if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(v_add_ptr);
bias_add_node_ptrs.push_back(v_add_ptr);
Node* v_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name()));
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(v_matmul_ptr);
Node* v_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name()));
if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) {
continue;
}
sub_graph_node_ptrs.push_back(v_weight_transpose_ptr);
weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr);
// K and V matmul must have the same input
Node* q_matmul_ptr = &node;
if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) {
continue;
}
// Check the constant value in the Reshape nodes.
bool is_reshape_valid = true;
for (auto x : reshape_node_ptrs) {
Node* node_ptr = x.first;
int64_t idx = x.second;
auto shape_arg = node_ptr->MutableInputDefs()[1];
const ONNX_NAMESPACE::TensorProto* tensor;
if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) {
is_reshape_valid = false;
break;
}
auto data_type = tensor->data_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
is_reshape_valid = false;
break;
}
// The number of the values should be more than idx, and the idx'th value should be divisible by parallel size,
// i.e., the attention head number should be divisible by parallel size.
auto init_const = onnxruntime::make_unique<Initializer>(*tensor, graph.ModelPath());
if (init_const->size() <= idx) {
is_reshape_valid = false;
break;
}
const int64_t* val = init_const->data<int64_t>();
if (val[idx] % horizontal_parallel_size_ != 0) {
LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx]
<< " is not divisible by horizontal_parallel_size_ "
<< horizontal_parallel_size_ << ", not supported currently.";
is_reshape_valid = false;
break;
}
}
if (!is_reshape_valid) {
continue;
}
// Partition weights. If any of them fails, skip transforming the rest.
std::vector<ONNX_NAMESPACE::TensorProto> qkv_weight_initializer_partitions;
for (auto trans_ptr : weight_transpose_node_ptrs) {
auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition;
if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) {
break;
}
qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition);
}
// Partition bias. If any of them fails, skip transforming the rest.
std::vector<ONNX_NAMESPACE::TensorProto> qkv_bias_initializer_partitions;
for (auto add_ptr : bias_add_node_ptrs) {
auto qkv_bias_arg = add_ptr->MutableInputDefs()[1];
ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition;
if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) {
break;
}
qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition);
}
// if all the weights or biases weren't transformed, skip transforming this subgraph
if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) {
continue;
}
if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) {
continue;
}
// transform the dense weight. If it fails, skip transforming this subgraph.
Node* last_transpose = const_cast<Node*>(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name()));
auto dense_weight_arg = last_transpose->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition;
if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) {
continue;
}
// Ready to transform the sub-graph when reach here.
// Replace node inputs
size_t i = 0;
for (auto trans_ptr : weight_transpose_node_ptrs) {
auto weight_name = trans_ptr->MutableInputDefs()[0]->Name();
NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]);
graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg);
graph.RemoveInitializedTensor(weight_name);
updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()});
i++;
}
i = 0;
for (auto add_ptr : bias_add_node_ptrs) {
auto bias_name = add_ptr->MutableInputDefs()[1]->Name();
NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]);
graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg);
graph.RemoveInitializedTensor(bias_name);
updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()});
i++;
}
NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg);
graph.RemoveInitializedTensor(dense_weight_arg->Name());
updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});
// It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching.
std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(),
std::back_inserter(nodes_to_clear_shape),
[](Node* node_ptr) { return node_ptr != nullptr; });
// Change the constant for the reshape nodes.
for (auto x : reshape_node_ptrs) {
Node* node_ptr = x.first;
int64_t idx = x.second;
auto shape_arg = node_ptr->MutableInputDefs()[1];
const ONNX_NAMESPACE::TensorProto* tensor;
graph.GetInitializedTensor(shape_arg->Name(), tensor);
auto data_type = tensor->data_type();
auto init_const = onnxruntime::make_unique<Initializer>(*tensor, graph.ModelPath());
const int64_t* val = init_const->data<int64_t>();
int64_t size = init_const->size();
ONNX_NAMESPACE::TensorProto tensor_partition;
tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name()));
tensor_partition.set_data_type(data_type);
tensor_partition.add_dims(size);
std::vector<int64_t> val_partition;
val_partition.reserve(size);
val_partition.insert(val_partition.end(), val, val + size);
val_partition[idx] /= horizontal_parallel_size_;
tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t));
NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition);
graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition);
graph.RemoveInitializedTensor(shape_arg->Name());
}
if (dropout_node_ptr != nullptr) {
dropout_nodes_to_transform.insert(dropout_node_ptr);
}
// Add MegatronF before the 1st MatMul and MegatronG before the last Add.
NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0];
std::vector<Node*> new_consumer_nodes;
const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name());
for (auto& n : node_consumers) {
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
continue;
}
new_consumer_nodes.emplace_back(const_cast<Node*>(n));
}
bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0;
//then for q, and k&v will have different MegatronF node.
{
const std::vector<NodeArg*> sa_f_input_defs{prev_input_node_ptr};
auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto();
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info);
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"),
"MegatronF",
k_matmul_ptr->Name() + " BARTAttention MegatronF",
sa_f_input_defs,
{&sa_f_out_arg}, {}, kMSDomain);
sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
if (shared_same_input) {
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
}
new_consumer_nodes.push_back(&sa_f_node);
}
graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes);
counter++;
if (!shared_same_input) {
{
NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0];
std::vector<Node*> q_new_consumer_nodes;
const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name());
for (auto& n : q_node_consumers) {
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
continue;
}
q_new_consumer_nodes.emplace_back(const_cast<Node*>(n));
}
const std::vector<NodeArg*> q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]};
auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto();
auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info);
Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"),
"MegatronF",
q_matmul_ptr->Name() + " BARTAttention MegatronF",
q_sa_f_input_defs,
{&q_sa_f_out_arg}, {}, kMSDomain);
q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0]));
q_new_consumer_nodes.push_back(&q_sa_f_node);
graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes);
// todo: need update the consumer node for the input_node as well.
}
}
const std::vector<NodeArg*> sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]};
auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info);
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"),
"MegatronG",
"BARTAttention MegatronG",
sa_g_input_defs,
{&sa_g_out_arg}, {}, kMSDomain);
sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0);
modified = true;
}
return Status::OK();
}
Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
std::unordered_set<Node*>& self_attention_dropout_nodes) const {
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
@ -610,18 +1141,17 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g
}
// Only need to set the seed if it's a transformed self-attention dropout, or the seed attribute is not set.
if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end() ||
graph_utils::GetNodeAttribute(node, "seed") == nullptr) {
if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) {
int64_t seed = static_cast<int64_t>(HashName(node.MutableOutputDefs()[0]->Name())) + utils::GetRandomSeed();
if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end()) {
if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) {
seed += horizontal_parallel_rank_;
}
if (graph_utils::GetNodeAttribute(node, "seed") != nullptr) {
node.ClearAttribute("seed");
}
node.AddAttribute("seed", seed);
counter++;
modified = true;
}
}
@ -635,11 +1165,19 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le
}
std::vector<Node*> nodes_to_clear_shape;
std::unordered_set<Node*> self_attention_dropout_nodes;
std::unordered_set<Node*> dropout_nodes_to_transform;
ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape).IsOK());
ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, self_attention_dropout_nodes).IsOK());
ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, self_attention_dropout_nodes).IsOK());
int32_t partitioned_mlp_count = 0;
int32_t partitioned_bart_mlp_count = 0;
int32_t partitioned_attention_count = 0;
int32_t partitioned_bart_attention_count = 0;
int32_t dropout_changed = 0;
ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, partitioned_mlp_count).IsOK());
ORT_ENFORCE(TransformBARTMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_mlp_count).IsOK());
ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_attention_count).IsOK());
ORT_ENFORCE(TransformBARTSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_attention_count).IsOK());
ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, dropout_nodes_to_transform, dropout_changed).IsOK());
auto& graph_inputs = graph.GetInputs();
for (auto& node : nodes_to_clear_shape) {
@ -653,13 +1191,28 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le
output->ClearShape();
}
for (auto x : updated_weight_names_) {
auto old_initializer_name = x.first;
auto new_initializer_name = x.second;
if (weights_to_train_.find(old_initializer_name) != weights_to_train_.end()) {
weights_to_train_.erase(old_initializer_name);
weights_to_train_.insert(new_initializer_name);
}
}
if (modified) {
graph.SetGraphResolveNeeded();
auto ret = graph.Resolve();
LOGS_DEFAULT(WARNING) << "Megatron transformer result : Partitioned " << partitioned_mlp_count << " MLP Blocks, "
<< partitioned_bart_mlp_count << " BART MLP Blocks, " << partitioned_attention_count << " Attention Blocks, "
<< partitioned_bart_attention_count << " BART Attention Blocks; Reset seed for " << dropout_changed
<< " Dropout nodes. Error Message: " << ret.ErrorMessage() << std::endl;
ORT_ENFORCE(ret.IsOK());
} else {
LOGS_DEFAULT(WARNING) << "Megatron transformer result : unmodified\n";
}
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -11,33 +11,50 @@ namespace onnxruntime {
class MegatronTransformer : public GraphTransformer {
public:
MegatronTransformer(int32_t horizontal_parallel_rank, int32_t horizontal_parallel_size,
std::unordered_map<std::string, std::string>& updated_weight_names,
std::unordered_set<std::string>& weights_to_train,
const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("MegatronTransformer", compatible_execution_providers),
horizontal_parallel_rank_(horizontal_parallel_rank),
horizontal_parallel_size_(horizontal_parallel_size) {}
horizontal_parallel_size_(horizontal_parallel_size),
updated_weight_names_(updated_weight_names),
weights_to_train_(weights_to_train) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const override;
private:
Status TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape) const;
std::vector<Node*>& nodes_to_clear_shape,
int32_t& counter) const;
Status TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& self_attention_dropout_nodes) const;
std::unordered_set<Node*>& dropout_nodes_to_transform,
int32_t& counter) const;
Status TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
Status TransformBARTMLP(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
Status TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
std::unordered_set<Node*>& self_attention_dropout_nodes) const;
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg,
ONNX_NAMESPACE::TensorProto& initializer_partition, int stride = 1) const;
ONNX_NAMESPACE::TensorProto& initializer_partition,
int stride = 1) const;
bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg,
ONNX_NAMESPACE::TensorProto& initializer_partition) const;
bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition) const;
const int32_t horizontal_parallel_rank_;
const int32_t horizontal_parallel_size_;
std::unordered_map<std::string, std::string>& updated_weight_names_;
std::unordered_set<std::string>& weights_to_train_;
};
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -43,24 +43,36 @@ Status SetupOptimizerParams(
const optional<std::string>& loss_scale_input_name,
const TrainingSession::TrainingConfiguration& config,
OptimizerGraphConfig& opt_graph_config_result,
std::unordered_map<std::string, OptimizerNodeConfig>& opt_node_configs_result) {
std::unordered_map<std::string, OptimizerNodeConfig>& opt_node_configs_result,
std::unordered_map<std::string, std::string>& weight_name_map_after_graph_transform) {
ORT_RETURN_IF_NOT(config.optimizer_config.has_value());
const auto& optimizer_config = config.optimizer_config.value();
// This is the mapping from the new weight name to the original weight name
// It is required to look up the optimizer config for the original weight
// passed in the training session config
std::unordered_map<std::string, std::string> reversed_weight_names_map;
for (auto& p : weight_name_map_after_graph_transform) {
reversed_weight_names_map.insert({p.second, p.first});
}
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
for (const auto& weight_name : weight_names_to_train) {
OptimizerNodeConfig opt_node_config{};
opt_node_config.name = optimizer_config.name;
opt_node_config.lr_feed_name = optimizer_config.learning_rate_input_name;
std::string original_weight_name = weight_name;
if (reversed_weight_names_map.find(original_weight_name) != reversed_weight_names_map.end()) {
original_weight_name = reversed_weight_names_map.at(original_weight_name);
}
try {
opt_node_config.attributes = optimizer_config.weight_attributes_generator(weight_name);
opt_node_config.attributes = optimizer_config.weight_attributes_generator(original_weight_name);
} catch (const std::exception& ex) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
}
try {
opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(weight_name);
opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(original_weight_name);
} catch (const std::exception& ex) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
}
@ -232,7 +244,8 @@ Status TrainingSession::ConfigureForTraining(
}
}
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config,
config_result));
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
@ -244,8 +257,14 @@ Status TrainingSession::ConfigureForTraining(
!filtered_config_weight_names_to_train.empty()
? filtered_config_weight_names_to_train
: GetTrainableModelInitializers(config.immutable_weights, loss_name);
for (const auto& weight_name_to_not_train : config.weight_names_to_not_train) {
weight_names_to_train.erase(weight_name_to_not_train);
if (config_result.weight_name_map_after_graph_transform.find(weight_name_to_not_train) !=
config_result.weight_name_map_after_graph_transform.end()) {
weight_names_to_train.erase(config_result.weight_name_map_after_graph_transform.at(weight_name_to_not_train));
} else {
weight_names_to_train.erase(weight_name_to_not_train);
}
}
{
@ -316,7 +335,7 @@ Status TrainingSession::ConfigureForTraining(
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg,
loss_scale_input_name, config, opt_graph_config, opt_node_configs));
loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform));
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
ORT_RETURN_IF_ERROR(BuildOptimizer(
opt_graph_config, opt_node_configs,
@ -512,8 +531,9 @@ static Status AddGradientAccumulationNodes(Graph& graph,
return GraphAugmenter::AugmentGraph(graph, graph_defs);
}
Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config) {
Status TrainingSession::ApplyTransformationsToMainGraph(std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TrainingConfigurationResult& config_result_out) {
GraphTransformerManager graph_transformation_mgr{2};
// TODO: ideally we can just reuse the CPU EP registered with the session, but in the training session case
// the EPs are registered after ConfigureForTraining and before Initialize is called. Hence we don't have access
@ -522,7 +542,7 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set
// Create execution frame for executing constant nodes.
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config);
AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config, config_result_out);
// apply transformers
Graph& graph = model_->MainGraph();
@ -536,15 +556,16 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set
// Registers all the pre transformers with transformer manager
void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execution_provider,
GraphTransformerManager& transformer_manager,
const std::unordered_set<std::string>& weights_to_train,
std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TrainingConfigurationResult& config_result_out,
TransformerLevel graph_optimization_level,
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, weights_to_train, config, execution_provider, custom_list);
level, weights_to_train, config, execution_provider, config_result_out.weight_name_map_after_graph_transform, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}

View file

@ -245,6 +245,9 @@ class TrainingSession : public InferenceSession {
// The pipeline configuration output.
// This is only set if an pipeline is enabled.
optional<PipelineConfigurationResult> pipeline_config_result;
// Mapped initialized names after weight partitioning for example MegatronTransformer
std::unordered_map<std::string, std::string> weight_name_map_after_graph_transform{};
};
/**
@ -392,14 +395,16 @@ class TrainingSession : public InferenceSession {
common::Status InsertPipelineOps(const std::unordered_set<std::string>& initializer_names_to_preserve,
pipeline::PipelineTensorNames& pipeline_tensor_names);
common::Status ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config);
common::Status ApplyTransformationsToMainGraph(std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TrainingConfigurationResult& config_result_out);
/** configure initial transformers for training */
void AddPreTrainingTransformers(const IExecutionProvider& execution_provider, // for constant folding
GraphTransformerManager& transformer_manager,
const std::unordered_set<std::string>& weights_to_train,
std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TrainingConfigurationResult& config_result_out,
TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel,
const std::vector<std::string>& custom_list = {});

View file

@ -43,6 +43,7 @@ struct TrainingParameters {
int gradient_accumulation_steps = 1;
int data_parallel_size = 1;
int horizontal_parallel_size = 1;
int pipeline_parallel_size = 1;
int deepspeed_zero_stage = 0;
bool enable_grad_norm_clip = true;
bool set_gradients_as_graph_outputs = false;
@ -65,11 +66,16 @@ TrainingConfigurationResult ConfigureSessionForTraining(
//TODO tix, refactor the mpi related code to populate all fields correctly by default.
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size);
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size);
if (parameters.world_size % parameters.data_parallel_size != 0) {
throw std::runtime_error("Cannot split data parallel group because world_size is not divisible");
}
if (parameters.world_size % parameters.horizontal_parallel_size != 0) {
throw std::runtime_error("Cannot split horizontal parallel group because world_size is not divisible");
}
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
auto data_group_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size);
if (data_group_size != parameters.data_parallel_size) {
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
<< data_group_size;
@ -201,7 +207,10 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers);
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers)
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size);
#if defined(USE_MPI)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });

View file

@ -163,7 +163,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) {
ASSERT_TRUE(ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
std::unordered_map<std::string, std::string> updated_weight_names;
std::unordered_set<std::string> weights_to_train;
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
@ -231,7 +233,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
std::unordered_map<std::string, std::string> updated_weight_names;
std::unordered_set<std::string> weights_to_train;
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
@ -298,7 +302,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) {
ASSERT_TRUE(ret.IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
std::unordered_map<std::string, std::string> updated_weight_names;
std::unordered_set<std::string> weights_to_train;
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
@ -363,7 +369,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
std::unordered_map<std::string, std::string> updated_weight_names;
std::unordered_set<std::string> weights_to_train;
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
@ -449,27 +457,35 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
// We only tested on CUDA run.
#if defined(USE_CUDA)
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
const int total_rank = 4;
static void RunPartitionCorrectnessTest(std::string model_path,
const logging::Logger& logger,
const int total_rank,
std::vector<std::string> input_names,
std::vector<std::vector<int64_t>> input_dims) {
const PathString model_uri = ToPathString(model_path) + ORT_TSTR(".onnx");
// const int total_rank = 4;
std::vector<Graph*> graphs;
std::vector<std::shared_ptr<Model>> p_models(total_rank);
for (auto i = 0; i < total_rank; i++) {
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
ASSERT_TRUE(ret.IsOK());
Status ret = Model::Load(model_uri, p_models[i], nullptr, logger);
ORT_ENFORCE(ret.IsOK());
Graph& graph = p_models[i]->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
std::unordered_map<std::string, std::string> updated_weight_names;
std::unordered_set<std::string> weights_to_train;
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank, updated_weight_names, weights_to_train), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger);
ORT_ENFORCE(ret.IsOK());
graphs.push_back(&graph);
auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx");
Model::Save(*p_models[i], model_uri2);
}
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_);
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, logger);
auto& combine_graph = combine_model.MainGraph();
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
ORT_ENFORCE(ret.IsOK());
auto model_uri2 = "mlp_megatron_basic_test_partition_combine.onnx";
auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx");
Model::Save(combine_model, model_uri2);
float scale = 1.f;
@ -479,10 +495,18 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
std::default_random_engine generator{gsl::narrow_cast<uint32_t>(seed)};
std::normal_distribution<float> distribution{mean, scale};
std::vector<int64_t> dims_X = {8, 16, 4};
std::vector<float> values_X(TensorShape(dims_X).Size());
std::for_each(values_X.begin(), values_X.end(),
[&generator, &distribution](float& value) { value = distribution(generator); });
ORT_ENFORCE(input_names.size() == input_dims.size());
NameMLValMap feeds;
for(size_t i = 0; i< input_dims.size();i++){
std::vector<int64_t> dims_X = input_dims[i];
std::vector<float> values_X(TensorShape(dims_X).Size());
std::for_each(values_X.begin(), values_X.end(),
[&generator, &distribution](float& value) { value = distribution(generator); });
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
feeds.insert(std::make_pair(input_names[i], ml_value));
}
std::vector<OrtValue> expected_ort_values;
{
@ -497,11 +521,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("input", ml_value));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("output");
@ -527,11 +546,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("input", ml_value));
// prepare outputs
std::vector<std::string> output_names;
for (auto i = 0; i < total_rank; i++) {
@ -554,137 +568,21 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
}
}
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
RunPartitionCorrectnessTest("testdata/transform/model_parallel/mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}});
}
TEST_F(GraphTransformationTests, MegatronBARTMLPPartitionCorrectnessTest) {
RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}});
}
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
std::vector<Graph*> graphs;
std::vector<std::shared_ptr<Model>> p_models(total_rank);
for (auto i = 0; i < total_rank; i++) {
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
ASSERT_TRUE(ret.IsOK());
Graph& graph = p_models[i]->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
graphs.push_back(&graph);
}
// Dropout seed checking.
const AttributeProto* attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout1"), "seed");
ORT_ENFORCE(attr != nullptr && attr->has_i());
int64_t dropout1_rank0_seed = attr->i();
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout2"), "seed");
ORT_ENFORCE(attr != nullptr && attr->has_i());
int64_t dropout2_rank0_seed = attr->i();
for (auto i = 1; i < total_rank; i++) {
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout1"), "seed");
ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout1_rank0_seed + i);
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout2"), "seed");
ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed);
}
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_);
auto& combine_graph = combine_model.MainGraph();
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
ORT_ENFORCE(ret.IsOK());
auto model_uri2 = "self_attention_megatron_basic_test_partition_combine.onnx";
Model::Save(combine_model, model_uri2);
float scale = 1.f;
float mean = 0.f;
float seed = 123.f;
std::default_random_engine generator{gsl::narrow_cast<uint32_t>(seed)};
std::normal_distribution<float> distribution{mean, scale};
std::vector<int64_t> dims_X = {8, 16, 4};
std::vector<float> values_X(TensorShape(dims_X).Size());
std::for_each(values_X.begin(), values_X.end(),
[&generator, &distribution](float& value) { value = distribution(generator); });
std::vector<int64_t> dims_Mask = {8, 1, 16, 16};
std::vector<float> values_Mask(TensorShape(dims_Mask).Size());
std::for_each(values_Mask.begin(), values_Mask.end(),
[&generator, &distribution](float& value) { value = distribution(generator); });
std::vector<OrtValue> expected_ort_values;
{
SessionOptions so;
so.session_logid = "RawGraphRun";
InferenceSession session_object{so, GetEnvironment()};
std::unique_ptr<IExecutionProvider> execution_provider = DefaultCudaExecutionProvider();
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Status st;
ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
NameMLValMap feeds;
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
feeds.insert(std::make_pair("input", ml_value));
OrtValue mask_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value);
feeds.insert(std::make_pair("mask", mask_value));
// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("output");
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
EXPECT_TRUE(st.IsOK());
}
std::vector<OrtValue> actual_ort_values;
{
SessionOptions so;
so.session_logid = "SplitThenCombineRun";
InferenceSession session_object{so, GetEnvironment()};
std::unique_ptr<IExecutionProvider> execution_provider = DefaultCudaExecutionProvider();
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Status st;
ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
NameMLValMap feeds;
OrtValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
feeds.insert(std::make_pair("input", ml_value));
OrtValue mask_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value);
feeds.insert(std::make_pair("mask", mask_value));
// prepare outputs
std::vector<std::string> output_names;
for (auto i = 0; i < total_rank; i++) {
output_names.push_back("output_rank_" + std::to_string(i));
}
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
EXPECT_TRUE(st.IsOK());
}
auto& expected_val = expected_ort_values[0].Get<Tensor>();
for (auto i = 0; i < total_rank; i++) {
auto& actual_val = actual_ort_values[i].Get<Tensor>();
horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, true);
horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, false);
}
RunPartitionCorrectnessTest("testdata/transform/model_parallel/self_attention_megatron_basic_test", *logger_, 2, {"input", "mask"}, {{8, 16, 4}, {8, 1, 16, 16}});
}
TEST_F(GraphTransformationTests, MegatronBARTSelfAttentionPartitionCorrectnessTest) {
RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_self_attention_megatron_basic_test", *logger_, 2, {"input"}, {{6, 8, 4}});
}
// end of USE_CUDA
#endif