[ORTModule] Enable SimplifiedLayerNormalization Fusion (#11580)

* enable SimplifiedLayerNormalization fuse

* remove allow_layer_norm_mod_precision flag
This commit is contained in:
Vincent Wang 2022-06-01 15:09:39 +08:00 committed by GitHub
parent 03abcb0640
commit 54d1573d2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 10 additions and 76 deletions

View file

@ -46,7 +46,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
| |
+---------------------+
In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type
In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type
due to restriction in older opsets. Therefore, Layer Normalization will also handle the case below :
+---------------------+
| |
@ -431,7 +431,7 @@ X --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
| ^
| |
+----------------------------------------------+
Additional FP16 patterns supported if allow_precision_change_ is true:
Additional FP16 patterns supported:
X --> Cast1 --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Cast2 --> Mul
| ^ ^
@ -535,7 +535,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0);
bool has_leading_cast = false;
if (allow_precision_change_ && p_pow_input_node) {
if (p_pow_input_node) {
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13}) &&
@ -548,8 +548,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
if (allow_precision_change_ &&
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) &&
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
nodes_to_remove.push_back(*next_node);
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());

View file

@ -35,15 +35,10 @@ The formula corresponding to LayerNorm activation subgraph:
*/
class SimplifiedLayerNormFusion : public GraphTransformer {
public:
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const bool allow_precision_change = false) noexcept
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers),
allow_precision_change_(allow_precision_change) {}
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
bool allow_precision_change_;
};
} // namespace onnxruntime

View file

@ -3847,27 +3847,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
}
}
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeDisallowed) {
auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 1);
ASSERT_TRUE(op_to_count["Add"] == 1);
ASSERT_TRUE(op_to_count["ReduceMean"] == 1);
ASSERT_TRUE(op_to_count["Pow"] == 1);
ASSERT_TRUE(op_to_count["Sqrt"] == 1);
ASSERT_TRUE(op_to_count["Cast"] == 2);
ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 0);
}
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeAllowed) {
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) {
auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
@ -3875,7 +3855,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_Precisio
InlinedHashSet<std::string_view> compatible_eps;
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps, true), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);

View file

@ -21,7 +21,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
bool transformer_layer_recompute{false};
// Number of layers to apply recompute
int number_recompute_layers{0};
bool allow_layer_norm_mod_precision{false};
};
} // namespace training

View file

@ -96,7 +96,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps, config.allow_layer_norm_mod_precision));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));

View file

@ -118,7 +118,6 @@ struct TrainingParameters {
std::vector<std::string> propagate_cast_ops_allow;
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy propagate_cast_ops_strategy =
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill;
bool allow_layer_norm_mod_precision = false;
// graph dumping
std::string model_after_graph_transforms_path;
@ -294,7 +293,6 @@ TrainingConfigurationResult ConfigureSessionForTraining(
config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy;
config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level;
config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow;
config.graph_transformer_config.allow_layer_norm_mod_precision = parameters.allow_layer_norm_mod_precision;
if (!parameters.model_after_graph_transforms_path.empty()) {
config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path);
@ -548,8 +546,7 @@ for every transfered tensor.
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
.def_readwrite("enable_adasum", &TrainingParameters::enable_adasum)
.def_readwrite("propagate_cast_ops_level", &TrainingParameters::propagate_cast_ops_level)
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow)
.def_readwrite("allow_layer_norm_mod_precision", &TrainingParameters::allow_layer_norm_mod_precision);
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow);
#if defined(USE_MPI)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
@ -787,7 +784,6 @@ for every transfered tensor.
.def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
.def_readwrite("allow_layer_norm_mod_precision", &TrainingGraphTransformerConfiguration::allow_layer_norm_mod_precision)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(

View file

@ -126,8 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface):
self._propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
self._propagate_cast_ops_allow = []
# Whether allow fusion of layer norm subgraph if doing so will cause modified precision.
self._allow_layer_norm_mod_precision = False
# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
# To be instantiated in the concrete implementation of GraphExecutionManager
@ -463,7 +461,6 @@ class GraphExecutionManager(GraphExecutionInterface):
graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level
graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow
graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy
graph_transformer_config.allow_layer_norm_mod_precision = self._allow_layer_norm_mod_precision
return graph_transformer_config
def _initialize_graph_builder(self, training):

View file

