mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Fix simplified layer norm fusion for training (#14866)
### Fix simplified layer norm fusion for training
Co-author with @prathikr.
Fix bug identified by @prathikr.
https://github.com/microsoft/onnxruntime/issues/14822.
Running T5 model enabling deepspeed, we see simplified layer norm is not
fused because the device check did not pass
b7fde84341/onnxruntime/core/optimizer/layer_norm_fusion.cc (L568).
Since during pretraining optimization pass, there is no device
placement, so the device check not fulfilled is expected.
On the other hand, the device check is still valid to avoid simplified
layer norm fusion works correctly for CPU runs. As a mitigation, added a
flag to indicate whether the fusion is triggered by pre-training
optimization or not. There is a risk though, when we run ORTModule
training with CPU EP, but I feel the risk can be much reduced if we
check CUDA/ROCM is enabled for the build.
```
CUDA_VISIBLE_DEVICES=0 python examples/onnxruntime/training/summarization/run_summarization.py --model_name_or_path t5-small --do_train --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --predict_with_generate --overwrite_output_dir --output_dir /bert_ort/pengwa/output --fp16 --max_steps 1 --logging_steps 1 --deepspeed aml_ds_config_zero_1.json
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
65f1f840f6
commit
5d8ce817cb
3 changed files with 26 additions and 10 deletions
|
|
@ -17,7 +17,7 @@ static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(fl
|
|||
// Default epsilon
|
||||
static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;
|
||||
|
||||
static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) {
|
||||
static bool IsSupportedDataType(const Node& node, int first_n_inputs = -1) {
|
||||
int input_index = 0;
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
if (first_n_inputs != -1 && input_index >= first_n_inputs) {
|
||||
|
|
@ -558,15 +558,18 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
// 3. x->Cast(to:fp16)->y : SimplifiedLayerNorm(T:float,V:fp16)
|
||||
// 4. x->y : SimplifiedLayerNorm(T:float,V:float)
|
||||
// They all work for GPU EP.
|
||||
// For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. But for #1 and
|
||||
// #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2, we can still
|
||||
// fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y after
|
||||
// Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking if
|
||||
// there is a Cast between x and y. Having Cast between means cannot fuse.
|
||||
// For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. We made an
|
||||
// exception here, since pre-training optimization happens without device assignment. skip_device_check_ is the
|
||||
// flag to disable device check intent only for pre-training optimization.
|
||||
// For #1 and #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2,
|
||||
// we can still fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y
|
||||
// after Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking
|
||||
// if there is a Cast between x and y. Having Cast between means cannot fuse.
|
||||
const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0);
|
||||
bool has_leading_cast = false;
|
||||
bool is_gpu_ep = pow_node.GetExecutionProviderType() == kCudaExecutionProvider ||
|
||||
pow_node.GetExecutionProviderType() == kRocmExecutionProvider;
|
||||
bool is_gpu_ep = (pow_node.GetExecutionProviderType() == kCudaExecutionProvider ||
|
||||
pow_node.GetExecutionProviderType() == kRocmExecutionProvider) ||
|
||||
skip_device_check_;
|
||||
if (is_gpu_ep && p_pow_input_node) {
|
||||
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
|
||||
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
|
||||
|
|
|
|||
|
|
@ -35,10 +35,18 @@ The formula corresponding to LayerNorm activation subgraph:
|
|||
*/
|
||||
class SimplifiedLayerNormFusion : public GraphTransformer {
|
||||
public:
|
||||
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {}
|
||||
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
bool skip_device_check = false) noexcept
|
||||
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers),
|
||||
skip_device_check_(skip_device_check) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
private:
|
||||
// A flag indicate whether device check is skipped for some cases.
|
||||
// This is introduced for pre-training optimizations, where when optimization passes are running,
|
||||
// devices placement is NOT done yet.
|
||||
bool skip_device_check_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -103,7 +103,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps,
|
||||
true /* skip_device_check*/));
|
||||
#else
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
|
||||
#endif
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
|
||||
|
|
|
|||
Loading…
Reference in a new issue