Detect fake tensor mode if it has already been created. (#16220)

### Description
<!-- Describe your changes. -->

Detect fake tensor mode if it has already been created. Follows this
example in pytorch:
86c7652503/torch/_inductor/compile_fx.py (L280)


### Motivation and Context
As of torch nightly 6/2/23, when trying to run a torch dynamo graph on
the ORT backend, we observe

```
E           torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
E           AssertionError: Mixing fake modes NYI
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True
```
The issue is that `ort_backend.py` creates a new fake tensor mode even
though one has already been created by torch.
This commit is contained in:
ashari4 2023-06-03 01:17:49 -05:00 committed by GitHub
parent 2e66bc8669
commit 18c97381cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -18,6 +18,7 @@ import torch.jit
import torch.onnx
import torch.onnx._onnx_supported_ops
from torch._decomp import decomposition_table
from torch._dynamo.utils import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
@ -632,7 +633,10 @@ class OrtBackend:
)(*args)
# TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to.
# We need input and output tensors' devices to decide if aten::_to_copy is just a Cast.
FakeTensorProp(prim_graph_module).propagate(*args)
fake_mode = detect_fake_mode(args)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode()
FakeTensorProp(prim_graph_module, mode=fake_mode).propagate(*args)
_replace_to_copy_with_to(prim_graph_module)
partitioner = CapabilityBasedPartitioner(
prim_graph_module, self._supported_ops, allows_single_node_partition=False