Revert "Call jit decomposition in VariableType to increase forward AD coverage (#84151)"

This reverts commit 42d99e6f19.

Reverted https://github.com/pytorch/pytorch/pull/84151 on behalf of https://github.com/malfet due to Regressed test_jvpvjp_nn_functional_layer_norm_cuda_float32, see 42d99e6f19
This commit is contained in:
PyTorch MergeBot 2022-09-07 18:02:27 +00:00
parent 31ef8ddb8c
commit acb4a09628
14 changed files with 213 additions and 280 deletions

View file

@ -133,6 +133,20 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
"please file a bug report instead.");
}
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
if (stack->back().isTuple()) {
IValue tup = stack->back();
stack->pop_back();
for (const auto& elem: tup.toTuple()->elements()) {
stack->push_back(elem);
}
}
}
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {

View file

@ -195,6 +195,12 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
#define VARIADIC_BDIMS_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);
#define RUN_JIT_DECOMPOSITION(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
inline void find_and_unpack_tensors(

View file

@ -15,7 +15,6 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <ATen/InferSize.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
namespace at { namespace functorch {
@ -511,7 +510,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(chunk, chunk_batching_rule);
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>());
RUN_JIT_DECOMPOSITION(trace)
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(repeat, repeat_batch_rule);

View file

@ -389,9 +389,43 @@ WithoutTop::~WithoutTop() {
pushDynamicLayer(std::move(layer_));
}
static void dynamicLayerFrontFallback(
// NOTE: [forward-mode AD decompositions hack]
//
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
// jvp transform, AND we have a decomposition for the operation, then run
// the decomposition.
//
// Let's break that down. There are a douple of moving pieces.
//
// 0. How do we know what transform we're dispatching on?
// Easy, check the top of the DynamicLayerStack and read the transform.
//
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
// gets dispatched to.
// - register a special kernel to the DynamicLayerFrontMode key
// (see JVP_DECOMP)
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
// an arg indicating we're going to use a decomp
//
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
// We currently use python decompositions that we torchscript.
// Ideally c10::OperatorHandle would have a field like this
// to identify the operator.
// The stuff here should map 1:1 with the operator name.
// aten::nll_loss_backward -> nll_loss_backward
// aten::add.Tensor -> add_Tensor
static void call_decomposition_for_jvp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
run_jit_decomposition(op, stack);
}
static void dynamicLayerFrontFallbackOperator(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
bool decomp_jvp) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
@ -400,6 +434,13 @@ static void dynamicLayerFrontFallback(
dump_local_tls();
}
#endif
// Hack: if jvp and we have a decomposition registered, then do the decomposition
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
decomp_jvp) {
return call_decomposition_for_jvp(op, stack);
}
// Save the current LocalDispatchKeySet (to the current DynamicLayer).
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
@ -419,6 +460,16 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
return c10::impl::ForceDispatchKeyGuard(key_set);
}
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, false);
}
void dynamicLayerFrontFallBackWithDecomp(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
return dynamicLayerFrontFallbackOperator(op, stack, true);
}
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& layer = dynamicLayerStackAccessor().back();
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
@ -435,5 +486,24 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}
#define JVP_DECOMP(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
#define JVP_DECOMP2(op, overload) \
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
JVP_DECOMP(nll_loss_backward);
JVP_DECOMP(nll_loss2d_backward);
JVP_DECOMP(_log_softmax_backward_data);
JVP_DECOMP(_softmax_backward_data);
OP_DECOMPOSE(log_sigmoid);
JVP_DECOMP(log_sigmoid_forward);
JVP_DECOMP(native_layer_norm_backward);
JVP_DECOMP(native_batch_norm_backward);
JVP_DECOMP(cudnn_batch_norm_backward);
}
}
} // namespace at

View file

@ -1047,6 +1047,9 @@ class TestOperators(TestCase):
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
xfail('normal', ''),
xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data
xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
xfail('cholesky', ''), # NYI: forward-AD for cholesky
xfail('eig', ''), # NYI: forward-AD for eig
@ -1055,7 +1058,10 @@ class TestOperators(TestCase):
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data
xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('symeig', ''), # NYI: forward AD for symeig
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
@ -1069,6 +1075,7 @@ class TestOperators(TestCase):
xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce
xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
skip('as_strided_scatter', ''), # seems flaky
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
@ -1129,8 +1136,37 @@ class TestOperators(TestCase):
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
return expected
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
# HACK: obviously pytorch should also have the same coverage
# For things that do have the same coverage, we test that jvp x vjp
# are the same between PyTorch and functorch. For things that don't,
# we check that jacfwd(vjp) and jacrev(vjp) are the same. This results
# in slower tests.
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
'nn.functional.nll_loss',
'softmax',
'log_softmax',
'nn.functional.cross_entropy',
'nn.functional.layer_norm',
'nn.functional.batch_norm',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
self.assertFalse(op.supports_fwgrad_bwgrad,
f"{op.name} now supports forward over reverse without a decomposition. " +
"Please remove the decomposition version")
def is_differentiable(t):
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
args = (cotangents, *primals)
if op.name == 'nn.functional.binary_cross_entropy':
argnums = (0, 1) # targets is float32 but isn't differentiable
atol_rtol = 1.5e-4, 1.3e-06
else:
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
atol_rtol = None
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
else:
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
def _make_extremal_inputs(self, shape, device):
if shape is None:

