diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index c3ef898849..c58ec49458 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -101,7 +101,7 @@ class GraphExecutionManager(GraphExecutionInterface): # predetermined list of opcodes considered safe to move before/after cast operation. # - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc. # whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc. - self._propagate_cast_ops_level = -1 + self._propagate_cast_ops_level = 1 # List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. self._propagate_cast_ops_allow = [] # Whether allow fusion of layer norm subgraph if doing so will cause modified precision. diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 4e66c7146b..d0e756b799 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -201,7 +201,7 @@ class ORTTrainerOptions(object): }, 'propagate_cast_ops_level': { 'type': 'integer', - 'default': -1 + 'default': 1 }, 'propagate_cast_ops_allow': { 'type': 'list', @@ -382,7 +382,7 @@ class ORTTrainerOptions(object): INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region. - graph_transformer.propagate_cast_ops_config.level(integer, default -1) + graph_transformer.propagate_cast_ops_config.level(integer, default 1) Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. Use predetermined list of opcodes considered safe to move before/after cast operation if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.