[Dort] Run type promotion pass to resolve dtype discrepancy (#18516)

Fixes CI failures mentioned in #18507

But we should not keep two separate dort impls in both pytorch and
onnxruntime. They are out of sync.
This commit is contained in:
Bowen Bao 2023-12-01 09:36:18 -08:00 committed by GitHub
parent 05a9c95764
commit fcea2cb7f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -576,6 +576,10 @@ class OrtBackend:
# rethrow FakeTensorProb failure because it is not yet currently handled.
raise
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
self.resolved_onnx_exporter_options.diagnostic_context, graph_module
).run()
from torch.onnx._internal.fx import fx_onnx_interpreter
# Create the object to iterate through the nodes in graph one-by-one