View file

@ -1956,20 +1956,7 @@
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
# HACK: This is just auto_element_wise followed by a view_as. The reason we have
# this is bc forward AD was complaining here about the shapes not being the same:
# the primal/tangent are 0-D/1-D respectively. This started happening after moving the
# jvp decomposition mechanism from functorch to core, possibly due to a batching rule.
# In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual
# formula.
#
# We'd like to avoid keeping the entire jvp decomposition mechanism in functorch,
# just for this single decomposition, but also want to avoid any cases from regressing:
# e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA).
#
# We should either figure out what is going on with vmap or perhaps fwd AD could
# be more tolerant about 0-dim vs 1-dim tensors
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p)
output: auto_element_wise
- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())

View file

@ -31,7 +31,6 @@ from torchgen.api import cpp
from torchgen.api.autograd import (
DifferentiableInput,
dispatch_strategy,
ForwardDerivative,
gen_differentiable_outputs,
is_differentiable,
NativeFunctionWithDifferentiabilityInfo,
@ -598,14 +597,8 @@ at::redispatch::${api_name}(${unpacked_args})"""
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
"""\
auto ${tmp_var} = ([&]() {
if (${try_jit_decomposition_bool} && ${any_has_forward_grad}) {
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
return impl::run_jit_decomposition_with_args_for_jvp<${returns_and_args}>("${op_name}", *opt_op, ks, ${arg_names});
} else {
${guard}
return ${base_type_call};
}
${guard}
return ${base_type_call};
})();
"""
)
@ -649,12 +642,6 @@ isFwGradDefined(${req_inp})\
"""
)
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
"""\
isFwGradDefinedTensorList(${req_inp})\
"""
)
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
"""\
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
@ -985,23 +972,6 @@ def emit_body(
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
)
if requires_derivative and not len(fw_derivatives) == 0:
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
differentiable_outputs
), (
"Expected the number of forward derivatives implemented to match the "
"number of differentiable outputs. NB: This only applies when at least "
"one forward derivative is implemented. Not implementing any forward "
"derivatives is also okay, and we would require inputs to the op to "
"not have associated tangents in that case."
)
try_jit_decomposition = (
requires_derivative
and len(fw_derivatives) == 0
and (not modifies_arguments(f))
and (not returns_void)
)
def emit_save_inputs() -> List[str]:
setup: List[str] = []
if info is None or not info.has_derivatives:
@ -1368,9 +1338,7 @@ def emit_body(
)
return call
def emit_call(
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
) -> str:
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
@ -1384,51 +1352,13 @@ def emit_body(
else:
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
try_jit_decomposition_bool = "true" if try_jit_decomposition else "false"
any_has_forward_grad = (
get_any_has_fw_grad_cond(derivative=None)
if requires_derivative
else "false"
)
return_types = ", ".join(
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
)
if len(f.func.returns) > 1:
return_types = f"std::tuple<{return_types}>"
arg_types = [
cpp.argument_type(a, binds="", symint=True).cpp_type()
for a in f.func.arguments.flat_all
]
arg_names = [
a.name
for a in cpp.arguments(
f.func.arguments,
faithful=True,
symint=True,
method=False,
cpp_no_default_args=set(),
)
]
if not modifies_arguments(f) and not returns_void:
# Just to keep things simple here, we only care about this path
# and always emit the if/else for now
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
base_type_call=base_type_call,
tmp_var=TMP_VAR,
guard=guard,
try_jit_decomposition_bool=try_jit_decomposition_bool,
any_has_forward_grad=any_has_forward_grad,
op_name=cpp.name(f.func),
op_overload=f.func.name.overload_name,
returns_and_args=return_types + ", " + ", ".join(arg_types),
arg_names=arg_names,
base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard
)
call += wrap_output(f, unpacked_bindings, TMP_VAR)
else:
assert not try_jit_decomposition
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
base_type_call=base_type_call, guard=guard
)
@ -1476,14 +1406,38 @@ def emit_body(
def emit_any_has_forward_grad() -> List[str]:
content: List[str] = []
for derivative in fw_derivatives:
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
assert derivative.required_inputs_fw_grad is not None
requires_fw_grad = " || ".join(
[
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]
)
if not requires_fw_grad:
# Handle functions like stack
# For these, we don't unpack anything and always call the user function
if not (
len(differentiable_inputs) == 1
and is_tensor_list_type(differentiable_inputs[0].type)
):
raise RuntimeError(
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
"forward AD formula does not use any input tangent) even though a forward gradient "
"formula has been defined for it. This case should only happen for function that "
"take a single TensorList as input. All other cases are not supported right now."
)
requires_fw_grad = "true"
if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && ({requires_fw_grad})"
content.append(
f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
f"(void){get_any_has_forward_grad_name(derivative.var_names)};"
)
return content
def emit_check_inplace() -> List[str]:
@ -1606,83 +1560,46 @@ def emit_body(
content.append("\n".join(fw_grad_setters))
return content
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
#
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
#
if derivative is None:
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
# - Used in the out_fn case when we want to forbid fw derivatives
# - Used in the case where the fw_derivative is not defined, but we want
# To check if there is a decomposition registered for jvp
to_check: List[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
)
):
if is_tensor_type(inp.type):
to_check.append(
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
)
elif is_tensor_list_type(inp.type):
to_check.append(
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
req_inp=inp.name
)
)
else:
raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
)
return f'({" || ".join(to_check)})'
else:
# (2) If derivative is provided, use that information to determine which inputs
# to check fw_grad for
assert derivative.required_inputs_fw_grad is not None
if len(derivative.required_inputs_fw_grad) == 0:
# Handle functions like stack
# For these, we don't unpack anything and always call the user function
if not (
len(differentiable_inputs) == 1
and is_tensor_list_type(differentiable_inputs[0].type)
):
raise RuntimeError(
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
"forward AD formula does not use any input tangent) even though a forward gradient "
"formula has been defined for it. This case should only happen for function that "
"take a single TensorList as input. All other cases are not supported right now."
)
any_has_fw_grad = "true"
else:
any_has_fw_grad = " || ".join(
[
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]
)
any_has_fw_grad = f"({any_has_fw_grad})"
return any_has_fw_grad
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
if is_out_fn:
msg = "because it is an out= function"
else:
msg = (
"because it has not been implemented yet.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation."
def get_msg() -> str:
if is_out_fn:
msg = "because it is an out= function"
else:
msg = (
"because it has not been implemented yet.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation."
)
return msg
res = ""
to_check: List[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
)
cond = get_any_has_fw_grad_cond(derivative=None)
return (
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
if cond != ""
else ""
)
):
if is_tensor_type(inp.type):
to_check.append(
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
)
elif is_tensor_list_type(inp.type):
cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
arg=inp.name, cond=cond, name=name, msg=get_msg()
)
else:
raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
)
if len(to_check) > 0:
cond = " || ".join(to_check)
res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(
cond=cond, name=name, msg=get_msg()
)
return res
body: List[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f)
@ -1696,7 +1613,7 @@ def emit_body(
body.extend(setup_derivative(differentiable_inputs))
body.append(declare_returned_variables(f))
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
body.append(emit_call(f, unpacked_bindings))
if requires_derivative:
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
@ -1706,11 +1623,20 @@ def emit_body(
if is_out_fn:
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
else:
if requires_derivative and not try_jit_decomposition:
if len(fw_derivatives) > 0:
body.extend(emit_fw_derivatives())
else:
if requires_derivative:
body.extend(emit_fw_derivatives())
if len(fw_derivatives) == 0:
body.append(emit_forbid_fw_derivatives())
else:
assert sum(
len(derivative.var_names) for derivative in fw_derivatives
) == len(differentiable_outputs), (
"Expected the number of forward derivatives implemented to match the "
"number of differentiable outputs. NB: This only applies when at least "
"one forward derivative is implemented. Not implementing any forward "
"derivatives is also okay, and we would require inputs to the op to "
"not have associated tangents in that case."
)
if requires_derivative:
# Save only after the forward AD has been set up

View file

@ -1,5 +1,4 @@
#include "torch/csrc/autograd/VariableTypeUtils.h"
#include "torch/csrc/autograd/VariableTypeUtilsDependOnOps.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/FunctionsManual.h"

View file

@ -1,40 +0,0 @@
#pragma once
#include <torch/csrc/jit/runtime/decomposition_registry.h>
// This is the set of helpers in VariableTypeUtils have a dependency on
// native_functions.yaml meaning the file will need to be re-compiled every time
// an operator is changed or added. We cannot simply put these functions in
// VariableType.h and VariableTypeutils.h, since they are included in files like
// ADInplaceOrViewType_X.cpp which don't always want to be recompiled.
namespace torch {
namespace autograd {
namespace impl {
// Depends on torch/csrc/jit/ir/ir.h -> aten/src/ATen/core/interned_strings.h
template <class Return, class... Args>
Return run_jit_decomposition_with_args_for_jvp(
c10::string_view name,
const c10::OperatorHandle& opHandle,
c10::DispatchKeySet dispatchKeySet,
Args... args) {
bool has_decomp = jit::has_jit_decomposition(opHandle.schema());
TORCH_CHECK_NOT_IMPLEMENTED(
has_decomp,
"Trying to use forward AD with ",
name,
" that does not support it"
"because it has not been implemented yet and does not have a decomposition.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation.");
return c10::KernelFunction::makeFromBoxedKernel(
c10::BoxedKernel::makeFromFunction<&jit::run_jit_decomposition>())
.call<Return, Args...>(opHandle, dispatchKeySet, args...);
}
} // namespace impl
} // namespace autograd
} // namespace torch

View file

@ -100,23 +100,5 @@ inline bool isFwGradDefined(const c10::optional<at::Tensor>& t) {
return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
}
inline bool isFwGradDefinedTensorList(const at::TensorList& variables) {
bool ret = false;
for (auto& variable : variables) {
ret |= isFwGradDefined(variable);
}
return ret;
}
inline bool isFwGradDefinedTensorList(
const c10::List<c10::optional<at::Tensor>> li) {
bool ret = false;
for (auto i : c10::irange(li.size())) {
auto t = li.get(i);
ret |= (t.has_value() && isFwGradDefined(t.value()));
}
return ret;
}
} // namespace autograd
} // namespace torch

View file

@ -160,26 +160,6 @@ void RegisterDecomposition(
schema_to_decomposition[&schema] = g;
}
void run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
// TODO: templatize based on op and keep static trace_exec
auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
trace_exec->run((*stack));
if (stack->back().isTuple()) {
at::IValue tup = stack->back();
stack->pop_back();
for (const auto& elem : tup.toTuple()->elements()) {
stack->push_back(elem);
}
}
}
bool has_jit_decomposition(const FunctionSchema& schema) {
return GetDecompositionFunction(schema).has_value();
}
Function* GetDecompositionExecutor(const FunctionSchema& schema) {
auto maybe_func = GetDecompositionFunction(schema);
TORCH_INTERNAL_ASSERT(maybe_func);

View file

@ -25,11 +25,5 @@ TORCH_API Function* GetDecompositionExecutor(const char* schema_literal);
TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema);
TORCH_API void run_jit_decomposition(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
} // namespace jit
} // namespace torch

View file

@ -39,10 +39,6 @@ import torch._refs.nn.functional
import torch._refs.special
import torch._refs.linalg
# Make sure that decompositions used for test_forward_mode_AD and
# test_fn_fwgrad_bwgrad are registered to the jit
import torch._decomp.decompositions_for_jvp
import torch._prims as prims # noqa: F401
from torch.utils._pytree import tree_flatten
@ -10168,7 +10164,6 @@ op_db: List[OpInfo] = [
assert_jit_shape_analysis=True,
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=True),
OpInfo('softmax',
aliases=('special.softmax', 'nn.functional.softmax',),
@ -10178,7 +10173,6 @@ op_db: List[OpInfo] = [
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=True),
# `softmin` supports different dtypes based on whether `dtype` argument,
# is passed or not. Hence two OpInfo entries, one with dtype and other without.
@ -10191,7 +10185,6 @@ op_db: List[OpInfo] = [
assert_jit_shape_analysis=False,
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo('nn.functional.softmin',
variant_test_name="with_dtype",
@ -10200,7 +10193,6 @@ op_db: List[OpInfo] = [
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo(
"nn.functional.cross_entropy",
@ -10209,7 +10201,6 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_cross_entropy,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
@ -10301,7 +10292,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
assert_jit_shape_analysis=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_native_layer_norm,
error_inputs_func=error_inputs_native_layer_norm,
skips=(
@ -10673,7 +10663,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
@ -10692,7 +10681,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
decorators=[
DecorateInfo(
@ -11732,7 +11720,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs_batch_norm,
skips=(
@ -11755,7 +11742,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[onlyCUDA, disablecuDNN],
skips=(
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
@ -14718,7 +14704,6 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_softmax_variant,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_autodiffed=True),
OpInfo(
'log_softmax',
@ -14728,7 +14713,6 @@ op_db: List[OpInfo] = [
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_autodiffed=True),
UnaryUfuncInfo('logit',
aten_backward_name='logit_backward',
@ -15605,7 +15589,6 @@ op_db: List[OpInfo] = [
supports_out=False,
sample_inputs_func=sample_inputs_nll_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
skips=(
# RuntimeError:

View file

@ -990,7 +990,6 @@ op_db: List[OpInfo] = [
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
@ -1018,7 +1017,6 @@ op_db: List[OpInfo] = [
],
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(
@ -1039,7 +1037,6 @@ op_db: List[OpInfo] = [
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
),
OpInfo(