mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Enable cast propagation in the frontend. (#8517)
This commit is contained in:
parent
00d8f8ce95
commit
5e2f4263db
2 changed files with 3 additions and 3 deletions
|
|
@ -101,7 +101,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# predetermined list of opcodes considered safe to move before/after cast operation.
|
||||
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.
|
||||
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc.
|
||||
self._propagate_cast_ops_level = -1
|
||||
self._propagate_cast_ops_level = 1
|
||||
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
|
||||
self._propagate_cast_ops_allow = []
|
||||
# Whether allow fusion of layer norm subgraph if doing so will cause modified precision.
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class ORTTrainerOptions(object):
|
|||
},
|
||||
'propagate_cast_ops_level': {
|
||||
'type': 'integer',
|
||||
'default': -1
|
||||
'default': 1
|
||||
},
|
||||
'propagate_cast_ops_allow': {
|
||||
'type': 'list',
|
||||
|
|
@ -382,7 +382,7 @@ class ORTTrainerOptions(object):
|
|||
INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes.
|
||||
FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike
|
||||
INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.
|
||||
graph_transformer.propagate_cast_ops_config.level(integer, default -1)
|
||||
graph_transformer.propagate_cast_ops_config.level(integer, default 1)
|
||||
Optimize by moving Cast operations if propagate_cast_ops_level is non-negative.
|
||||
Use predetermined list of opcodes considered safe to move before/after cast operation
|
||||
if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
|
||||
|
|
|
|||
Loading…
Reference in a new issue