mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
add function initialization back to graph resolve (#4434)
This commit is contained in:
parent
0fdb1e9f60
commit
dd73e8c016
2 changed files with 2 additions and 7 deletions
|
|
@ -2049,6 +2049,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
|
|||
node.op_ = nullptr;
|
||||
}
|
||||
|
||||
InitFunctionBodyForNode(node);
|
||||
|
||||
if (!node.op_) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -602,13 +602,6 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
LossSubgraph loss_subgraph(graph);
|
||||
|
||||
// Stage 1: Convert whole graph including forward and backward to FP16
|
||||
// Initialize function body for all function nodes
|
||||
// This is required to make sure after converting inputs\weights to FP16
|
||||
// the new NodeArg updates are correctly propagated to the function body nodes as well.
|
||||
for (auto& node : graph.Nodes()) {
|
||||
graph.InitFunctionBodyForNode(node);
|
||||
}
|
||||
|
||||
// Insert Cast node to convert inputs from FP32 to FP16
|
||||
// If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs.
|
||||
for (const NodeArg* input : graph.GetInputs()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue