pytorch/test/cpp_extensions
Edward Z. Yang 19e27b1556 Make dispatcher registrations of SymInt functions backwards compatible (#84557)
Previously, when we SymInt-ify a schema, this is a BC-breaking change
for all people who registered functions for that function; they
must accept c10::SymInt where they previously accepted int64_t.
This is not great.

With this change, I accept old type registrations transparently.  The
idea is in several parts:

- At the registration site, at compile time I have no idea whether or not
  if the function being registered has a SymInt schema or not.  So I
  must defer the exact compatibility check.  What I do instead is
  check if the function pointer registered to me has SymInt in the
  argument or not.  If it does, I assume it is new-style and ensure
  it is also registered to a special sym_ slot on KernelFunction.
  If not, it only goes in the conventional slot.

- At the dispatcher site, I know at compile time whether or not this
  is a SymInt function.  If it is, I check for a sym_ slot on the
  KernelFunction, and preferentially use that.  If no such slot
  exists, I then fall back to the regular slot... but I convert
  all SymInt arguments to int64_t arguments (doing assertions that
  no true symbolic integer was passed.)  I can skip this test entirely
  if the function doesn't have any SymInts in it; in that case I know
  that only the original slot could have been registered. Fortunately,
  both branches of the short circuit typecheck, so I didn't have to
  use SFINAE or if-constexpr to make it work; just a plain if statement
  that I expect the compiler to optimize away.

- Schema validation is now modestly more complicated. There are two parts. First, function schema validation proceeds by checking if the signature in question has any SymInt-like types in it or not. If it does, we do function schema validation against the real types; if it doesn't, we do validation against the fake types (but only for symint; MemoryFormat is always MemoryFormat). Second, cpp signature validation also keeps track of a "symint" cpp signature and a "non-symint" cpp signature. We only compare symint with symint, and non-symint with non-symint. I did not implement checking a conflict between a symint and non-symint cpp signature, though in principle you could try converting the SymInt types to non-SymInt types and doing the comparison that way.

To show it is working, I remove a bunch of c10::asIntArrayRefSlow shims, as the dispatcher is able to insert them automatically now.

I didn't update the Metal registrations (though they can get similar treatment) as OSS CI coverage is insufficient for this case.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39280965](https://our.internmc.facebook.com/intern/diff/D39280965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84557
Approved by: https://github.com/wconstab
2022-09-07 16:30:21 +00:00
..
no_python_abi_suffix_test
self_compiler_include_dirs_test
torch_test_cpp_extension
cpp_c10d_extension.cpp
cpp_c10d_extension.hpp
cpp_frontend_extension.cpp
cublas_extension.cpp
cuda_dlink_extension.cpp
cuda_dlink_extension_add.cu
cuda_dlink_extension_add.cuh
cuda_dlink_extension_kernel.cu
cuda_extension.cpp
cuda_extension.cu
cuda_extension_kernel.cu
cuda_extension_kernel2.cu
cudnn_extension.cpp
cusolver_extension.cpp
dangling_impl_extension.cpp
doubler.h
extension.cpp
jit_extension.cpp
jit_extension2.cpp
open_registration_extension.cpp Make dispatcher registrations of SymInt functions backwards compatible (#84557) 2022-09-07 16:30:21 +00:00
ort_extension.cpp Make dispatcher registrations of SymInt functions backwards compatible (#84557) 2022-09-07 16:30:21 +00:00
rng_extension.cpp
setup.py
torch_library.cu