diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index b8d7990e64..4f2ec74519 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -18,6 +18,7 @@ import torch.jit import torch.onnx import torch.onnx._onnx_supported_ops from torch._decomp import decomposition_table +from torch._dynamo.utils import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.fake_tensor_prop import FakeTensorProp @@ -632,7 +633,10 @@ class OrtBackend: )(*args) # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to. # We need input and output tensors' devices to decide if aten::_to_copy is just a Cast. - FakeTensorProp(prim_graph_module).propagate(*args) + fake_mode = detect_fake_mode(args) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode() + FakeTensorProp(prim_graph_module, mode=fake_mode).propagate(*args) _replace_to_copy_with_to(prim_graph_module) partitioner = CapabilityBasedPartitioner( prim_graph_module, self._supported_ops, allows_single_node_partition=False