From 18c97381cd69db860f698451ef87722466b3542e Mon Sep 17 00:00:00 2001 From: ashari4 <70242157+ashari4@users.noreply.github.com> Date: Sat, 3 Jun 2023 01:17:49 -0500 Subject: [PATCH] Detect fake tensor mode if it has already been created. (#16220) ### Description Detect fake tensor mode if it has already been created. Follows this example in pytorch: https://github.com/pytorch/pytorch/blob/86c76525033418f6e9fed6134a6af66eff79d1d5/torch/_inductor/compile_fx.py#L280 ### Motivation and Context As of torch nightly 6/2/23, when trying to run a torch dynamo graph on the ORT backend, we observe ``` E torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: E AssertionError: Mixing fake modes NYI E E E You can suppress this exception and fall back to eager by setting: E import torch._dynamo E torch._dynamo.config.suppress_errors = True ``` The issue is that `ort_backend.py` creates a new fake tensor mode even though one has already been created by torch. --- .../orttraining/python/training/torchdynamo/ort_backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index b8d7990e64..4f2ec74519 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -18,6 +18,7 @@ import torch.jit import torch.onnx import torch.onnx._onnx_supported_ops from torch._decomp import decomposition_table +from torch._dynamo.utils import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.fake_tensor_prop import FakeTensorProp @@ -632,7 +633,10 @@ class OrtBackend: )(*args) # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to. # We need input and output tensors' devices to decide if aten::_to_copy is just a Cast. - FakeTensorProp(prim_graph_module).propagate(*args) + fake_mode = detect_fake_mode(args) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode() + FakeTensorProp(prim_graph_module, mode=fake_mode).propagate(*args) _replace_to_copy_with_to(prim_graph_module) partitioner = CapabilityBasedPartitioner( prim_graph_module, self._supported_ops, allows_single_node_partition=False