New calling convention for Python dispatcher (#85133)
Instead of calling into the Python dispatcher for EVERY dispatcher
call, we now have a two step process. First, we
getattr(op: OpOverload, dispatch_key) to "load" the handler for the
function. This can either be a conventional function (in which
case we will call it, in the same way the old Python dispatcher
worked), or it can be a DispatchKey, in which case we will directly
call that DispatchKey in C++, bypassing marshalling between Python
and C++ entirely. OpOverload.__getattr__ is carefully written so
that it will cache the
A further optimization would be to define __slots__ on OpOverload,
and ensuring that the DispatchKey strings are interned.
The resulting Python dispatcher is less flexible: after the first
lookup, the handler is cached and we won't recompute it. Furthermore,
by default, dispatches will not go into Python, and so you won't
get stack frames for the Python dispatcher by default. But we get
a huge performance improvement: on the following microbenchmark
we go from 2.5s to 1.9s.
```
import time
import torch
from functorch import make_fx
def f(x):
for i in range(1000):
x = x * x
return x
begin = time.time()
res = make_fx(f, tracing_mode="symbolic")(torch.randn(10, 20))
print(time.time()-begin)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85133
Approved by: https://github.com/wconstab
2022-09-16 17:23:01 +00:00
|
|
|
import torch._C
|
2022-09-15 00:43:36 +00:00
|
|
|
from contextlib import contextmanager
|
Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end. (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.
This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)
There are some misc fixups which I can split if people really want.
- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
_dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
in all namespaces. Iterator support on all torch._ops objects to let
you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
in a context
- check_metadata_matches for easily comparing outputs of functions and see
if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-22 15:47:48 +00:00
|
|
|
import unittest.mock
|
|
|
|
|
import torch
|
|
|
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
import itertools
|
2023-03-27 15:04:39 +00:00
|
|
|
from typing import Iterator
|
|
|
|
|
import torch._ops
|
2022-09-15 00:43:36 +00:00
|
|
|
|
|
|
|
|
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def no_python_dispatcher():
|
|
|
|
|
g = torch._C._DisablePythonDispatcher()
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
del g
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def enable_python_dispatcher():
|
|
|
|
|
g = torch._C._EnablePythonDispatcher()
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
del g
|
Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end. (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.
This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)
There are some misc fixups which I can split if people really want.
- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
_dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
in all namespaces. Iterator support on all torch._ops objects to let
you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
in a context
- check_metadata_matches for easily comparing outputs of functions and see
if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-22 15:47:48 +00:00
|
|
|
|
|
|
|
|
CROSSREF_FUNCTIONALIZE = False
|
|
|
|
|
|
2023-03-27 15:04:39 +00:00
|
|
|
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
|
|
|
|
"""
|
|
|
|
|
Warning: the set of overloads this will report is very subtle. It is precisely
|
|
|
|
|
the set of torch.ops functions that have actually been accessed from Python
|
|
|
|
|
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
|
|
|
|
from the set of registered operators, which will in general be a larger set,
|
|
|
|
|
as this would include all operators which we ran C++ static initializers or
|
|
|
|
|
Python operator registration on. This does not eagerly populate the list on
|
|
|
|
|
torch.ops.aten; this list is lazy!
|
|
|
|
|
|
|
|
|
|
In other words, this is good for traversing over everything that has an
|
|
|
|
|
OpOverload object allocated in Python. We use it for cache invalidation, but
|
|
|
|
|
don't rely on this list being complete.
|
|
|
|
|
|
|
|
|
|
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
|
|
|
|
to be complete either, as a subsequent lazy load of a library which triggers more
|
|
|
|
|
registrations could add more things to the set.
|
|
|
|
|
"""
|
Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end. (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.
This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)
There are some misc fixups which I can split if people really want.
- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
_dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
in all namespaces. Iterator support on all torch._ops objects to let
you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
in a context
- check_metadata_matches for easily comparing outputs of functions and see
if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-22 15:47:48 +00:00
|
|
|
for ns in torch.ops:
|
|
|
|
|
packets = getattr(torch.ops, ns)
|
|
|
|
|
for op_name in packets:
|
|
|
|
|
packet = getattr(packets, op_name)
|
|
|
|
|
for overload in packet:
|
|
|
|
|
yield getattr(packet, overload)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def suspend_functionalization():
|
|
|
|
|
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
|
|
|
|
|
f_rv = torch._C._functionalization_reapply_views_tls()
|
|
|
|
|
if f_tls:
|
|
|
|
|
torch._disable_functionalization()
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
if f_tls:
|
|
|
|
|
torch._enable_functionalization(reapply_views=f_rv)
|
|
|
|
|
|
|
|
|
|
def check_tensor_metadata_matches(nv, rv, desc):
|
|
|
|
|
assert callable(desc)
|
|
|
|
|
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
|
|
|
|
|
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
|
|
|
|
|
same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
|
|
|
|
|
assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
|
|
|
|
|
|
|
|
|
def check_metadata_matches(n, r, desc):
|
|
|
|
|
assert callable(desc)
|
|
|
|
|
n_vals, n_spec = pytree.tree_flatten(n)
|
|
|
|
|
r_vals, r_spec = pytree.tree_flatten(r)
|
|
|
|
|
# TODO: test the specs match; empirically sometimes we have a tuple
|
|
|
|
|
# on one side and a list on the other
|
|
|
|
|
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
|
|
|
|
|
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
|
|
|
|
|
if not isinstance(rv, torch.Tensor):
|
|
|
|
|
continue
|
|
|
|
|
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
|
|
|
|
|
|
|
|
|
|
class Lit:
|
|
|
|
|
def __init__(self, s):
|
|
|
|
|
self.s = s
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.s
|
|
|
|
|
|
|
|
|
|
def _fmt(a: object) -> object:
|
|
|
|
|
if isinstance(a, torch.Tensor):
|
|
|
|
|
return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
|
|
|
|
|
else:
|
|
|
|
|
return a
|
|
|
|
|
|
|
|
|
|
def make_crossref_functionalize(op, final_key):
|
|
|
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
|
# This case is pretty weird, suppress it for now
|
|
|
|
|
if op == torch.ops.aten.lift_fresh.default:
|
|
|
|
|
return final_key
|
|
|
|
|
|
|
|
|
|
def handler(*args, **kwargs):
|
|
|
|
|
fake_mode = FakeTensorMode()
|
|
|
|
|
|
|
|
|
|
def fakeify_defun(t):
|
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
|
|
|
if torch._is_functional_tensor(t):
|
|
|
|
|
r = torch._from_functional_tensor(t)
|
|
|
|
|
# NB: This assumes that the inner tensor sizes/strides match
|
|
|
|
|
# the outer tensor sizes/strides. This doesn't necessarily have to
|
|
|
|
|
# be the case, see discussion at
|
|
|
|
|
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
|
|
|
|
|
assert t.size() == r.size()
|
|
|
|
|
assert t.stride() == r.stride()
|
|
|
|
|
else:
|
|
|
|
|
r = t
|
|
|
|
|
# TODO: suppress guards
|
|
|
|
|
return fake_mode.from_tensor(r)
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
def maybe_detach(t):
|
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
|
|
|
return t.detach()
|
|
|
|
|
else:
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
with suspend_functionalization():
|
|
|
|
|
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
|
|
|
|
|
orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
|
|
|
|
|
with fake_mode:
|
|
|
|
|
f_r = op(*f_args, **f_kwargs)
|
|
|
|
|
r = op._op_dk(final_key, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def desc():
|
|
|
|
|
fmt_args = ", ".join(
|
|
|
|
|
itertools.chain(
|
|
|
|
|
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
|
|
|
|
|
(f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return f"{op}({fmt_args})"
|
|
|
|
|
check_metadata_matches(f_r, r, desc)
|
|
|
|
|
return r
|
|
|
|
|
return handler
|
|
|
|
|
|
|
|
|
|
# NB: enabling this is slow, don't do it in a hot loop. This is purely
|
|
|
|
|
# for debugging purposes.
|
|
|
|
|
@contextmanager
|
|
|
|
|
def enable_crossref_functionalize():
|
2023-03-27 15:04:39 +00:00
|
|
|
for op in all_py_loaded_overloads():
|
Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end. (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.
This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)
There are some misc fixups which I can split if people really want.
- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
_dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
in all namespaces. Iterator support on all torch._ops objects to let
you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
in a context
- check_metadata_matches for easily comparing outputs of functions and see
if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-22 15:47:48 +00:00
|
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
|
|
|
|
try:
|
|
|
|
|
with enable_python_dispatcher(), unittest.mock.patch(
|
|
|
|
|
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
2023-03-27 15:04:39 +00:00
|
|
|
for op in all_py_loaded_overloads():
|
Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end. (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.
This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)
There are some misc fixups which I can split if people really want.
- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
_dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
in all namespaces. Iterator support on all torch._ops objects to let
you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
in a context
- check_metadata_matches for easily comparing outputs of functions and see
if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-22 15:47:48 +00:00
|
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|