diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 25feb5b8d7..9bbb3c4d57 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -17,7 +17,7 @@ static constexpr std::array supported_data_types{"tensor(fl // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; -static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) { +static bool IsSupportedDataType(const Node& node, int first_n_inputs = -1) { int input_index = 0; for (const auto& input_arg : node.InputDefs()) { if (first_n_inputs != -1 && input_index >= first_n_inputs) { @@ -558,15 +558,18 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // 3. x->Cast(to:fp16)->y : SimplifiedLayerNorm(T:float,V:fp16) // 4. x->y : SimplifiedLayerNorm(T:float,V:float) // They all work for GPU EP. - // For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. But for #1 and - // #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2, we can still - // fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y after - // Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking if - // there is a Cast between x and y. Having Cast between means cannot fuse. + // For CPU EP, we have only SimplifiedlayerNorm(T:float,V:float) implementation, so only #4 works. We made an + // exception here, since pre-training optimization happens without device assignment. skip_device_check_ is the + // flag to disable device check intent only for pre-training optimization. + // For #1 and #2, if we treat the entry Cast as a normal node, meaning has_leading_cast is false, then for #2, + // we can still fuse it to "Cast(to:float)->SimplifiedlayerNorm(T:float,V:float)" (same as applying #4 to the x->y + // after Cast), so the condition for CPU EP to fuse or not is always setting has_leading_cast to false and checking + // if there is a Cast between x and y. Having Cast between means cannot fuse. const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0); bool has_leading_cast = false; - bool is_gpu_ep = pow_node.GetExecutionProviderType() == kCudaExecutionProvider || - pow_node.GetExecutionProviderType() == kRocmExecutionProvider; + bool is_gpu_ep = (pow_node.GetExecutionProviderType() == kCudaExecutionProvider || + pow_node.GetExecutionProviderType() == kRocmExecutionProvider) || + skip_device_check_; if (is_gpu_ep && p_pow_input_node) { Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index()); // If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index ea10c50ee4..18b176a802 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -35,10 +35,18 @@ The formula corresponding to LayerNorm activation subgraph: */ class SimplifiedLayerNormFusion : public GraphTransformer { public: - SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {} + SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}, + bool skip_device_check = false) noexcept + : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers), + skip_device_check_(skip_device_check) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: + // A flag indicate whether device check is skipped for some cases. + // This is introduced for pre-training optimizations, where when optimization passes are running, + // devices placement is NOT done yet. + bool skip_device_check_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index a068de1f4f..689d1aaeec 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -103,7 +103,12 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); +#if defined(USE_CUDA) || defined(USE_ROCM) + transformers.emplace_back(std::make_unique(compatible_eps, + true /* skip_device_check*/)); +#else transformers.emplace_back(std::make_unique(compatible_eps)); +#endif transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps));