mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add megatron transforms for BART (#5521)
* Large model export and run ORT Python support * Megatron change refine a bit workaround self attention issue use partitioned name for weights when megatron model parallel is enabled Fix Megatron Transformer Issue (cuased by the renaming) Add UTs for T5 model parallel Fix megatron seed issue fix log a bit checkkpointing changes + rebase Unintended reshape transform change t5 layer norm changes add t5 layer norm kernel use template for t5 layer norm template definition changes no build error add CPU cuda kernel first unit test other forward unit tests add T5LayerNormGrad Add c++ transform and test for T5 LN minor fix BART MLP Megatron tranform Add concat slice transform + test Cosmetic improvements in concat slice transform Constant folding bug fix + megatron attention transform for BART Undo unnecessary changes * Cleanup * Remove unnecessary changes * Cleanup megatron * Windows build * Add self attention test graph * Correcting transforms + cleanup * review comments * review comments * fix build and test failures * Fix CI * fix windows CI Co-authored-by: Peng Wang <pengwa@microsoft.com> Co-authored-by: Aishwarya <aibhanda@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
a14cd6267b
commit
5aec34500d
13 changed files with 979 additions and 216 deletions
|
|
@ -95,8 +95,12 @@ TEST(CseTests, SimpleTestTraining) {
|
|||
.IsOK());
|
||||
|
||||
GraphTransformerManager graph_transformation_mgr(1);
|
||||
// need to declare variables to avoid build error after making
|
||||
// weights_to_train and updated_weight_names as non-const
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
auto transformers_to_register = onnxruntime::training::transformer_utils::GeneratePreTrainingTransformers(
|
||||
TransformerLevel::Level1, {}, {}, CPUExecutionProvider(CPUExecutionProviderInfo()));
|
||||
TransformerLevel::Level1, weights_to_train, {}, CPUExecutionProvider(CPUExecutionProviderInfo()), updated_weight_names);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
ASSERT_TRUE(
|
||||
graph_transformation_mgr.Register(std::move(entry), TransformerLevel::Level1).IsOK());
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.onnx
vendored
Normal file
Binary file not shown.
105
onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py
vendored
Normal file
105
onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py
vendored
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
hidden_size = 4
|
||||
weight_dim_to_split = 16
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", hidden_size])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", hidden_size])
|
||||
|
||||
a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((weight_dim_to_split, hidden_size))
|
||||
a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.weight")
|
||||
|
||||
a_bias_np_vals = (0.01 * np.arange(weight_dim_to_split, dtype=np.float32)) # weight_dim_to_split numbers in total
|
||||
a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.bias")
|
||||
|
||||
dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(())
|
||||
dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio")
|
||||
|
||||
dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(())
|
||||
dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode")
|
||||
|
||||
b_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((hidden_size, weight_dim_to_split))
|
||||
b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.weight")
|
||||
|
||||
b_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) # hidden_size numbers in total
|
||||
b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.bias")
|
||||
|
||||
transpose1 = helper.make_node('Transpose', [a_weight_initializer.name], ['transpose1'], name='transpose1', perm=[1,0])
|
||||
transpose2 = helper.make_node('Transpose', [b_weight_initializer.name], ['transpose2'], name='transpose2', perm=[1,0])
|
||||
matmul = helper.make_node(
|
||||
'MatMul', # node name
|
||||
['input', 'transpose1'], # inputs
|
||||
['matmul'], # outputs
|
||||
name="matmul"
|
||||
)
|
||||
|
||||
biasgelu = helper.make_node(
|
||||
'BiasGelu', # node name
|
||||
['matmul', a_bias_initializer.name], # inputs
|
||||
['biasgelu'], # outputs
|
||||
name="biasgelu",
|
||||
domain="com.microsoft"
|
||||
)
|
||||
|
||||
dropout1 = helper.make_node('Dropout',
|
||||
["biasgelu", dropout_initializer.name, dropout_mode_initializer.name],
|
||||
['dropout1', "dropout1_mask"],
|
||||
name='dropout1')
|
||||
|
||||
matmul2 = helper.make_node(
|
||||
'MatMul', # node name
|
||||
['dropout1', "transpose2"], # inputs
|
||||
['matmul2'], # outputs
|
||||
name="matmul2"
|
||||
)
|
||||
|
||||
add = helper.make_node(
|
||||
'Add', # node name
|
||||
['matmul2', b_bias_initializer.name], # inputs
|
||||
['add'], # outputs
|
||||
name="add"
|
||||
)
|
||||
|
||||
dropout2 = helper.make_node('Dropout',
|
||||
["add", dropout_initializer.name, dropout_mode_initializer.name],
|
||||
['dropout2', "dropout2_mask"],
|
||||
name='dropout2')
|
||||
|
||||
identity = helper.make_node(
|
||||
'Identity', # node name
|
||||
['dropout2'], # inputs
|
||||
['output'], # outputs
|
||||
name="identity"
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[transpose1, transpose2, matmul, biasgelu, dropout1, matmul2, add, dropout2, identity],
|
||||
'test-model',
|
||||
[X],
|
||||
[Y],
|
||||
[a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, dropout_initializer, dropout_mode_initializer]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs["opset_imports"] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
|
||||
onnx.save(model_def, 'bart_mlp_megatron_basic_test.onnx')
|
||||
BIN
onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.onnx
vendored
Normal file
Binary file not shown.
150
onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py
vendored
Normal file
150
onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py
vendored
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
batch = 6
|
||||
hidden_size = 4
|
||||
attention_head = 2
|
||||
hidden_per_attention = 2
|
||||
|
||||
relative_attention_num_buckets=32
|
||||
input_len=8
|
||||
output_len=8
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [batch, input_len, hidden_size])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [output_len, batch, hidden_size])
|
||||
|
||||
q_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
|
||||
q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, 'encoder.layers.0.self_attn.q_proj.weight')
|
||||
|
||||
k_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
|
||||
k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, 'encoder.layers.0.self_attn.k_proj.weight')
|
||||
|
||||
v_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
|
||||
v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, 'encoder.layers.0.self_attn.v_proj.weight')
|
||||
|
||||
q_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
|
||||
q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, 'encoder.layers.0.self_attn.q_proj.bias')
|
||||
|
||||
k_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
|
||||
k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, 'encoder.layers.0.self_attn.k_proj.bias')
|
||||
|
||||
v_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
|
||||
v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, 'encoder.layers.0.self_attn.v_proj.bias')
|
||||
|
||||
q_shape_initializer = numpy_helper.from_array(np.asarray([input_len, batch*attention_head , hidden_per_attention], dtype=np.int64), 'q_shape')
|
||||
k_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'k_shape')
|
||||
v_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'v_shape')
|
||||
|
||||
mul_np_vals = np.asarray([0.1767766922712326], dtype=np.float32).reshape(())
|
||||
mul_initializer = numpy_helper.from_array(mul_np_vals, "mul_const")
|
||||
|
||||
qk_shape_initializer = numpy_helper.from_array(np.asarray([batch, attention_head , input_len, input_len], dtype=np.int64), 'qk_shape')
|
||||
|
||||
dummy_condition_initializer = numpy_helper.from_array(np.zeros((batch, input_len), dtype=bool), 'dummy_cond')
|
||||
inf_const_initializer = numpy_helper.from_array(np.asarray([-np.inf], dtype=np.float32), 'inf_const')
|
||||
|
||||
where_shape_initializer = numpy_helper.from_array(np.asarray([batch*attention_head , input_len, input_len], dtype=np.int64), 'where_shape')
|
||||
|
||||
dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(())
|
||||
dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio")
|
||||
|
||||
dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(())
|
||||
dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode")
|
||||
|
||||
shape_initializer3 = numpy_helper.from_array(np.array([input_len, batch, attention_head * hidden_per_attention], dtype=np.int64), 'concat_shape_3')
|
||||
|
||||
dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size))
|
||||
dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, 'encoder.layers.0.self_attn.out_proj.weight')
|
||||
|
||||
dense_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32))
|
||||
dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, 'encoder.layers.0.self_attn.out_proj.bias')
|
||||
|
||||
|
||||
transpose_ip = helper.make_node('Transpose', ['input'], ['transpose_ip'], name='transpose_ip', perm=[1,0,2])
|
||||
|
||||
transpose_q = helper.make_node('Transpose', [q_weight_initializer.name], ['transpose_q'], name='transpose_q', perm=[1,0])
|
||||
transpose_k = helper.make_node('Transpose', [k_weight_initializer.name], ['transpose_k'], name='transpose_k', perm=[1,0])
|
||||
transpose_v = helper.make_node('Transpose', [v_weight_initializer.name], ['transpose_v'], name='transpose_v', perm=[1,0])
|
||||
|
||||
matmul_q = helper.make_node('MatMul', ['transpose_ip', 'transpose_q'], ['matmul_q'], name='matmul_q')
|
||||
matmul_k = helper.make_node('MatMul', ['transpose_ip', 'transpose_k'], ['matmul_k'], name='matmul_k')
|
||||
matmul_v = helper.make_node('MatMul', ['transpose_ip', 'transpose_v'], ['matmul_v'], name='matmul_v')
|
||||
|
||||
|
||||
add_q = helper.make_node('Add', ['matmul_q', q_bias_initializer.name], ['add_q'], name='add_q')
|
||||
add_k = helper.make_node('Add', ['matmul_k', k_bias_initializer.name], ['add_k'], name='add_k')
|
||||
add_v = helper.make_node('Add', ['matmul_v', v_bias_initializer.name], ['add_v'], name='add_v')
|
||||
|
||||
mul_q = helper.make_node('Mul', ['add_q' , 'mul_const'], ['mul_q'], name='mul_q')
|
||||
|
||||
reshape_q = helper.make_node('Reshape', ['mul_q', q_shape_initializer.name], ['reshape_q'], name='reshape_q')
|
||||
reshape_k = helper.make_node('Reshape', ['add_k', k_shape_initializer.name], ['reshape_k'], name='reshape_k')
|
||||
reshape_v = helper.make_node('Reshape', ['add_v', v_shape_initializer.name], ['reshape_v'], name='reshape_v')
|
||||
|
||||
transpose_q2 = helper.make_node('Transpose', ['reshape_q'], ['transpose_q2'], name='transpose_q2', perm=[1,0,2])
|
||||
transpose_k2 = helper.make_node('Transpose', ['reshape_k'], ['transpose_k2'], name='transpose_k2', perm=[1,2,0])
|
||||
transpose_v2 = helper.make_node('Transpose', ['reshape_v'], ['transpose_v2'], name='transpose_v2', perm=[1,0,2])
|
||||
|
||||
matmul = helper.make_node('MatMul', ['transpose_q2', 'transpose_k2'], ['matmul'], name='matmul')
|
||||
reshape_qk = helper.make_node("Reshape", ['matmul', qk_shape_initializer.name], ['reshape_qk'], name='reshape_qk')
|
||||
|
||||
|
||||
unsqueeze = helper.make_node('Unsqueeze', [dummy_condition_initializer.name],['unsqueeze_cond'], axes=[1,2], name='unsqueeze_cond')
|
||||
where = helper.make_node('Where', ['unsqueeze_cond', inf_const_initializer.name, 'reshape_qk'], ['where'], name='where')
|
||||
|
||||
reshape_where = helper.make_node("Reshape", ['where', where_shape_initializer.name], ['reshape_where'], name='reshape_where')
|
||||
|
||||
softmax = helper.make_node('Softmax', ['reshape_where'], ['softmax'], name='softmax', axis=2)
|
||||
dropout1 = helper.make_node('Dropout',
|
||||
["softmax", dropout_initializer.name, dropout_mode_initializer.name],
|
||||
['dropout1', "dropout1_mask"],
|
||||
name='dropout1')
|
||||
|
||||
matmul2 = helper.make_node('MatMul', ['dropout1', 'transpose_v2'], ['matmul2'], name='matmul2')
|
||||
transpose = helper.make_node('Transpose', ['matmul2'], ['transpose'], name='transpose', perm=[1,0,2])
|
||||
reshape = helper.make_node('Reshape', ['transpose', shape_initializer3.name], ['reshape'], name='reshape')
|
||||
|
||||
transpose_o_weight = helper.make_node('Transpose', [dense_weight_initializer.name], ['transpose_o_weight'], name='transpose_o_weight', perm=[1,0])
|
||||
matmul3 = helper.make_node('MatMul', ['reshape', 'transpose_o_weight'], ['matmul3'], name='matmul3')
|
||||
add3 = helper.make_node('Add', ['matmul3', dense_bias_initializer.name], ['add3'], name='add3')
|
||||
identity = helper.make_node('Identity', ['add3'], ['output'], name='identity')
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[transpose_ip,transpose_q,transpose_k,transpose_v,matmul_q,matmul_k,matmul_v,add_q,add_k,add_v,
|
||||
mul_q,reshape_q,reshape_k,reshape_v,transpose_q2,transpose_k2,transpose_v2,matmul,reshape_qk,
|
||||
unsqueeze,where,reshape_where,softmax,dropout1,matmul2,transpose,reshape,transpose_o_weight,
|
||||
matmul3,add3,identity],
|
||||
'self-attention-megatron-test-model',
|
||||
[X],
|
||||
[Y],
|
||||
[q_weight_initializer,k_weight_initializer,v_weight_initializer,q_bias_initializer,k_bias_initializer,
|
||||
v_bias_initializer,q_shape_initializer,k_shape_initializer,v_shape_initializer,mul_initializer,
|
||||
qk_shape_initializer,dummy_condition_initializer,inf_const_initializer,where_shape_initializer,
|
||||
dropout_initializer,dropout_mode_initializer,shape_initializer3,
|
||||
dense_weight_initializer, dense_bias_initializer]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'bart_self_attention_megatron_basic_test.onnx')
|
||||
|
||||
|
||||
|
|
@ -52,9 +52,10 @@ namespace transformer_utils {
|
|||
|
||||
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
||||
TransformerLevel level,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
const IExecutionProvider& execution_provider,
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names,
|
||||
const std::vector<std::string>& transformers_and_rules_to_enable) {
|
||||
std::vector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
|
||||
|
|
@ -94,7 +95,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
if (config.enable_gelu_approximation) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(compatible_eps));
|
||||
}
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<ConstantFolding>(execution_provider, compatible_eps, weights_to_train));
|
||||
transformers.emplace_back(onnxruntime::make_unique<ReshapeFusion>(compatible_eps));
|
||||
auto horizontal_parallel_size = training::DistributedRunContext::GroupSize(training::WorkerGroupType::HorizontalParallel);
|
||||
|
|
@ -102,7 +102,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
LOGS_DEFAULT(WARNING) << horizontal_parallel_size << "-way horizontal model parallel is enabled";
|
||||
transformers.emplace_back(onnxruntime::make_unique<MegatronTransformer>(
|
||||
training::DistributedRunContext::RankInGroup(training::WorkerGroupType::HorizontalParallel),
|
||||
horizontal_parallel_size, compatible_eps));
|
||||
horizontal_parallel_size, updated_weight_names, weights_to_train, compatible_eps));
|
||||
}
|
||||
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ namespace transformer_utils {
|
|||
/** Generates all pre-training transformers for this level. */
|
||||
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
||||
TransformerLevel level,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
const IExecutionProvider& execution_provider, // required for constant folding
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names,
|
||||
const std::vector<std::string>& rules_and_transformers_to_enable = {});
|
||||
|
||||
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_13 = {1
|
|||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_11_13 = {1, 11, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v2_11_13 = {2, 11, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v5_13 = {5, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v1_6_7_13 = {1, 6, 7, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v7_13 = {7, 13};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9 = {9};
|
||||
const std::initializer_list<ONNX_NAMESPACE::OperatorSetVersion> opset_v9_13 = {9, 13};
|
||||
|
|
@ -42,11 +43,12 @@ const OpInfo reshape_info = OpInfo("Reshape", opset_v5_13);
|
|||
const OpInfo transpose_info = OpInfo("Transpose", opset_v1_13);
|
||||
const OpInfo matmul_info = OpInfo("MatMul", opset_v9_13);
|
||||
const OpInfo div_info = OpInfo("Div", opset_v7_13);
|
||||
const OpInfo mul_info = OpInfo("Mul", opset_v7_13);
|
||||
const OpInfo mul_info = OpInfo("Mul", opset_v1_6_7_13);
|
||||
const OpInfo sub_info = OpInfo("Sub", opset_v7_13);
|
||||
const OpInfo softmax_info = OpInfo("Softmax", opset_v1_11_13);
|
||||
const OpInfo trainable_dropout_info = OpInfo("TrainableDropout", opset_v9, kOnnxDomain);
|
||||
const OpInfo dropout_info = OpInfo("Dropout", opset_v12_13);
|
||||
const OpInfo where_info = OpInfo("Where", opset_v9);
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo(const std::vector<OpInfo>& op_infos,
|
||||
|
|
@ -119,7 +121,6 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node
|
|||
LOGS_DEFAULT(WARNING) << "PartitionWeightByColumn: " << input_arg.Name() << " is not an initializer";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto data_type = tensor_proto->data_type();
|
||||
const ONNX_NAMESPACE::TensorShapeProto* shape = input_arg.Shape();
|
||||
int rank = shape->dim_size();
|
||||
|
|
@ -147,8 +148,9 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node
|
|||
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
|
||||
const float* a_weight = initializer->data<float>();
|
||||
|
||||
initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) +
|
||||
"_" + input_arg.Name() + "_partition");
|
||||
std::string new_initializer_name = input_arg.Name() + "_column_rank_" + std::to_string(horizontal_parallel_rank_);
|
||||
|
||||
initializer_partition.set_name(new_initializer_name);
|
||||
initializer_partition.set_data_type(data_type);
|
||||
|
||||
int64_t column_partition = column_count / horizontal_parallel_size_;
|
||||
|
|
@ -209,12 +211,12 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg
|
|||
<< horizontal_parallel_size_ << ", not supported currently.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
|
||||
const float* a_weight = initializer->data<float>();
|
||||
|
||||
initializer_partition.set_name("rank_" + std::to_string(horizontal_parallel_rank_) +
|
||||
"_" + input_arg.Name() + "_partition");
|
||||
std::string new_initializer_name = input_arg.Name() + "_row_rank_" + std::to_string(horizontal_parallel_rank_);
|
||||
|
||||
initializer_partition.set_name(new_initializer_name);
|
||||
initializer_partition.set_data_type(data_type);
|
||||
|
||||
int64_t row_partition = row_count / horizontal_parallel_size_;
|
||||
|
|
@ -231,13 +233,13 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg
|
|||
const int64_t row_index_offset = horizontal_parallel_rank_ * row_partition;
|
||||
memcpy(result.data(), a_weight + row_index_offset * column_count, sizeof(float) * element_count);
|
||||
initializer_partition.set_raw_data(result.data(), element_count * sizeof(float));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape) const {
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
int32_t& counter) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto node_index : node_topology_list) {
|
||||
|
|
@ -309,12 +311,15 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
|
||||
NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg);
|
||||
updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()});
|
||||
|
||||
NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg);
|
||||
updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()});
|
||||
|
||||
NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg);
|
||||
updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()});
|
||||
|
||||
graph.RemoveInitializedTensor(a_weight_arg->Name());
|
||||
graph.RemoveInitializedTensor(b_weight_arg->Name());
|
||||
|
|
@ -349,6 +354,151 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0);
|
||||
modified = true;
|
||||
counter++;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*
|
||||
DenseWeight -- Transpose \
|
||||
MatMul -- BiasGelu -- Dropout -- MatMul -- Add -- Dropout
|
||||
*/
|
||||
Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
Node* second_op = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
|
||||
Node* first_op = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
|
||||
if (node.GetInputEdgesCount() > 0) {
|
||||
if (second_op == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (first_op != nullptr && first_op->OpType().compare("MegatronF") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (second_op->OpType().compare("Transpose") != 0) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
// check if transpose is only 2-dim
|
||||
if (!optimizer_utils::IsAttributeWithExpectedValues(*second_op, "perm", {1LL, 0LL})) {
|
||||
continue;
|
||||
}
|
||||
ProviderType provider_type = node.GetExecutionProviderType();
|
||||
|
||||
Node* biasgelu_node_ptr = graph.GetNode(node.OutputNodesBegin()->Index());
|
||||
Node& biasgelu_node = *biasgelu_node_ptr;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(biasgelu_node, "BiasGelu", {1}, kMSDomain) ||
|
||||
biasgelu_node.GetExecutionProviderType() != provider_type ||
|
||||
biasgelu_node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
Node& dropout_node = *graph.GetNode(biasgelu_node.OutputNodesBegin()->Index());
|
||||
if (!IsExpectedOpAndProvider(dropout_node, dropout_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
Node& matmul2_node = *graph.GetNode(dropout_node.OutputNodesBegin()->Index());
|
||||
if (!IsExpectedOpAndProvider(matmul2_node, matmul_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
Node& add_node = *graph.GetNode(matmul2_node.OutputNodesBegin()->Index());
|
||||
if (!IsExpectedOpAndProvider(add_node, add_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
Node& dropout2_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
|
||||
if (!IsExpectedOpAndProvider(dropout2_node, dropout_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
Node* transpose_op_ptr = const_cast<Node*>(graph.GetProducerNode(matmul2_node.MutableInputDefs()[1]->Name()));
|
||||
if (transpose_op_ptr == nullptr || !IsExpectedOpAndProvider(*transpose_op_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
nodes_to_clear_shape.insert(nodes_to_clear_shape.end(), {&node, second_op, &biasgelu_node, &dropout_node,
|
||||
&matmul2_node, transpose_op_ptr});
|
||||
|
||||
auto dense_wi_weight_arg = second_op->MutableInputDefs()[0];
|
||||
ONNX_NAMESPACE::TensorProto dense_wi_weight_initializer_partition;
|
||||
if (!PartitionWeightByRow(graph, *dense_wi_weight_arg, dense_wi_weight_initializer_partition)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//since the bias doesnt get transposed, partitioning by col
|
||||
auto dense_wi_bias_arg = biasgelu_node.MutableInputDefs()[1];
|
||||
ONNX_NAMESPACE::TensorProto dense_wi_bias_initializer_partition;
|
||||
if (!PartitionWeightByColumn(graph, *dense_wi_bias_arg, dense_wi_bias_initializer_partition)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dense_wo_weight_arg = transpose_op_ptr->MutableInputDefs()[0];
|
||||
ONNX_NAMESPACE::TensorProto dense_wo_weight_initializer_partition;
|
||||
if (!PartitionWeightByColumn(graph, *dense_wo_weight_arg, dense_wo_weight_initializer_partition)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg);
|
||||
updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()});
|
||||
|
||||
NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg);
|
||||
updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()});
|
||||
|
||||
NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg);
|
||||
updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()});
|
||||
|
||||
graph.RemoveInitializedTensor(dense_wi_weight_arg->Name());
|
||||
graph.RemoveInitializedTensor(dense_wi_bias_arg->Name());
|
||||
graph.RemoveInitializedTensor(dense_wo_weight_arg->Name());
|
||||
|
||||
dropout_nodes_to_transform.insert(&dropout_node);
|
||||
|
||||
const std::vector<NodeArg*> mlp_f_input_defs{node.MutableInputDefs()[0]};
|
||||
auto mlp_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto();
|
||||
auto& mlp_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronF_Output"), &mlp_f_type_info);
|
||||
Node& mlp_f_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronF"),
|
||||
"MegatronF",
|
||||
"MLP MegatronF",
|
||||
mlp_f_input_defs,
|
||||
{&mlp_f_out_arg}, {}, kMSDomain);
|
||||
counter++;
|
||||
mlp_f_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
const Node::EdgeEnd* edge = graph_utils::GetInputEdge(node, 0);
|
||||
if (nullptr == edge) { // handle input/initializer
|
||||
graph_utils::ReplaceNodeInput(node, 0, *(mlp_f_node.MutableOutputDefs()[0]));
|
||||
} else {
|
||||
auto input_node = const_cast<Node*>(&edge->GetNode());
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, *input_node, edge->GetSrcArgIndex(), mlp_f_node, 0);
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> mlp_g_input_defs{matmul2_node.MutableOutputDefs()[0]};
|
||||
auto mlp_g_type_info = *matmul2_node.MutableOutputDefs()[0]->TypeAsProto();
|
||||
auto& mlp_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BART_MLP_MegatronG_Output"), &mlp_g_type_info);
|
||||
Node& mlp_g_node = graph.AddNode(graph.GenerateNodeName("BART_MLP_MegatronG"),
|
||||
"MegatronG",
|
||||
"MLP MegatronG",
|
||||
mlp_g_input_defs,
|
||||
{&mlp_g_out_arg}, {}, kMSDomain);
|
||||
mlp_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
|
||||
mlp_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, matmul2_node, 0, mlp_g_node, 0);
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -357,7 +507,8 @@ Status MegatronTransformer::TransformMLP(Graph& graph, bool& modified, int graph
|
|||
Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& self_attention_dropout_nodes) const {
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform,
|
||||
int32_t& counter) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
|
|
@ -406,7 +557,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
// Get all useful nodes here as more vector push back below will change the index.
|
||||
Node& add_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
|
||||
Node& split_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
|
||||
Node& transpose_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12];
|
||||
Node& k_transpose_after_reshape_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12];
|
||||
Node* matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 11];
|
||||
Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 6];
|
||||
Node* matmul_node_ptr1 = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
|
||||
|
|
@ -414,7 +565,7 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
Node& matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 2];
|
||||
|
||||
// Transpose node attribute checking.
|
||||
if (!optimizer_utils::IsAttributeWithExpectedValues(transpose_node, "perm", {0LL, 2LL, 1LL, 3LL}) ||
|
||||
if (!optimizer_utils::IsAttributeWithExpectedValues(k_transpose_after_reshape_node, "perm", {0LL, 2LL, 1LL, 3LL}) ||
|
||||
!optimizer_utils::IsAttributeWithExpectedValues(transpose_node1, "perm", {0LL, 2LL, 1LL, 3LL})) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -518,12 +669,15 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
// Replace by the partition weights.
|
||||
NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg);
|
||||
updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()});
|
||||
|
||||
NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg);
|
||||
updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()});
|
||||
|
||||
NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg);
|
||||
updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});
|
||||
|
||||
graph.RemoveInitializedTensor(qkv_weight_arg->Name());
|
||||
graph.RemoveInitializedTensor(qkv_bias_arg->Name());
|
||||
|
|
@ -554,16 +708,16 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
}
|
||||
|
||||
if (dropout_node_ptr != nullptr) {
|
||||
self_attention_dropout_nodes.insert(dropout_node_ptr);
|
||||
dropout_nodes_to_transform.insert(dropout_node_ptr);
|
||||
}
|
||||
|
||||
// Add MegatronF before the 1st MatMul and MegatronG before the last Add.
|
||||
const std::vector<NodeArg*> sa_f_input_defs{node.MutableInputDefs()[0]};
|
||||
auto sa_f_type_info = *node.MutableInputDefs()[0]->TypeAsProto();
|
||||
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronF_Output"), &sa_f_type_info);
|
||||
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronF"),
|
||||
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronF_Output"), &sa_f_type_info);
|
||||
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SeftAttention_MegatronF"),
|
||||
"MegatronF",
|
||||
"SelfAttention MegatronF",
|
||||
"SeftAttention MegatronF",
|
||||
sa_f_input_defs,
|
||||
{&sa_f_out_arg}, {}, kMSDomain);
|
||||
sa_f_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
|
@ -577,23 +731,400 @@ Status MegatronTransformer::TransformSelfAttention(Graph& graph, bool& modified,
|
|||
|
||||
const std::vector<NodeArg*> sa_g_input_defs{matmul_node.MutableOutputDefs()[0]};
|
||||
auto sa_g_type_info = *matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy
|
||||
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SelfAttention_MegatronG_Output"), &sa_g_type_info);
|
||||
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName("SelfAttention_MegatronG"),
|
||||
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("SeftAttention_MegatronG_Output"), &sa_g_type_info);
|
||||
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "SelfAttention_MegatronG"),
|
||||
"MegatronG",
|
||||
"SelfAttention MegatronG",
|
||||
"Attention MegatronG",
|
||||
sa_g_input_defs,
|
||||
{&sa_g_out_arg}, {}, kMSDomain);
|
||||
sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
|
||||
sa_g_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, matmul_node, 0, sa_g_node, 0);
|
||||
modified = true;
|
||||
counter++;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MegatronTransformer::TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform,
|
||||
int32_t& counter) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
// Self attention sub-graph.
|
||||
//
|
||||
// MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout
|
||||
// MatMul->Add->Reshape->Transpose-------> | |
|
||||
// MatMul->Add->Reshape->Transpose----------------------------------------------------------> |
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* q_matmul_input_node_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
|
||||
if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) {
|
||||
continue;
|
||||
}
|
||||
std::vector<Node*> sub_graph_node_ptrs;
|
||||
sub_graph_node_ptrs.push_back(&node);
|
||||
ProviderType provider_type = node.GetExecutionProviderType();
|
||||
|
||||
std::vector<NodeInfo> linear_pattern = {
|
||||
NodeInfo({add_info}),
|
||||
NodeInfo({mul_info}),
|
||||
NodeInfo({reshape_info}),
|
||||
NodeInfo({transpose_info}),
|
||||
NodeInfo({matmul_info}),
|
||||
NodeInfo({add_info}, false), // -13
|
||||
NodeInfo({reshape_info}),
|
||||
NodeInfo({where_info}),
|
||||
NodeInfo({reshape_info}),
|
||||
NodeInfo({softmax_info}),
|
||||
NodeInfo({dropout_info}, false), // -8
|
||||
NodeInfo({matmul_info}),
|
||||
NodeInfo({add_info}, false), // -6
|
||||
NodeInfo({transpose_info}),
|
||||
NodeInfo({reshape_info}),
|
||||
NodeInfo({matmul_info}), // -3
|
||||
NodeInfo({add_info}),
|
||||
NodeInfo({dropout_info}, false)}; // -1
|
||||
if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) {
|
||||
continue;
|
||||
}
|
||||
// Get all useful nodes here as more vector push back below will change the index.
|
||||
// Other than the optional nodes in the pattern, all other node pointers are valid
|
||||
// if they match the linear pattern.
|
||||
Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18];
|
||||
Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
|
||||
Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
|
||||
Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8];
|
||||
Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7];
|
||||
Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
|
||||
Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3];
|
||||
|
||||
// Transpose node attribute checking.
|
||||
if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) ||
|
||||
!optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) {
|
||||
continue;
|
||||
}
|
||||
// map between reshape node and dim of reshape that must be modified
|
||||
std::unordered_map<Node*, int64_t> reshape_node_ptrs;
|
||||
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1;
|
||||
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1;
|
||||
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0;
|
||||
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2;
|
||||
// till now node should be q matmul operation
|
||||
|
||||
std::vector<Node*> weight_transpose_node_ptrs;
|
||||
std::vector<Node*> bias_add_node_ptrs;
|
||||
|
||||
Node* q_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
|
||||
if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
weight_transpose_node_ptrs.push_back(q_transpose_ptr);
|
||||
sub_graph_node_ptrs.push_back(q_transpose_ptr);
|
||||
bias_add_node_ptrs.push_back(q_biasadd_node_ptr);
|
||||
|
||||
Node* k_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name()));
|
||||
if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(k_transpose_ptr);
|
||||
|
||||
Node* k_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
reshape_node_ptrs[k_reshape_ptr] = 1;
|
||||
sub_graph_node_ptrs.push_back(k_reshape_ptr);
|
||||
|
||||
Node* k_add_ptr = const_cast<Node*>(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(k_add_ptr);
|
||||
bias_add_node_ptrs.push_back(k_add_ptr);
|
||||
|
||||
Node* k_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(k_matmul_ptr);
|
||||
|
||||
Node* k_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name()));
|
||||
if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(k_weight_transpose_ptr);
|
||||
weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr);
|
||||
|
||||
Node* v_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name()));
|
||||
if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(v_transpose_ptr);
|
||||
|
||||
Node* v_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
reshape_node_ptrs[v_reshape_ptr] = 1;
|
||||
sub_graph_node_ptrs.push_back(v_reshape_ptr);
|
||||
|
||||
Node* v_add_ptr = const_cast<Node*>(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(v_add_ptr);
|
||||
bias_add_node_ptrs.push_back(v_add_ptr);
|
||||
|
||||
Node* v_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name()));
|
||||
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(v_matmul_ptr);
|
||||
|
||||
Node* v_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name()));
|
||||
if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) {
|
||||
continue;
|
||||
}
|
||||
sub_graph_node_ptrs.push_back(v_weight_transpose_ptr);
|
||||
weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr);
|
||||
|
||||
// K and V matmul must have the same input
|
||||
Node* q_matmul_ptr = &node;
|
||||
if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the constant value in the Reshape nodes.
|
||||
bool is_reshape_valid = true;
|
||||
for (auto x : reshape_node_ptrs) {
|
||||
Node* node_ptr = x.first;
|
||||
int64_t idx = x.second;
|
||||
auto shape_arg = node_ptr->MutableInputDefs()[1];
|
||||
const ONNX_NAMESPACE::TensorProto* tensor;
|
||||
if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) {
|
||||
is_reshape_valid = false;
|
||||
break;
|
||||
}
|
||||
auto data_type = tensor->data_type();
|
||||
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
is_reshape_valid = false;
|
||||
break;
|
||||
}
|
||||
// The number of the values should be more than idx, and the idx'th value should be divisible by parallel size,
|
||||
// i.e., the attention head number should be divisible by parallel size.
|
||||
auto init_const = onnxruntime::make_unique<Initializer>(*tensor, graph.ModelPath());
|
||||
if (init_const->size() <= idx) {
|
||||
is_reshape_valid = false;
|
||||
break;
|
||||
}
|
||||
const int64_t* val = init_const->data<int64_t>();
|
||||
if (val[idx] % horizontal_parallel_size_ != 0) {
|
||||
LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx]
|
||||
<< " is not divisible by horizontal_parallel_size_ "
|
||||
<< horizontal_parallel_size_ << ", not supported currently.";
|
||||
is_reshape_valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_reshape_valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Partition weights. If any of them fails, skip transforming the rest.
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> qkv_weight_initializer_partitions;
|
||||
for (auto trans_ptr : weight_transpose_node_ptrs) {
|
||||
auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0];
|
||||
ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition;
|
||||
if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) {
|
||||
break;
|
||||
}
|
||||
qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition);
|
||||
}
|
||||
|
||||
// Partition bias. If any of them fails, skip transforming the rest.
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> qkv_bias_initializer_partitions;
|
||||
for (auto add_ptr : bias_add_node_ptrs) {
|
||||
auto qkv_bias_arg = add_ptr->MutableInputDefs()[1];
|
||||
ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition;
|
||||
if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) {
|
||||
break;
|
||||
}
|
||||
qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition);
|
||||
}
|
||||
|
||||
// if all the weights or biases weren't transformed, skip transforming this subgraph
|
||||
if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) {
|
||||
continue;
|
||||
}
|
||||
if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// transform the dense weight. If it fails, skip transforming this subgraph.
|
||||
Node* last_transpose = const_cast<Node*>(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name()));
|
||||
auto dense_weight_arg = last_transpose->MutableInputDefs()[0];
|
||||
ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition;
|
||||
if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ready to transform the sub-graph when reach here.
|
||||
// Replace node inputs
|
||||
size_t i = 0;
|
||||
for (auto trans_ptr : weight_transpose_node_ptrs) {
|
||||
auto weight_name = trans_ptr->MutableInputDefs()[0]->Name();
|
||||
NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]);
|
||||
graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg);
|
||||
graph.RemoveInitializedTensor(weight_name);
|
||||
updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()});
|
||||
i++;
|
||||
}
|
||||
i = 0;
|
||||
for (auto add_ptr : bias_add_node_ptrs) {
|
||||
auto bias_name = add_ptr->MutableInputDefs()[1]->Name();
|
||||
NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]);
|
||||
graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg);
|
||||
graph.RemoveInitializedTensor(bias_name);
|
||||
updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()});
|
||||
i++;
|
||||
}
|
||||
|
||||
NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
|
||||
graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg);
|
||||
graph.RemoveInitializedTensor(dense_weight_arg->Name());
|
||||
updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});
|
||||
|
||||
// It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching.
|
||||
std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(),
|
||||
std::back_inserter(nodes_to_clear_shape),
|
||||
[](Node* node_ptr) { return node_ptr != nullptr; });
|
||||
|
||||
// Change the constant for the reshape nodes.
|
||||
for (auto x : reshape_node_ptrs) {
|
||||
Node* node_ptr = x.first;
|
||||
int64_t idx = x.second;
|
||||
auto shape_arg = node_ptr->MutableInputDefs()[1];
|
||||
const ONNX_NAMESPACE::TensorProto* tensor;
|
||||
graph.GetInitializedTensor(shape_arg->Name(), tensor);
|
||||
auto data_type = tensor->data_type();
|
||||
auto init_const = onnxruntime::make_unique<Initializer>(*tensor, graph.ModelPath());
|
||||
const int64_t* val = init_const->data<int64_t>();
|
||||
int64_t size = init_const->size();
|
||||
ONNX_NAMESPACE::TensorProto tensor_partition;
|
||||
tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name()));
|
||||
tensor_partition.set_data_type(data_type);
|
||||
tensor_partition.add_dims(size);
|
||||
|
||||
std::vector<int64_t> val_partition;
|
||||
val_partition.reserve(size);
|
||||
val_partition.insert(val_partition.end(), val, val + size);
|
||||
val_partition[idx] /= horizontal_parallel_size_;
|
||||
tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t));
|
||||
NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition);
|
||||
graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition);
|
||||
graph.RemoveInitializedTensor(shape_arg->Name());
|
||||
}
|
||||
|
||||
if (dropout_node_ptr != nullptr) {
|
||||
dropout_nodes_to_transform.insert(dropout_node_ptr);
|
||||
}
|
||||
|
||||
// Add MegatronF before the 1st MatMul and MegatronG before the last Add.
|
||||
|
||||
NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0];
|
||||
std::vector<Node*> new_consumer_nodes;
|
||||
const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name());
|
||||
for (auto& n : node_consumers) {
|
||||
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
|
||||
continue;
|
||||
}
|
||||
new_consumer_nodes.emplace_back(const_cast<Node*>(n));
|
||||
}
|
||||
|
||||
bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0;
|
||||
|
||||
//then for q, and k&v will have different MegatronF node.
|
||||
{
|
||||
const std::vector<NodeArg*> sa_f_input_defs{prev_input_node_ptr};
|
||||
auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto();
|
||||
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info);
|
||||
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"),
|
||||
"MegatronF",
|
||||
k_matmul_ptr->Name() + " BARTAttention MegatronF",
|
||||
sa_f_input_defs,
|
||||
{&sa_f_out_arg}, {}, kMSDomain);
|
||||
sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
|
||||
graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
|
||||
graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
|
||||
if (shared_same_input) {
|
||||
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
|
||||
}
|
||||
new_consumer_nodes.push_back(&sa_f_node);
|
||||
}
|
||||
graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes);
|
||||
counter++;
|
||||
if (!shared_same_input) {
|
||||
{
|
||||
NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0];
|
||||
std::vector<Node*> q_new_consumer_nodes;
|
||||
const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name());
|
||||
for (auto& n : q_node_consumers) {
|
||||
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
|
||||
continue;
|
||||
}
|
||||
q_new_consumer_nodes.emplace_back(const_cast<Node*>(n));
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]};
|
||||
auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto();
|
||||
auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info);
|
||||
Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"),
|
||||
"MegatronF",
|
||||
q_matmul_ptr->Name() + " BARTAttention MegatronF",
|
||||
q_sa_f_input_defs,
|
||||
{&q_sa_f_out_arg}, {}, kMSDomain);
|
||||
q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType());
|
||||
|
||||
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0]));
|
||||
q_new_consumer_nodes.push_back(&q_sa_f_node);
|
||||
graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes);
|
||||
// todo: need update the consumer node for the input_node as well.
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]};
|
||||
auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy
|
||||
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info);
|
||||
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"),
|
||||
"MegatronG",
|
||||
"BARTAttention MegatronG",
|
||||
sa_g_input_defs,
|
||||
{&sa_g_out_arg}, {}, kMSDomain);
|
||||
sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
|
||||
sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
|
||||
std::unordered_set<Node*>& self_attention_dropout_nodes) const {
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto node_index : node_topology_list) {
|
||||
|
|
@ -610,18 +1141,17 @@ Status MegatronTransformer::TransformDropout(Graph& graph, bool& modified, int g
|
|||
}
|
||||
|
||||
// Only need to set the seed if it's a transformed self-attention dropout, or the seed attribute is not set.
|
||||
if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end() ||
|
||||
graph_utils::GetNodeAttribute(node, "seed") == nullptr) {
|
||||
if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) {
|
||||
int64_t seed = static_cast<int64_t>(HashName(node.MutableOutputDefs()[0]->Name())) + utils::GetRandomSeed();
|
||||
if (self_attention_dropout_nodes.find(&node) != self_attention_dropout_nodes.end()) {
|
||||
if (dropout_nodes_to_transform.find(&node) != dropout_nodes_to_transform.end()) {
|
||||
seed += horizontal_parallel_rank_;
|
||||
}
|
||||
|
||||
if (graph_utils::GetNodeAttribute(node, "seed") != nullptr) {
|
||||
node.ClearAttribute("seed");
|
||||
}
|
||||
|
||||
node.AddAttribute("seed", seed);
|
||||
counter++;
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -635,11 +1165,19 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
}
|
||||
|
||||
std::vector<Node*> nodes_to_clear_shape;
|
||||
std::unordered_set<Node*> self_attention_dropout_nodes;
|
||||
std::unordered_set<Node*> dropout_nodes_to_transform;
|
||||
|
||||
ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape).IsOK());
|
||||
ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, self_attention_dropout_nodes).IsOK());
|
||||
ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, self_attention_dropout_nodes).IsOK());
|
||||
int32_t partitioned_mlp_count = 0;
|
||||
int32_t partitioned_bart_mlp_count = 0;
|
||||
int32_t partitioned_attention_count = 0;
|
||||
int32_t partitioned_bart_attention_count = 0;
|
||||
int32_t dropout_changed = 0;
|
||||
|
||||
ORT_ENFORCE(TransformMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, partitioned_mlp_count).IsOK());
|
||||
ORT_ENFORCE(TransformBARTMLP(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_mlp_count).IsOK());
|
||||
ORT_ENFORCE(TransformSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_attention_count).IsOK());
|
||||
ORT_ENFORCE(TransformBARTSelfAttention(graph, modified, graph_level, logger, nodes_to_clear_shape, dropout_nodes_to_transform, partitioned_bart_attention_count).IsOK());
|
||||
ORT_ENFORCE(TransformDropout(graph, modified, graph_level, logger, dropout_nodes_to_transform, dropout_changed).IsOK());
|
||||
|
||||
auto& graph_inputs = graph.GetInputs();
|
||||
for (auto& node : nodes_to_clear_shape) {
|
||||
|
|
@ -653,13 +1191,28 @@ Status MegatronTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
output->ClearShape();
|
||||
}
|
||||
|
||||
for (auto x : updated_weight_names_) {
|
||||
auto old_initializer_name = x.first;
|
||||
auto new_initializer_name = x.second;
|
||||
if (weights_to_train_.find(old_initializer_name) != weights_to_train_.end()) {
|
||||
weights_to_train_.erase(old_initializer_name);
|
||||
weights_to_train_.insert(new_initializer_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
graph.SetGraphResolveNeeded();
|
||||
auto ret = graph.Resolve();
|
||||
LOGS_DEFAULT(WARNING) << "Megatron transformer result : Partitioned " << partitioned_mlp_count << " MLP Blocks, "
|
||||
<< partitioned_bart_mlp_count << " BART MLP Blocks, " << partitioned_attention_count << " Attention Blocks, "
|
||||
<< partitioned_bart_attention_count << " BART Attention Blocks; Reset seed for " << dropout_changed
|
||||
<< " Dropout nodes. Error Message: " << ret.ErrorMessage() << std::endl;
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
} else {
|
||||
LOGS_DEFAULT(WARNING) << "Megatron transformer result : unmodified\n";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,33 +11,50 @@ namespace onnxruntime {
|
|||
class MegatronTransformer : public GraphTransformer {
|
||||
public:
|
||||
MegatronTransformer(int32_t horizontal_parallel_rank, int32_t horizontal_parallel_size,
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names,
|
||||
std::unordered_set<std::string>& weights_to_train,
|
||||
const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("MegatronTransformer", compatible_execution_providers),
|
||||
horizontal_parallel_rank_(horizontal_parallel_rank),
|
||||
horizontal_parallel_size_(horizontal_parallel_size) {}
|
||||
horizontal_parallel_size_(horizontal_parallel_size),
|
||||
updated_weight_names_(updated_weight_names),
|
||||
weights_to_train_(weights_to_train) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
private:
|
||||
Status TransformMLP(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape) const;
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
int32_t& counter) const;
|
||||
|
||||
Status TransformSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& self_attention_dropout_nodes) const;
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform,
|
||||
int32_t& counter) const;
|
||||
|
||||
Status TransformBARTSelfAttention(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
|
||||
|
||||
Status TransformBARTMLP(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger,
|
||||
std::vector<Node*>& nodes_to_clear_shape,
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
|
||||
|
||||
Status TransformDropout(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger,
|
||||
std::unordered_set<Node*>& self_attention_dropout_nodes) const;
|
||||
std::unordered_set<Node*>& dropout_nodes_to_transform, int32_t& counter) const;
|
||||
|
||||
bool PartitionWeightByColumn(const Graph& graph, const NodeArg& input_arg,
|
||||
ONNX_NAMESPACE::TensorProto& initializer_partition, int stride = 1) const;
|
||||
ONNX_NAMESPACE::TensorProto& initializer_partition,
|
||||
int stride = 1) const;
|
||||
|
||||
bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg,
|
||||
ONNX_NAMESPACE::TensorProto& initializer_partition) const;
|
||||
bool PartitionWeightByRow(const Graph& graph, const NodeArg& input_arg, ONNX_NAMESPACE::TensorProto& initializer_partition) const;
|
||||
|
||||
const int32_t horizontal_parallel_rank_;
|
||||
const int32_t horizontal_parallel_size_;
|
||||
std::unordered_map<std::string, std::string>& updated_weight_names_;
|
||||
std::unordered_set<std::string>& weights_to_train_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -43,24 +43,36 @@ Status SetupOptimizerParams(
|
|||
const optional<std::string>& loss_scale_input_name,
|
||||
const TrainingSession::TrainingConfiguration& config,
|
||||
OptimizerGraphConfig& opt_graph_config_result,
|
||||
std::unordered_map<std::string, OptimizerNodeConfig>& opt_node_configs_result) {
|
||||
std::unordered_map<std::string, OptimizerNodeConfig>& opt_node_configs_result,
|
||||
std::unordered_map<std::string, std::string>& weight_name_map_after_graph_transform) {
|
||||
ORT_RETURN_IF_NOT(config.optimizer_config.has_value());
|
||||
const auto& optimizer_config = config.optimizer_config.value();
|
||||
|
||||
// This is the mapping from the new weight name to the original weight name
|
||||
// It is required to look up the optimizer config for the original weight
|
||||
// passed in the training session config
|
||||
std::unordered_map<std::string, std::string> reversed_weight_names_map;
|
||||
for (auto& p : weight_name_map_after_graph_transform) {
|
||||
reversed_weight_names_map.insert({p.second, p.first});
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
|
||||
for (const auto& weight_name : weight_names_to_train) {
|
||||
OptimizerNodeConfig opt_node_config{};
|
||||
opt_node_config.name = optimizer_config.name;
|
||||
opt_node_config.lr_feed_name = optimizer_config.learning_rate_input_name;
|
||||
|
||||
std::string original_weight_name = weight_name;
|
||||
if (reversed_weight_names_map.find(original_weight_name) != reversed_weight_names_map.end()) {
|
||||
original_weight_name = reversed_weight_names_map.at(original_weight_name);
|
||||
}
|
||||
try {
|
||||
opt_node_config.attributes = optimizer_config.weight_attributes_generator(weight_name);
|
||||
opt_node_config.attributes = optimizer_config.weight_attributes_generator(original_weight_name);
|
||||
} catch (const std::exception& ex) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
|
||||
}
|
||||
|
||||
try {
|
||||
opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(weight_name);
|
||||
opt_node_config.int_attributes = optimizer_config.weight_int_attributes_generator(original_weight_name);
|
||||
} catch (const std::exception& ex) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
|
||||
}
|
||||
|
|
@ -232,7 +244,8 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config,
|
||||
config_result));
|
||||
|
||||
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
|
|
@ -244,8 +257,14 @@ Status TrainingSession::ConfigureForTraining(
|
|||
!filtered_config_weight_names_to_train.empty()
|
||||
? filtered_config_weight_names_to_train
|
||||
: GetTrainableModelInitializers(config.immutable_weights, loss_name);
|
||||
|
||||
for (const auto& weight_name_to_not_train : config.weight_names_to_not_train) {
|
||||
weight_names_to_train.erase(weight_name_to_not_train);
|
||||
if (config_result.weight_name_map_after_graph_transform.find(weight_name_to_not_train) !=
|
||||
config_result.weight_name_map_after_graph_transform.end()) {
|
||||
weight_names_to_train.erase(config_result.weight_name_map_after_graph_transform.at(weight_name_to_not_train));
|
||||
} else {
|
||||
weight_names_to_train.erase(weight_name_to_not_train);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -316,7 +335,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
std::unordered_map<std::string, OptimizerNodeConfig> opt_node_configs{};
|
||||
ORT_RETURN_IF_ERROR(SetupOptimizerParams(
|
||||
weights_to_train_, fp32_weight_name_to_mixed_precision_node_arg,
|
||||
loss_scale_input_name, config, opt_graph_config, opt_node_configs));
|
||||
loss_scale_input_name, config, opt_graph_config, opt_node_configs, config_result.weight_name_map_after_graph_transform));
|
||||
TrainingConfigurationResult::OptimizerConfigurationResult optimizer_config_result{};
|
||||
ORT_RETURN_IF_ERROR(BuildOptimizer(
|
||||
opt_graph_config, opt_node_configs,
|
||||
|
|
@ -512,8 +531,9 @@ static Status AddGradientAccumulationNodes(Graph& graph,
|
|||
return GraphAugmenter::AugmentGraph(graph, graph_defs);
|
||||
}
|
||||
|
||||
Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config) {
|
||||
Status TrainingSession::ApplyTransformationsToMainGraph(std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TrainingConfigurationResult& config_result_out) {
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
// TODO: ideally we can just reuse the CPU EP registered with the session, but in the training session case
|
||||
// the EPs are registered after ConfigureForTraining and before Initialize is called. Hence we don't have access
|
||||
|
|
@ -522,7 +542,7 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set
|
|||
// Create execution frame for executing constant nodes.
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config);
|
||||
AddPreTrainingTransformers(*cpu_execution_provider, graph_transformation_mgr, weights_to_train, config, config_result_out);
|
||||
|
||||
// apply transformers
|
||||
Graph& graph = model_->MainGraph();
|
||||
|
|
@ -536,15 +556,16 @@ Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set
|
|||
// Registers all the pre transformers with transformer manager
|
||||
void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execution_provider,
|
||||
GraphTransformerManager& transformer_manager,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TrainingConfigurationResult& config_result_out,
|
||||
TransformerLevel graph_optimization_level,
|
||||
const std::vector<std::string>& custom_list) {
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
// Generate and register transformers for level
|
||||
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, weights_to_train, config, execution_provider, custom_list);
|
||||
level, weights_to_train, config, execution_provider, config_result_out.weight_name_map_after_graph_transform, custom_list);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
transformer_manager.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -245,6 +245,9 @@ class TrainingSession : public InferenceSession {
|
|||
// The pipeline configuration output.
|
||||
// This is only set if an pipeline is enabled.
|
||||
optional<PipelineConfigurationResult> pipeline_config_result;
|
||||
|
||||
// Mapped initialized names after weight partitioning for example MegatronTransformer
|
||||
std::unordered_map<std::string, std::string> weight_name_map_after_graph_transform{};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -392,14 +395,16 @@ class TrainingSession : public InferenceSession {
|
|||
common::Status InsertPipelineOps(const std::unordered_set<std::string>& initializer_names_to_preserve,
|
||||
pipeline::PipelineTensorNames& pipeline_tensor_names);
|
||||
|
||||
common::Status ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config);
|
||||
common::Status ApplyTransformationsToMainGraph(std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TrainingConfigurationResult& config_result_out);
|
||||
|
||||
/** configure initial transformers for training */
|
||||
void AddPreTrainingTransformers(const IExecutionProvider& execution_provider, // for constant folding
|
||||
GraphTransformerManager& transformer_manager,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TrainingConfigurationResult& config_result_out,
|
||||
TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel,
|
||||
const std::vector<std::string>& custom_list = {});
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ struct TrainingParameters {
|
|||
int gradient_accumulation_steps = 1;
|
||||
int data_parallel_size = 1;
|
||||
int horizontal_parallel_size = 1;
|
||||
int pipeline_parallel_size = 1;
|
||||
int deepspeed_zero_stage = 0;
|
||||
bool enable_grad_norm_clip = true;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
|
|
@ -65,11 +66,16 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
//TODO tix, refactor the mpi related code to populate all fields correctly by default.
|
||||
ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size);
|
||||
ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size);
|
||||
|
||||
if (parameters.world_size % parameters.data_parallel_size != 0) {
|
||||
throw std::runtime_error("Cannot split data parallel group because world_size is not divisible");
|
||||
}
|
||||
|
||||
if (parameters.world_size % parameters.horizontal_parallel_size != 0) {
|
||||
throw std::runtime_error("Cannot split horizontal parallel group because world_size is not divisible");
|
||||
}
|
||||
|
||||
auto data_group_size = parameters.world_size / parameters.horizontal_parallel_size;
|
||||
auto data_group_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size);
|
||||
if (data_group_size != parameters.data_parallel_size) {
|
||||
LOGS(*(sess->GetLogger()), WARNING) << "data_parallel_size is not correct, tuned automatically to "
|
||||
<< data_group_size;
|
||||
|
|
@ -201,7 +207,10 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
|
||||
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
|
||||
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
|
||||
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers);
|
||||
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers)
|
||||
.def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size)
|
||||
.def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size)
|
||||
.def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size);
|
||||
|
||||
#if defined(USE_MPI)
|
||||
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
|
||||
|
|
|
|||
|
|
@ -163,7 +163,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) {
|
|||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
|
|
@ -231,7 +233,9 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionRank1) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
|
|
@ -298,7 +302,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank0) {
|
|||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2), TransformerLevel::Level1);
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(0, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
|
|
@ -363,7 +369,9 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2), TransformerLevel::Level1);
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(1, 2, updated_weight_names, weights_to_train), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
|
|
@ -449,27 +457,35 @@ TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) {
|
|||
|
||||
// We only tested on CUDA run.
|
||||
#if defined(USE_CUDA)
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx";
|
||||
const int total_rank = 4;
|
||||
static void RunPartitionCorrectnessTest(std::string model_path,
|
||||
const logging::Logger& logger,
|
||||
const int total_rank,
|
||||
std::vector<std::string> input_names,
|
||||
std::vector<std::vector<int64_t>> input_dims) {
|
||||
const PathString model_uri = ToPathString(model_path) + ORT_TSTR(".onnx");
|
||||
// const int total_rank = 4;
|
||||
std::vector<Graph*> graphs;
|
||||
std::vector<std::shared_ptr<Model>> p_models(total_rank);
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Status ret = Model::Load(model_uri, p_models[i], nullptr, logger);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
Graph& graph = p_models[i]->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
|
||||
std::unordered_map<std::string, std::string> updated_weight_names;
|
||||
std::unordered_set<std::string> weights_to_train;
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank, updated_weight_names, weights_to_train), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
graphs.push_back(&graph);
|
||||
auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_rank_") + ToPathString(std::to_string(i)) + ORT_TSTR(".onnx");
|
||||
Model::Save(*p_models[i], model_uri2);
|
||||
}
|
||||
|
||||
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_);
|
||||
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, logger);
|
||||
auto& combine_graph = combine_model.MainGraph();
|
||||
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
auto model_uri2 = "mlp_megatron_basic_test_partition_combine.onnx";
|
||||
auto model_uri2 = ToPathString(model_path) + ORT_TSTR("_partition_combine.onnx");
|
||||
Model::Save(combine_model, model_uri2);
|
||||
|
||||
float scale = 1.f;
|
||||
|
|
@ -479,10 +495,18 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
std::default_random_engine generator{gsl::narrow_cast<uint32_t>(seed)};
|
||||
std::normal_distribution<float> distribution{mean, scale};
|
||||
|
||||
std::vector<int64_t> dims_X = {8, 16, 4};
|
||||
std::vector<float> values_X(TensorShape(dims_X).Size());
|
||||
std::for_each(values_X.begin(), values_X.end(),
|
||||
[&generator, &distribution](float& value) { value = distribution(generator); });
|
||||
ORT_ENFORCE(input_names.size() == input_dims.size());
|
||||
NameMLValMap feeds;
|
||||
for(size_t i = 0; i< input_dims.size();i++){
|
||||
std::vector<int64_t> dims_X = input_dims[i];
|
||||
std::vector<float> values_X(TensorShape(dims_X).Size());
|
||||
std::for_each(values_X.begin(), values_X.end(),
|
||||
[&generator, &distribution](float& value) { value = distribution(generator); });
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
|
||||
feeds.insert(std::make_pair(input_names[i], ml_value));
|
||||
}
|
||||
|
||||
std::vector<OrtValue> expected_ort_values;
|
||||
{
|
||||
|
|
@ -497,11 +521,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st;
|
||||
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("input", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("output");
|
||||
|
|
@ -527,11 +546,6 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st;
|
||||
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
|
||||
NameMLValMap feeds;
|
||||
feeds.insert(std::make_pair("input", ml_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
|
|
@ -554,137 +568,21 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
||||
RunPartitionCorrectnessTest("testdata/transform/model_parallel/mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}});
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MegatronBARTMLPPartitionCorrectnessTest) {
|
||||
RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_mlp_megatron_basic_test", *logger_, 4, {"input"}, {{8, 16, 4}});
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) {
|
||||
auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx";
|
||||
const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here.
|
||||
std::vector<Graph*> graphs;
|
||||
std::vector<std::shared_ptr<Model>> p_models(total_rank);
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
auto ret = Model::Load(model_uri, p_models[i], nullptr, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
Graph& graph = p_models[i]->MainGraph();
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MegatronTransformer>(i, total_rank), TransformerLevel::Level1);
|
||||
ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
graphs.push_back(&graph);
|
||||
}
|
||||
|
||||
// Dropout seed checking.
|
||||
const AttributeProto* attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout1"), "seed");
|
||||
ORT_ENFORCE(attr != nullptr && attr->has_i());
|
||||
int64_t dropout1_rank0_seed = attr->i();
|
||||
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[0], "dropout2"), "seed");
|
||||
ORT_ENFORCE(attr != nullptr && attr->has_i());
|
||||
int64_t dropout2_rank0_seed = attr->i();
|
||||
for (auto i = 1; i < total_rank; i++) {
|
||||
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout1"), "seed");
|
||||
ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout1_rank0_seed + i);
|
||||
attr = graph_utils::GetNodeAttribute(*GetNodeByName(*graphs[i], "dropout2"), "seed");
|
||||
ORT_ENFORCE(attr != nullptr && attr->has_i() && attr->i() == dropout2_rank0_seed);
|
||||
}
|
||||
|
||||
onnxruntime::Model combine_model("combine_graph", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}, {kMSDomain, 1}}, {}, *logger_);
|
||||
auto& combine_graph = combine_model.MainGraph();
|
||||
auto ret = horizontal_parallel_test_utils::MergeGraphsOnAllWorkers(graphs, combine_graph);
|
||||
ORT_ENFORCE(ret.IsOK());
|
||||
auto model_uri2 = "self_attention_megatron_basic_test_partition_combine.onnx";
|
||||
Model::Save(combine_model, model_uri2);
|
||||
|
||||
float scale = 1.f;
|
||||
float mean = 0.f;
|
||||
float seed = 123.f;
|
||||
|
||||
std::default_random_engine generator{gsl::narrow_cast<uint32_t>(seed)};
|
||||
std::normal_distribution<float> distribution{mean, scale};
|
||||
|
||||
std::vector<int64_t> dims_X = {8, 16, 4};
|
||||
std::vector<float> values_X(TensorShape(dims_X).Size());
|
||||
std::for_each(values_X.begin(), values_X.end(),
|
||||
[&generator, &distribution](float& value) { value = distribution(generator); });
|
||||
|
||||
std::vector<int64_t> dims_Mask = {8, 1, 16, 16};
|
||||
std::vector<float> values_Mask(TensorShape(dims_Mask).Size());
|
||||
std::for_each(values_Mask.begin(), values_Mask.end(),
|
||||
[&generator, &distribution](float& value) { value = distribution(generator); });
|
||||
|
||||
std::vector<OrtValue> expected_ort_values;
|
||||
{
|
||||
SessionOptions so;
|
||||
so.session_logid = "RawGraphRun";
|
||||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultCudaExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
Status st;
|
||||
ASSERT_TRUE((st = session_object.Load(model_uri)).IsOK()) << st;
|
||||
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
|
||||
|
||||
NameMLValMap feeds;
|
||||
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
|
||||
feeds.insert(std::make_pair("input", ml_value));
|
||||
|
||||
OrtValue mask_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value);
|
||||
feeds.insert(std::make_pair("mask", mask_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back("output");
|
||||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
}
|
||||
|
||||
std::vector<OrtValue> actual_ort_values;
|
||||
{
|
||||
SessionOptions so;
|
||||
so.session_logid = "SplitThenCombineRun";
|
||||
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
std::unique_ptr<IExecutionProvider> execution_provider = DefaultCudaExecutionProvider();
|
||||
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
|
||||
|
||||
Status st;
|
||||
ASSERT_TRUE((st = session_object.Load(model_uri2)).IsOK()) << st;
|
||||
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;
|
||||
|
||||
NameMLValMap feeds;
|
||||
OrtValue ml_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_X, values_X, &ml_value);
|
||||
feeds.insert(std::make_pair("input", ml_value));
|
||||
|
||||
OrtValue mask_value;
|
||||
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_Mask, values_Mask, &mask_value);
|
||||
feeds.insert(std::make_pair("mask", mask_value));
|
||||
|
||||
// prepare outputs
|
||||
std::vector<std::string> output_names;
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
output_names.push_back("output_rank_" + std::to_string(i));
|
||||
}
|
||||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
}
|
||||
|
||||
auto& expected_val = expected_ort_values[0].Get<Tensor>();
|
||||
for (auto i = 0; i < total_rank; i++) {
|
||||
auto& actual_val = actual_ort_values[i].Get<Tensor>();
|
||||
horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, true);
|
||||
horizontal_parallel_test_utils::VerifyOutputs(expected_val, actual_val, false);
|
||||
}
|
||||
RunPartitionCorrectnessTest("testdata/transform/model_parallel/self_attention_megatron_basic_test", *logger_, 2, {"input", "mask"}, {{8, 16, 4}, {8, 1, 16, 16}});
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MegatronBARTSelfAttentionPartitionCorrectnessTest) {
|
||||
RunPartitionCorrectnessTest("testdata/transform/model_parallel/bart_self_attention_megatron_basic_test", *logger_, 2, {"input"}, {{6, 8, 4}});
|
||||
}
|
||||
// end of USE_CUDA
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue