diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index a576bc20ed..9bafe39a5c 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -576,6 +576,10 @@ class OrtBackend: # rethrow FakeTensorProb failure because it is not yet currently handled. raise + graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( + self.resolved_onnx_exporter_options.diagnostic_context, graph_module + ).run() + from torch.onnx._internal.fx import fx_onnx_interpreter # Create the object to iterate through the nodes in graph one-by-one