mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Previously, when we SymInt-ify a schema, this is a BC-breaking change for all people who registered functions for that function; they must accept c10::SymInt where they previously accepted int64_t. This is not great. With this change, I accept old type registrations transparently. The idea is in several parts: - At the registration site, at compile time I have no idea whether or not if the function being registered has a SymInt schema or not. So I must defer the exact compatibility check. What I do instead is check if the function pointer registered to me has SymInt in the argument or not. If it does, I assume it is new-style and ensure it is also registered to a special sym_ slot on KernelFunction. If not, it only goes in the conventional slot. - At the dispatcher site, I know at compile time whether or not this is a SymInt function. If it is, I check for a sym_ slot on the KernelFunction, and preferentially use that. If no such slot exists, I then fall back to the regular slot... but I convert all SymInt arguments to int64_t arguments (doing assertions that no true symbolic integer was passed.) I can skip this test entirely if the function doesn't have any SymInts in it; in that case I know that only the original slot could have been registered. Fortunately, both branches of the short circuit typecheck, so I didn't have to use SFINAE or if-constexpr to make it work; just a plain if statement that I expect the compiler to optimize away. - Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way. To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now. I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557 Approved by: https://github.com/wconstab |
||
|---|---|---|
| .. | ||
| no_python_abi_suffix_test | ||
| self_compiler_include_dirs_test | ||
| torch_test_cpp_extension | ||
| cpp_c10d_extension.cpp | ||
| cpp_c10d_extension.hpp | ||
| cpp_frontend_extension.cpp | ||
| cublas_extension.cpp | ||
| cuda_dlink_extension.cpp | ||
| cuda_dlink_extension_add.cu | ||
| cuda_dlink_extension_add.cuh | ||
| cuda_dlink_extension_kernel.cu | ||
| cuda_extension.cpp | ||
| cuda_extension.cu | ||
| cuda_extension_kernel.cu | ||
| cuda_extension_kernel2.cu | ||
| cudnn_extension.cpp | ||
| cusolver_extension.cpp | ||
| dangling_impl_extension.cpp | ||
| doubler.h | ||
| extension.cpp | ||
| jit_extension.cpp | ||
| jit_extension2.cpp | ||
| open_registration_extension.cpp | ||
| ort_extension.cpp | ||
| rng_extension.cpp | ||
| setup.py | ||
| torch_library.cu | ||