mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
[ORTModule] Enable SimplifiedLayerNormalization Fusion (#11580)
* enable SimplifiedLayerNormalization fuse * remove allow_layer_norm_mod_precision flag
This commit is contained in:
parent
03abcb0640
commit
54d1573d2f
13 changed files with 10 additions and 76 deletions
|
|
@ -46,7 +46,7 @@ X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
|
|||
| |
|
||||
+---------------------+
|
||||
|
||||
In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type
|
||||
In recent pytorch, Cast nodes may be inserted before Pow to ensure that both inputs 'base' and 'power' are the same type
|
||||
due to restriction in older opsets. Therefore, Layer Normalization will also handle the case below :
|
||||
+---------------------+
|
||||
| |
|
||||
|
|
@ -431,7 +431,7 @@ X --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
|
|||
| ^
|
||||
| |
|
||||
+----------------------------------------------+
|
||||
Additional FP16 patterns supported if allow_precision_change_ is true:
|
||||
Additional FP16 patterns supported:
|
||||
|
||||
X --> Cast1 --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Cast2 --> Mul
|
||||
| ^ ^
|
||||
|
|
@ -535,7 +535,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0);
|
||||
bool has_leading_cast = false;
|
||||
if (allow_precision_change_ && p_pow_input_node) {
|
||||
if (p_pow_input_node) {
|
||||
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
|
||||
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13}) &&
|
||||
|
|
@ -548,8 +548,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
// div --> mul or div --> cast --> mul
|
||||
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
|
||||
if (allow_precision_change_ &&
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) &&
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) &&
|
||||
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
|
||||
nodes_to_remove.push_back(*next_node);
|
||||
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());
|
||||
|
|
|
|||
|
|
@ -35,15 +35,10 @@ The formula corresponding to LayerNorm activation subgraph:
|
|||
*/
|
||||
class SimplifiedLayerNormFusion : public GraphTransformer {
|
||||
public:
|
||||
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
const bool allow_precision_change = false) noexcept
|
||||
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers),
|
||||
allow_precision_change_(allow_precision_change) {}
|
||||
SimplifiedLayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("SimplifiedLayerNormFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
private:
|
||||
bool allow_precision_change_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3847,27 +3847,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeDisallowed) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Div"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
ASSERT_TRUE(op_to_count["ReduceMean"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Pow"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Sqrt"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 2);
|
||||
ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_PrecisionChangeAllowed) {
|
||||
TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
|
|
@ -3875,7 +3855,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest_Precisio
|
|||
|
||||
InlinedHashSet<std::string_view> compatible_eps;
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps, true), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
|
|||
bool transformer_layer_recompute{false};
|
||||
// Number of layers to apply recompute
|
||||
int number_recompute_layers{0};
|
||||
bool allow_layer_norm_mod_precision{false};
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
|
||||
transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps, config.allow_layer_norm_mod_precision));
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ struct TrainingParameters {
|
|||
std::vector<std::string> propagate_cast_ops_allow;
|
||||
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy propagate_cast_ops_strategy =
|
||||
GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::FloodFill;
|
||||
bool allow_layer_norm_mod_precision = false;
|
||||
|
||||
// graph dumping
|
||||
std::string model_after_graph_transforms_path;
|
||||
|
|
@ -294,7 +293,6 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy;
|
||||
config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level;
|
||||
config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow;
|
||||
config.graph_transformer_config.allow_layer_norm_mod_precision = parameters.allow_layer_norm_mod_precision;
|
||||
|
||||
if (!parameters.model_after_graph_transforms_path.empty()) {
|
||||
config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path);
|
||||
|
|
@ -548,8 +546,7 @@ for every transfered tensor.
|
|||
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
|
||||
.def_readwrite("enable_adasum", &TrainingParameters::enable_adasum)
|
||||
.def_readwrite("propagate_cast_ops_level", &TrainingParameters::propagate_cast_ops_level)
|
||||
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow)
|
||||
.def_readwrite("allow_layer_norm_mod_precision", &TrainingParameters::allow_layer_norm_mod_precision);
|
||||
.def_readwrite("propagate_cast_ops_allow", &TrainingParameters::propagate_cast_ops_allow);
|
||||
|
||||
#if defined(USE_MPI)
|
||||
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
|
||||
|
|
@ -787,7 +784,6 @@ for every transfered tensor.
|
|||
.def_readwrite("gelu_recompute", &TrainingGraphTransformerConfiguration::gelu_recompute)
|
||||
.def_readwrite("transformer_layer_recompute", &TrainingGraphTransformerConfiguration::transformer_layer_recompute)
|
||||
.def_readwrite("number_recompute_layers", &TrainingGraphTransformerConfiguration::number_recompute_layers)
|
||||
.def_readwrite("allow_layer_norm_mod_precision", &TrainingGraphTransformerConfiguration::allow_layer_norm_mod_precision)
|
||||
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
|
||||
|
||||
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
|
||||
|
|
|
|||
|
|
@ -126,8 +126,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._propagate_cast_ops_level = 1
|
||||
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
|
||||
self._propagate_cast_ops_allow = []
|
||||
# Whether allow fusion of layer norm subgraph if doing so will cause modified precision.
|
||||
self._allow_layer_norm_mod_precision = False
|
||||
|
||||
# Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL
|
||||
# To be instantiated in the concrete implementation of GraphExecutionManager
|
||||
|
|
@ -463,7 +461,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
graph_transformer_config.propagate_cast_ops_config.level = self._propagate_cast_ops_level
|
||||
graph_transformer_config.propagate_cast_ops_config.allow = self._propagate_cast_ops_allow
|
||||
graph_transformer_config.propagate_cast_ops_config.strategy = self._propagate_cast_ops_strategy
|
||||
graph_transformer_config.allow_layer_norm_mod_precision = self._allow_layer_norm_mod_precision
|
||||
return graph_transformer_config
|
||||
|
||||
def _initialize_graph_builder(self, training):
|
||||
|
|
|
|||
|
|
@ -81,18 +81,6 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data):
|
|||
ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction
|
||||
|
||||
|
||||
def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data):
|
||||
"""Loads AllowLayerNormModPrecision from json file onto ORTModule."""
|
||||
|
||||
assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key)
|
||||
log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.")
|
||||
|
||||
assert isinstance(
|
||||
data.AllowLayerNormModPrecision, bool
|
||||
), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean"
|
||||
ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision
|
||||
|
||||
|
||||
def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data):
|
||||
"""Loads EnableGradAccOptimization from json file onto ORTModule."""
|
||||
|
||||
|
|
@ -218,7 +206,6 @@ def _define_load_function_keys():
|
|||
_load_propagate_cast_ops.loading_key = "PropagateCastOps"
|
||||
_load_use_external_gpu_allocator.loading_key = "UseExternalGPUAllocator"
|
||||
_load_enable_custom_autograd_function.loading_key = "EnableCustomAutogradFunction"
|
||||
_load_allow_layer_norm_mod_precision.loading_key = "AllowLayerNormModPrecision"
|
||||
_load_enable_grad_acc_optimization.loading_key = "EnableGradAccOptimization"
|
||||
_load_run_symbolic_shape_infer.loading_key = "RunSymbolicShapeInference"
|
||||
_load_use_static_shape.loading_key = "UseStaticShape"
|
||||
|
|
@ -242,7 +229,6 @@ def load_from_json(ortmodule, path=None):
|
|||
},
|
||||
"UseExternalGPUAllocator" : false, # bool flag
|
||||
"EnableCustomAutogradFunction": true, # bool flag
|
||||
"AllowLayerNormModPrecision": true, # bool flag
|
||||
"EnableGradAccOptimization": true, # bool flag
|
||||
"UseStaticShape": true, # bool flag
|
||||
"RunSymbolicShapeInference": false, # bool flag
|
||||
|
|
@ -299,7 +285,6 @@ def load_from_json(ortmodule, path=None):
|
|||
_load_propagate_cast_ops.loading_key: _load_propagate_cast_ops,
|
||||
_load_use_external_gpu_allocator.loading_key: _load_use_external_gpu_allocator,
|
||||
_load_enable_custom_autograd_function.loading_key: _load_enable_custom_autograd_function,
|
||||
_load_allow_layer_norm_mod_precision.loading_key: _load_allow_layer_norm_mod_precision,
|
||||
_load_enable_grad_acc_optimization.loading_key: _load_enable_grad_acc_optimization,
|
||||
_load_run_symbolic_shape_infer.loading_key: _load_run_symbolic_shape_infer,
|
||||
_load_use_static_shape.loading_key: _load_use_static_shape,
|
||||
|
|
|
|||
|
|
@ -210,10 +210,6 @@ class ORTTrainerOptions(object):
|
|||
'default': []
|
||||
}
|
||||
}
|
||||
},
|
||||
'allow_layer_norm_mod_precision': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
@ -389,9 +385,6 @@ class ORTTrainerOptions(object):
|
|||
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
|
||||
graph_transformer.propagate_cast_ops_config.allow(list of str, [])
|
||||
List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
|
||||
graph_transformer.allow_layer_norm_mod_precision(bool, default False)
|
||||
Enable LayerNormalization/SimplifiedLayerNormalization fusion
|
||||
even if it requires modified compute precision
|
||||
attn_dropout_recompute (bool, default is False):
|
||||
enable recomputing attention dropout to save memory
|
||||
gelu_recompute (bool, default is False):
|
||||
|
|
@ -630,7 +623,6 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
|
|||
"gelu_recompute": {"type": "boolean", "default": False},
|
||||
"transformer_layer_recompute": {"type": "boolean", "default": False},
|
||||
"number_recompute_layers": {"type": "integer", "min": 0, "default": 0},
|
||||
"allow_layer_norm_mod_precision": {"type": "boolean", "default": False},
|
||||
"propagate_cast_ops_config": {
|
||||
"type": "dict",
|
||||
"default_setter": lambda _: {},
|
||||
|
|
|
|||
|
|
@ -46,9 +46,6 @@ def test_load_config_from_json_1():
|
|||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function == True
|
||||
|
||||
# test allow layer norm mod precision
|
||||
assert ort_model_attributes._allow_layer_norm_mod_precision == True
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape == True
|
||||
|
||||
|
|
@ -102,9 +99,6 @@ def test_load_config_from_json_2():
|
|||
# test enable custom autograd function
|
||||
assert ort_model_attributes._enable_custom_autograd_function == False
|
||||
|
||||
# test allow layer norm mod precision
|
||||
assert ort_model_attributes._allow_layer_norm_mod_precision == False
|
||||
|
||||
# test use static shape
|
||||
assert ort_model_attributes._use_static_shape == False
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
},
|
||||
"UseExternalGPUAllocator" : false,
|
||||
"EnableCustomAutogradFunction": true,
|
||||
"AllowLayerNormModPrecision": true,
|
||||
"EnableGradAccOptimization": true,
|
||||
"UseStaticShape": true,
|
||||
"RunSymbolicShapeInference": false,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
},
|
||||
"UseExternalGPUAllocator" : true,
|
||||
"EnableCustomAutogradFunction": false,
|
||||
"AllowLayerNormModPrecision": false,
|
||||
"EnableGradAccOptimization": false,
|
||||
"UseStaticShape": false,
|
||||
"RunSymbolicShapeInference": true,
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ def testORTTrainerOptionsDefaultValues(test_input):
|
|||
"gelu_recompute": False,
|
||||
"transformer_layer_recompute": False,
|
||||
"number_recompute_layers": 0,
|
||||
"allow_layer_norm_mod_precision": False,
|
||||
"propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []},
|
||||
},
|
||||
"utils": {
|
||||
|
|
|
|||
Loading…
Reference in a new issue