From 54d1573d2f6ab27cdf6e1bf5848a51bc03e0853c Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 1 Jun 2022 15:09:39 +0800 Subject: [PATCH] [ORTModule] Enable SimplifiedLayerNormalization Fusion (#11580) * enable SimplifiedLayerNormalization fuse * remove allow_layer_norm_mod_precision flag --- .../core/optimizer/layer_norm_fusion.cc | 9 ++++--- .../core/optimizer/layer_norm_fusion.h | 9 ++----- .../test/optimizer/graph_transform_test.cc | 24 ++----------------- .../core/optimizer/graph_transformer_config.h | 1 - .../core/optimizer/graph_transformer_utils.cc | 2 +- .../python/orttraining_pybind_state.cc | 6 +---- .../ortmodule/_graph_execution_manager.py | 3 --- .../json_config/_load_config_from_json.py | 15 ------------ .../python/training/orttrainer_options.py | 8 ------- ...test_ortmodule_experimental_json_config.py | 6 ----- ..._ortmodule_experimental_json_config_1.json | 1 - ..._ortmodule_experimental_json_config_2.json | 1 - .../orttraining_test_orttrainer_frontend.py | 1 - 13 files changed, 10 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index e5a9798425..cc54b0e4bf 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -46,7 +46,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul | | +---------------------+ -In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type +In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type due to restriction in older opsets. Therefore, Layer Normalization will also handle the case below : +---------------------+ | | @@ -431,7 +431,7 @@ X --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul | ^ | | +----------------------------------------------+ -Additional FP16 patterns supported if allow_precision_change_ is true: +Additional FP16 patterns supported: X --> Cast1 --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Cast2 --> Mul | ^ ^ @@ -535,7 +535,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0); bool has_leading_cast = false; - if (allow_precision_change_ && p_pow_input_node) { + if (p_pow_input_node) { Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index()); // If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div) if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13}) && @@ -548,8 +548,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // div --> mul or div --> cast --> mul Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index()); - if (allow_precision_change_ && - graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) && + if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) && optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) { nodes_to_remove.push_back(*next_node); next_node = graph.GetNode(next_node->OutputNodesBegin()->Index()); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.h b/onnxruntime/core/optimizer/layer_norm_fusion.h index 11bccfddf2..ea10c50ee4 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.h +++ b/onnxruntime/core/optimizer/layer_norm_fusion.h @@ -35,15 +35,10 @@ The formula corresponding to LayerNorm activation subgraph: */ class SimplifiedLayerNormFusion : public GraphTransformer { public: - SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}, - const bool allow_precision_change = false) noexcept - : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers), - allow_precision_change_(allow_precision_change) {} + SimplifiedLayerNormFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - - private: - bool allow_precision_change_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 7caf1638ec..aa85e9029f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3847,27 +3847,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } -TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeDisallowed) { - auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 1); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["ReduceMean"] == 1); - ASSERT_TRUE(op_to_count["Pow"] == 1); - ASSERT_TRUE(op_to_count["Sqrt"] == 1); - ASSERT_TRUE(op_to_count["Cast"] == 2); - ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 0); -} - -TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeAllowed) { +TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -3875,7 +3855,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_Precisio InlinedHashSet compatible_eps; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps, true), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps), TransformerLevel::Level2)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_config.h b/orttraining/orttraining/core/optimizer/graph_transformer_config.h index ef0c4cf57c..3b1bf22ae2 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_config.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_config.h @@ -21,7 +21,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat bool transformer_layer_recompute{false}; // Number of layers to apply recompute int number_recompute_layers{0}; - bool allow_layer_norm_mod_precision{false}; }; } // namespace training diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 5f9269946b..2012979462 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -96,7 +96,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps, config.allow_layer_norm_mod_precision)); + transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index cb2a7d7439..afa4b338f2 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -118,7 +118,6 @@ struct TrainingParameters { std::vector propagate_cast_ops_allow; GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy propagate_cast_ops_strategy = GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill; - bool allow_layer_norm_mod_precision = false; // graph dumping std::string model_after_graph_transforms_path; @@ -294,7 +293,6 @@ TrainingConfigurationResult ConfigureSessionForTraining( config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy; config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level; config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow; - config.graph_transformer_config.allow_layer_norm_mod_precision = parameters.allow_layer_norm_mod_precision; if (!parameters.model_after_graph_transforms_path.empty()) { config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path); @@ -548,8 +546,7 @@ for every transfered tensor. .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) .def_readwrite("enable_adasum", &TrainingParameters::enable_adasum) .def_readwrite("propagate_cast_ops_level", &TrainingParameters::propagate_cast_ops_level) - .def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow) - .def_readwrite("allow_layer_norm_mod_precision", &TrainingParameters::allow_layer_norm_mod_precision); + .def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow); #if defined(USE_MPI) m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); }); @@ -787,7 +784,6 @@ for every transfered tensor. .def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute) .def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute) .def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers) - .def_readwrite("allow_layer_norm_mod_precision", &TrainingGraphTransformerConfiguration::allow_layer_norm_mod_precision) .def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config); py::class_ module_graph_builder_config( diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 9c62fbebf1..47c0608311 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -126,8 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface): self._propagate_cast_ops_level = 1 # List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. self._propagate_cast_ops_allow = [] - # Whether allow fusion of layer norm subgraph if doing so will cause modified precision. - self._allow_layer_norm_mod_precision = False # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL # To be instantiated in the concrete implementation of GraphExecutionManager @@ -463,7 +461,6 @@ class GraphExecutionManager(GraphExecutionInterface): graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy - graph_transformer_config.allow_layer_norm_mod_precision = self._allow_layer_norm_mod_precision return graph_transformer_config def _initialize_graph_builder(self, training): diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py index e896b9f6e6..f251df2736 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py @@ -81,18 +81,6 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction -def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data): - """Loads AllowLayerNormModPrecision from json file onto ORTModule.""" - - assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key) - log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.") - - assert isinstance( - data.AllowLayerNormModPrecision, bool - ), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean" - ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision - - def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): """Loads EnableGradAccOptimization from json file onto ORTModule.""" @@ -218,7 +206,6 @@ def _define_load_function_keys(): _load_propagate_cast_ops.loading_key = "PropagateCastOps" _load_use_external_gpu_allocator.loading_key = "UseExternalGPUAllocator" _load_enable_custom_autograd_function.loading_key = "EnableCustomAutogradFunction" - _load_allow_layer_norm_mod_precision.loading_key = "AllowLayerNormModPrecision" _load_enable_grad_acc_optimization.loading_key = "EnableGradAccOptimization" _load_run_symbolic_shape_infer.loading_key = "RunSymbolicShapeInference" _load_use_static_shape.loading_key = "UseStaticShape" @@ -242,7 +229,6 @@ def load_from_json(ortmodule, path=None): }, "UseExternalGPUAllocator" : false, # bool flag "EnableCustomAutogradFunction": true, # bool flag - "AllowLayerNormModPrecision": true, # bool flag "EnableGradAccOptimization": true, # bool flag "UseStaticShape": true, # bool flag "RunSymbolicShapeInference": false, # bool flag @@ -299,7 +285,6 @@ def load_from_json(ortmodule, path=None): _load_propagate_cast_ops.loading_key: _load_propagate_cast_ops, _load_use_external_gpu_allocator.loading_key: _load_use_external_gpu_allocator, _load_enable_custom_autograd_function.loading_key: _load_enable_custom_autograd_function, - _load_allow_layer_norm_mod_precision.loading_key: _load_allow_layer_norm_mod_precision, _load_enable_grad_acc_optimization.loading_key: _load_enable_grad_acc_optimization, _load_run_symbolic_shape_infer.loading_key: _load_run_symbolic_shape_infer, _load_use_static_shape.loading_key: _load_use_static_shape, diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 9431e2eaa4..080b8202a5 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -210,10 +210,6 @@ class ORTTrainerOptions(object): 'default': [] } } - }, - 'allow_layer_norm_mod_precision': { - 'type': 'boolean', - 'default': False } } }, @@ -389,9 +385,6 @@ class ORTTrainerOptions(object): if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise. graph_transformer.propagate_cast_ops_config.allow(list of str, []) List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. - graph_transformer.allow_layer_norm_mod_precision(bool, default False) - Enable LayerNormalization/SimplifiedLayerNormalization fusion - even if it requires modified compute precision attn_dropout_recompute (bool, default is False): enable recomputing attention dropout to save memory gelu_recompute (bool, default is False): @@ -630,7 +623,6 @@ _ORTTRAINER_OPTIONS_SCHEMA = { "gelu_recompute": {"type": "boolean", "default": False}, "transformer_layer_recompute": {"type": "boolean", "default": False}, "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, - "allow_layer_norm_mod_precision": {"type": "boolean", "default": False}, "propagate_cast_ops_config": { "type": "dict", "default_setter": lambda _: {}, diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py index 8eebb767f6..7b2e08dc9e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py @@ -46,9 +46,6 @@ def test_load_config_from_json_1(): # test enable custom autograd function assert ort_model_attributes._enable_custom_autograd_function == True - # test allow layer norm mod precision - assert ort_model_attributes._allow_layer_norm_mod_precision == True - # test use static shape assert ort_model_attributes._use_static_shape == True @@ -102,9 +99,6 @@ def test_load_config_from_json_2(): # test enable custom autograd function assert ort_model_attributes._enable_custom_autograd_function == False - # test allow layer norm mod precision - assert ort_model_attributes._allow_layer_norm_mod_precision == False - # test use static shape assert ort_model_attributes._use_static_shape == False diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json index bf51e5d710..d8a5390100 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_1.json @@ -7,7 +7,6 @@ }, "UseExternalGPUAllocator" : false, "EnableCustomAutogradFunction": true, - "AllowLayerNormModPrecision": true, "EnableGradAccOptimization": true, "UseStaticShape": true, "RunSymbolicShapeInference": false, diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json index 9c42a211a1..ed1bb29a4c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config_2.json @@ -7,7 +7,6 @@ }, "UseExternalGPUAllocator" : true, "EnableCustomAutogradFunction": false, - "AllowLayerNormModPrecision": false, "EnableGradAccOptimization": false, "UseStaticShape": false, "RunSymbolicShapeInference": true, diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 5dd9e13684..924750e738 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -84,7 +84,6 @@ def testORTTrainerOptionsDefaultValues(test_input): "gelu_recompute": False, "transformer_layer_recompute": False, "number_recompute_layers": 0, - "allow_layer_norm_mod_precision": False, "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, }, "utils": {