mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Make static dispatch turn off variable before entering the kernel. (#26908)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26908 This is because in static dispatch, variable is not supported. But you might still "have" a variable somehow (e.g., in JIT initialization code), so we need to translate between the two worlds. Attempt to fix #26764 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D17607550 Pulled By: ezyang fbshipit-source-id: ea2cb931215d749efd95259ad7a1a8ae96c38aed
This commit is contained in:
parent
e75a49341b
commit
9159a601ca
3 changed files with 482 additions and 0 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -239,10 +239,17 @@ static inline ${return_type} ${api_name}(${formals}) {
|
|||
|
||||
# In order to rely on the linker to strip unused ops, it requires us to dispatch statically
|
||||
# in Functions.h and TensorMethods.h.
|
||||
#
|
||||
# NB: The default body also needs to apply a variable guard, as in some
|
||||
# situations what we think is a default body actually does have an
|
||||
# explicit derivative, and thereby would have gotten unwrapped by
|
||||
# the time you get to the implementation.
|
||||
STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\
|
||||
at::AutoNonVariableTypeMode _var_guard(true);
|
||||
${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments});
|
||||
""")
|
||||
STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\
|
||||
at::AutoNonVariableTypeMode _var_guard(true);
|
||||
switch(tensorTypeIdToBackend(impl::dispatchTypeId(${type_set}))) {
|
||||
${static_dispatch_function_switches}
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
#include <ATen/core/EnableNamedTensor.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
#include <ATen/TypeDefault.h>
|
||||
|
|
|
|||
Loading…
Reference in a new issue