Enable cast propagation in the frontend. (#8517)

This commit is contained in:
satyajandhyala 2021-07-28 17:06:49 -07:00 committed by GitHub
parent 00d8f8ce95
commit 5e2f4263db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View file

@ -101,7 +101,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# predetermined list of opcodes considered safe to move before/after cast operation.
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc.
self._propagate_cast_ops_level = -1
self._propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
self._propagate_cast_ops_allow = []
# Whether allow fusion of layer norm subgraph if doing so will cause modified precision.

View file

@ -201,7 +201,7 @@ class ORTTrainerOptions(object):
},
'propagate_cast_ops_level': {
'type': 'integer',
'default': -1
'default': 1
},
'propagate_cast_ops_allow': {
'type': 'list',
@ -382,7 +382,7 @@ class ORTTrainerOptions(object):
INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes.
FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike
INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.
graph_transformer.propagate_cast_ops_config.level(integer, default -1)
graph_transformer.propagate_cast_ops_config.level(integer, default 1)
Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
Use predetermined list of opcodes considered safe to move before/after cast operation
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.