From 5d8ce817cb5651b8ff579ef91e8dcfb6cf8583ed Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 8 Mar 2023 05:59:20 +0800 Subject: [PATCH] Fix simplified layer norm fusion for training (#14866) ### Fix simplified layer norm fusion for training Co-author with @prathikr. Fix bug identified by @prathikr. https://github.com/microsoft/onnxruntime/issues/14822. Running T5 model enabling deepspeed, we see simplified layer norm is not fused because the device check did not pass https://github.com/microsoft/onnxruntime/blob/b7fde84341f5e7e4fc8b202e9aabad4d087ec15c/onnxruntime/core/optimizer/layer_norm_fusion.cc#L568. Since during pretraining optimization pass, there is no device placement, so the device check not fulfilled is expected. On the other hand, the device check is still valid to avoid simplified layer norm fusion works correctly for CPU runs. As a mitigation, added a flag to indicate whether the fusion is triggered by pre-training optimization or not. There is a risk though, when we run ORTModule training with CPU EP, but I feel the risk can be much reduced if we check CUDA/ROCM is enabled for the build. ``` CUDA_VISIBLE_DEVICES=0 python examples/onnxruntime/training/summarization/run_summarization.py --model_name_or_path t5-small --do_train --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --predict_with_generate --overwrite_output_dir --output_dir /bert_ort/pengwa/output --fp16 --max_steps 1 --logging_steps 1 --deepspeed aml_ds_config_zero_1.json ``` ### Motivation and Context --- .../core/optimizer/layer_norm_fusion.cc | 19 +++++++++++-------- .../core/optimizer/layer_norm_fusion.h | 12 ++++++++++-- .../core/optimizer/graph_transformer_utils.cc | 5 +++++ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 25feb5b8d7..9bbb3c4d57 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -17,7 +17,7 @@ static constexpr std::array supported_data_types{"tensor(fl // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; -static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) { +static bool IsSupportedDataType(const Node& node, int first_n_inputs = -1) { int input_index = 0; for (const auto& input_arg : node.InputDefs()) { if (first_n_inputs != -1 && input_index >= first_n_inputs) { @@ -558,15 +558,18 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // 3. x->Cast(to:fp16)->y : SimplifiedLayerNorm(T:float,V:fp16) // 4. x->y : SimplifiedLayerNorm(T:float,V:float) // They all work for GPU EP. - // For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. But for #1 and - // #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2, we can still - // fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y after - // Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking if - // there is a Cast between x and y. Having Cast between means cannot fuse. + // For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. We made an + // exception here, since pre-training optimization happens without device assignment. skip_device_check_ is the + // flag to disable device check intent only for pre-training optimization. + // For #1 and #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2, + // we can still fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y + // after Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking + // if there is a Cast between x and y. Having Cast between means cannot fuse. const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0); bool has_leading_cast = false; - bool is_gpu_ep = pow_node.GetExecutionProviderType() == kCudaExecutionProvider || - pow_node.GetExecutionProviderType() == kRocmExecutionProvider; + bool is_gpu_ep = (pow_node.GetExecutionProviderType() == kCudaExecutionProvider || + pow_node.GetExecutionProviderType() == kRocmExecutionProvider) || + skip_device_check_; if (is_gpu_ep && p_pow_input_node) { Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index()); // If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index ea10c50ee4..18b176a802 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -35,10 +35,18 @@ The formula corresponding to LayerNorm activation subgraph: */ class SimplifiedLayerNormFusion : public GraphTransformer { public: - SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {} + SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}, + bool skip_device_check = false) noexcept + : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers), + skip_device_check_(skip_device_check) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: + // A flag indicate whether device check is skipped for some cases. + // This is introduced for pre-training optimizations, where when optimization passes are running, + // devices placement is NOT done yet. + bool skip_device_check_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index a068de1f4f..689d1aaeec 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -103,7 +103,12 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); +#if defined(USE_CUDA) || defined(USE_ROCM) + transformers.emplace_back(std::make_unique(compatible_eps, + true /* skip_device_check*/)); +#else transformers.emplace_back(std::make_unique(compatible_eps)); +#endif transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps));