From fcea2cb7f184d608efa1e5c72f9e25072e82009d Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 1 Dec 2023 09:36:18 -0800 Subject: [PATCH] [Dort] Run type promotion pass to resolve dtype discrepancy (#18516) Fixes CI failures mentioned in #18507 But we should not keep two separate dort impls in both pytorch and onnxruntime. They are out of sync. --- .../orttraining/python/training/torchdynamo/ort_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index a576bc20ed..9bafe39a5c 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -576,6 +576,10 @@ class OrtBackend: # rethrow FakeTensorProb failure because it is not yet currently handled. raise + graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( + self.resolved_onnx_exporter_options.diagnostic_context, graph_module + ).run() + from torch.onnx._internal.fx import fx_onnx_interpreter # Create the object to iterate through the nodes in graph one-by-one