@ -81,18 +81,6 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
"""Loads AllowLayerNormModPrecision from json file onto ORTModule."""
assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key)
log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.")
assert isinstance(
data.AllowLayerNormModPrecision, bool
), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean"
ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
"""Loads EnableGradAccOptimization from json file onto ORTModule."""
@ -218,7 +206,6 @@ def _define_load_function_keys():
_load_propagate_cast_ops.loading_key = "PropagateCastOps"
_load_use_external_gpu_allocator.loading_key = "UseExternalGPUAllocator"
_load_enable_custom_autograd_function.loading_key = "EnableCustomAutogradFunction"
_load_allow_layer_norm_mod_precision.loading_key = "AllowLayerNormModPrecision"
_load_enable_grad_acc_optimization.loading_key = "EnableGradAccOptimization"
_load_run_symbolic_shape_infer.loading_key = "RunSymbolicShapeInference"
_load_use_static_shape.loading_key = "UseStaticShape"
@ -242,7 +229,6 @@ def load_from_json(ortmodule, path=None):
},
"UseExternalGPUAllocator" : false, # bool flag
"EnableCustomAutogradFunction": true, # bool flag
"AllowLayerNormModPrecision": true, # bool flag
"EnableGradAccOptimization": true, # bool flag
"UseStaticShape": true, # bool flag
"RunSymbolicShapeInference": false, # bool flag
@ -299,7 +285,6 @@ def load_from_json(ortmodule, path=None):
_load_propagate_cast_ops.loading_key: _load_propagate_cast_ops,
_load_use_external_gpu_allocator.loading_key: _load_use_external_gpu_allocator,
_load_enable_custom_autograd_function.loading_key: _load_enable_custom_autograd_function,
_load_allow_layer_norm_mod_precision.loading_key: _load_allow_layer_norm_mod_precision,
_load_enable_grad_acc_optimization.loading_key: _load_enable_grad_acc_optimization,
_load_run_symbolic_shape_infer.loading_key: _load_run_symbolic_shape_infer,
_load_use_static_shape.loading_key: _load_use_static_shape,

View file

@ -210,10 +210,6 @@ class ORTTrainerOptions(object):
'default': []
}
}
},
'allow_layer_norm_mod_precision': {
'type': 'boolean',
'default': False
}
}
},
@ -389,9 +385,6 @@ class ORTTrainerOptions(object):
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
graph_transformer.propagate_cast_ops_config.allow(list of str, [])
List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
graph_transformer.allow_layer_norm_mod_precision(bool, default False)
Enable LayerNormalization/SimplifiedLayerNormalization fusion
even if it requires modified compute precision
attn_dropout_recompute (bool, default is False):
enable recomputing attention dropout to save memory
gelu_recompute (bool, default is False):
@ -630,7 +623,6 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
"gelu_recompute": {"type": "boolean", "default": False},
"transformer_layer_recompute": {"type": "boolean", "default": False},
"number_recompute_layers": {"type": "integer", "min": 0, "default": 0},
"allow_layer_norm_mod_precision": {"type": "boolean", "default": False},
"propagate_cast_ops_config": {
"type": "dict",
"default_setter": lambda _: {},

View file

@ -46,9 +46,6 @@ def test_load_config_from_json_1():
# test enable custom autograd function
assert ort_model_attributes._enable_custom_autograd_function == True
# test allow layer norm mod precision
assert ort_model_attributes._allow_layer_norm_mod_precision == True
# test use static shape
assert ort_model_attributes._use_static_shape == True
@ -102,9 +99,6 @@ def test_load_config_from_json_2():
# test enable custom autograd function
assert ort_model_attributes._enable_custom_autograd_function == False
# test allow layer norm mod precision
assert ort_model_attributes._allow_layer_norm_mod_precision == False
# test use static shape
assert ort_model_attributes._use_static_shape == False

View file

@ -7,7 +7,6 @@
},
"UseExternalGPUAllocator" : false,
"EnableCustomAutogradFunction": true,
"AllowLayerNormModPrecision": true,
"EnableGradAccOptimization": true,
"UseStaticShape": true,
"RunSymbolicShapeInference": false,

View file

@ -7,7 +7,6 @@
},
"UseExternalGPUAllocator" : true,
"EnableCustomAutogradFunction": false,
"AllowLayerNormModPrecision": false,
"EnableGradAccOptimization": false,
"UseStaticShape": false,
"RunSymbolicShapeInference": true,

View file

@ -84,7 +84,6 @@ def testORTTrainerOptionsDefaultValues(test_input):
"gelu_recompute": False,
"transformer_layer_recompute": False,
"number_recompute_layers": 0,
"allow_layer_norm_mod_precision": False,
"propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []},
},
"utils": {