mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Call jit decomposition in VariableType to increase forward AD coverage (#84151)"
This reverts commit42d99e6f19. Reverted https://github.com/pytorch/pytorch/pull/84151 on behalf of https://github.com/malfet due to Regressed test_jvpvjp_nn_functional_layer_norm_cuda_float32, see42d99e6f19
This commit is contained in:
parent
31ef8ddb8c
commit
acb4a09628
14 changed files with 213 additions and 280 deletions
|
|
@ -133,6 +133,20 @@ void vmapIncompatibleInplaceError(const char* schema_name) {
|
|||
"please file a bug report instead.");
|
||||
}
|
||||
|
||||
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
// TODO: templatize based on op and keep static trace_exec
|
||||
auto * trace_exec = torch::jit::GetDecompositionExecutor(schema);
|
||||
trace_exec->run((*stack));
|
||||
if (stack->back().isTuple()) {
|
||||
IValue tup = stack->back();
|
||||
stack->pop_back();
|
||||
for (const auto& elem: tup.toTuple()->elements()) {
|
||||
stack->push_back(elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
|
||||
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
|
||||
if (logical_scalar_tensor.scalar_type() != result_type) {
|
||||
|
|
|
|||
|
|
@ -195,6 +195,12 @@ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t
|
|||
#define VARIADIC_BDIMS_BOXED(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
||||
|
||||
void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
|
||||
#define RUN_JIT_DECOMPOSITION(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>());
|
||||
|
||||
|
||||
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
|
||||
|
||||
inline void find_and_unpack_tensors(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/util/SmallBuffer.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <torch/csrc/jit/runtime/decomposition_registry.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
|
|
@ -511,7 +510,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
VMAP_SUPPORT(chunk, chunk_batching_rule);
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT(flip, flip_batch_rule);
|
||||
m.impl("trace", torch::CppFunction::makeFromBoxedFunction<&torch::jit::run_jit_decomposition>());
|
||||
RUN_JIT_DECOMPOSITION(trace)
|
||||
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
|
||||
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
|
||||
VMAP_SUPPORT(repeat, repeat_batch_rule);
|
||||
|
|
|
|||
|
|
@ -389,9 +389,43 @@ WithoutTop::~WithoutTop() {
|
|||
pushDynamicLayer(std::move(layer_));
|
||||
}
|
||||
|
||||
static void dynamicLayerFrontFallback(
|
||||
// NOTE: [forward-mode AD decompositions hack]
|
||||
//
|
||||
// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the
|
||||
// jvp transform, AND we have a decomposition for the operation, then run
|
||||
// the decomposition.
|
||||
//
|
||||
// Let's break that down. There are a douple of moving pieces.
|
||||
//
|
||||
// 0. How do we know what transform we're dispatching on?
|
||||
// Easy, check the top of the DynamicLayerStack and read the transform.
|
||||
//
|
||||
// 1. Next, we must identify when an operation (e.g. nll_loss_backward)
|
||||
// gets dispatched to.
|
||||
// - register a special kernel to the DynamicLayerFrontMode key
|
||||
// (see JVP_DECOMP)
|
||||
// - that special kernel invokes dynamicLayerFrontFallbackOperator with
|
||||
// an arg indicating we're going to use a decomp
|
||||
//
|
||||
// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp.
|
||||
// We currently use python decompositions that we torchscript.
|
||||
|
||||
// Ideally c10::OperatorHandle would have a field like this
|
||||
// to identify the operator.
|
||||
// The stuff here should map 1:1 with the operator name.
|
||||
// aten::nll_loss_backward -> nll_loss_backward
|
||||
// aten::add.Tensor -> add_Tensor
|
||||
|
||||
static void call_decomposition_for_jvp(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
run_jit_decomposition(op, stack);
|
||||
}
|
||||
|
||||
static void dynamicLayerFrontFallbackOperator(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack,
|
||||
bool decomp_jvp) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
|
||||
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
|
||||
|
|
@ -400,6 +434,13 @@ static void dynamicLayerFrontFallback(
|
|||
dump_local_tls();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Hack: if jvp and we have a decomposition registered, then do the decomposition
|
||||
if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp &&
|
||||
decomp_jvp) {
|
||||
return call_decomposition_for_jvp(op, stack);
|
||||
}
|
||||
|
||||
// Save the current LocalDispatchKeySet (to the current DynamicLayer).
|
||||
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
|
||||
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
|
||||
|
|
@ -419,6 +460,16 @@ restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
|
|||
return c10::impl::ForceDispatchKeyGuard(key_set);
|
||||
}
|
||||
|
||||
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
return dynamicLayerFrontFallbackOperator(op, stack, false);
|
||||
}
|
||||
|
||||
void dynamicLayerFrontFallBackWithDecomp(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
return dynamicLayerFrontFallbackOperator(op, stack, true);
|
||||
}
|
||||
|
||||
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto& layer = dynamicLayerStackAccessor().back();
|
||||
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
|
||||
|
|
@ -435,5 +486,24 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
|
|||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
|
||||
}
|
||||
|
||||
#define JVP_DECOMP(op) \
|
||||
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
|
||||
|
||||
#define JVP_DECOMP2(op, overload) \
|
||||
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>());
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
|
||||
JVP_DECOMP(nll_loss_backward);
|
||||
JVP_DECOMP(nll_loss2d_backward);
|
||||
JVP_DECOMP(_log_softmax_backward_data);
|
||||
JVP_DECOMP(_softmax_backward_data);
|
||||
OP_DECOMPOSE(log_sigmoid);
|
||||
JVP_DECOMP(log_sigmoid_forward);
|
||||
JVP_DECOMP(native_layer_norm_backward);
|
||||
JVP_DECOMP(native_batch_norm_backward);
|
||||
JVP_DECOMP(cudnn_batch_norm_backward);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -1047,6 +1047,9 @@ class TestOperators(TestCase):
|
|||
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
|
||||
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
|
||||
xfail('normal', ''),
|
||||
xfail('_masked.log_softmax', ''), # NYI: forward-AD for _log_softmax_backward_data
|
||||
xfail('_masked.softmax', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
|
||||
xfail('cholesky', ''), # NYI: forward-AD for cholesky
|
||||
xfail('eig', ''), # NYI: forward-AD for eig
|
||||
|
|
@ -1055,7 +1058,10 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
|
||||
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
|
||||
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
|
||||
xfail('nn.functional.instance_norm', ''), # NYI: forward AD for native_batch_norm_backward
|
||||
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
|
||||
xfail('nn.functional.softmin', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('nn.functional.softmin', 'with_dtype'), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('renorm', ''), # NYI: forward AD for renorm
|
||||
xfail('symeig', ''), # NYI: forward AD for symeig
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
|
||||
|
|
@ -1069,6 +1075,7 @@ class TestOperators(TestCase):
|
|||
xfail('scatter_reduce', 'mean'), # NYI: forward-AD for scatter_reduce
|
||||
xfail('scatter_reduce', 'prod'), # NYI: forward-AD for scatter_reduce
|
||||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||
xfail('native_layer_norm', ''), # NYI: forward-AD for native_layer_norm_backward
|
||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||
skip('as_strided_scatter', ''), # seems flaky
|
||||
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
|
||||
|
|
@ -1129,8 +1136,37 @@ class TestOperators(TestCase):
|
|||
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
|
||||
return expected
|
||||
|
||||
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
|
||||
self.assertEqual(result, expected)
|
||||
# HACK: obviously pytorch should also have the same coverage
|
||||
# For things that do have the same coverage, we test that jvp x vjp
|
||||
# are the same between PyTorch and functorch. For things that don't,
|
||||
# we check that jacfwd(vjp) and jacrev(vjp) are the same. This results
|
||||
# in slower tests.
|
||||
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
|
||||
'nn.functional.nll_loss',
|
||||
'softmax',
|
||||
'log_softmax',
|
||||
'nn.functional.cross_entropy',
|
||||
'nn.functional.layer_norm',
|
||||
'nn.functional.batch_norm',
|
||||
}
|
||||
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
|
||||
self.assertFalse(op.supports_fwgrad_bwgrad,
|
||||
f"{op.name} now supports forward over reverse without a decomposition. " +
|
||||
"Please remove the decomposition version")
|
||||
|
||||
def is_differentiable(t):
|
||||
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
|
||||
args = (cotangents, *primals)
|
||||
if op.name == 'nn.functional.binary_cross_entropy':
|
||||
argnums = (0, 1) # targets is float32 but isn't differentiable
|
||||
atol_rtol = 1.5e-4, 1.3e-06
|
||||
else:
|
||||
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
|
||||
atol_rtol = None
|
||||
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
|
||||
else:
|
||||
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def _make_extremal_inputs(self, shape, device):
|
||||
if shape is None:
|
||||
|
|
|
|||
|
|
@ -1956,20 +1956,7 @@
|
|||
|
||||
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
|
||||
self: log_sigmoid_backward(grad, self, buffer)
|
||||
# HACK: This is just auto_element_wise followed by a view_as. The reason we have
|
||||
# this is bc forward AD was complaining here about the shapes not being the same:
|
||||
# the primal/tangent are 0-D/1-D respectively. This started happening after moving the
|
||||
# jvp decomposition mechanism from functorch to core, possibly due to a batching rule.
|
||||
# In functorch we rely on OP_DECOMPOSE, but now we compute forward AD using an actual
|
||||
# formula.
|
||||
#
|
||||
# We'd like to avoid keeping the entire jvp decomposition mechanism in functorch,
|
||||
# just for this single decomposition, but also want to avoid any cases from regressing:
|
||||
# e.g. test_vmapjvpall_nn_functional_logsigmoid_cuda_float32 (passes on cpu, fails on CUDA).
|
||||
#
|
||||
# We should either figure out what is going on with vmap or perhaps fwd AD could
|
||||
# be more tolerant about 0-dim vs 1-dim tensors
|
||||
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj().view_as(self_p)
|
||||
output: auto_element_wise
|
||||
|
||||
- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from torchgen.api import cpp
|
|||
from torchgen.api.autograd import (
|
||||
DifferentiableInput,
|
||||
dispatch_strategy,
|
||||
ForwardDerivative,
|
||||
gen_differentiable_outputs,
|
||||
is_differentiable,
|
||||
NativeFunctionWithDifferentiabilityInfo,
|
||||
|
|
@ -598,14 +597,8 @@ at::redispatch::${api_name}(${unpacked_args})"""
|
|||
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
|
||||
"""\
|
||||
auto ${tmp_var} = ([&]() {
|
||||
if (${try_jit_decomposition_bool} && ${any_has_forward_grad}) {
|
||||
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
|
||||
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
|
||||
return impl::run_jit_decomposition_with_args_for_jvp<${returns_and_args}>("${op_name}", *opt_op, ks, ${arg_names});
|
||||
} else {
|
||||
${guard}
|
||||
return ${base_type_call};
|
||||
}
|
||||
${guard}
|
||||
return ${base_type_call};
|
||||
})();
|
||||
"""
|
||||
)
|
||||
|
|
@ -649,12 +642,6 @@ isFwGradDefined(${req_inp})\
|
|||
"""
|
||||
)
|
||||
|
||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
|
||||
"""\
|
||||
isFwGradDefinedTensorList(${req_inp})\
|
||||
"""
|
||||
)
|
||||
|
||||
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
|
||||
"""\
|
||||
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
|
||||
|
|
@ -985,23 +972,6 @@ def emit_body(
|
|||
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
|
||||
)
|
||||
|
||||
if requires_derivative and not len(fw_derivatives) == 0:
|
||||
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
|
||||
differentiable_outputs
|
||||
), (
|
||||
"Expected the number of forward derivatives implemented to match the "
|
||||
"number of differentiable outputs. NB: This only applies when at least "
|
||||
"one forward derivative is implemented. Not implementing any forward "
|
||||
"derivatives is also okay, and we would require inputs to the op to "
|
||||
"not have associated tangents in that case."
|
||||
)
|
||||
try_jit_decomposition = (
|
||||
requires_derivative
|
||||
and len(fw_derivatives) == 0
|
||||
and (not modifies_arguments(f))
|
||||
and (not returns_void)
|
||||
)
|
||||
|
||||
def emit_save_inputs() -> List[str]:
|
||||
setup: List[str] = []
|
||||
if info is None or not info.has_derivatives:
|
||||
|
|
@ -1368,9 +1338,7 @@ def emit_body(
|
|||
)
|
||||
return call
|
||||
|
||||
def emit_call(
|
||||
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
|
||||
) -> str:
|
||||
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
|
||||
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
||||
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
||||
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
|
||||
|
|
@ -1384,51 +1352,13 @@ def emit_body(
|
|||
else:
|
||||
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
|
||||
|
||||
try_jit_decomposition_bool = "true" if try_jit_decomposition else "false"
|
||||
any_has_forward_grad = (
|
||||
get_any_has_fw_grad_cond(derivative=None)
|
||||
if requires_derivative
|
||||
else "false"
|
||||
)
|
||||
return_types = ", ".join(
|
||||
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
|
||||
)
|
||||
if len(f.func.returns) > 1:
|
||||
return_types = f"std::tuple<{return_types}>"
|
||||
|
||||
arg_types = [
|
||||
cpp.argument_type(a, binds="", symint=True).cpp_type()
|
||||
for a in f.func.arguments.flat_all
|
||||
]
|
||||
arg_names = [
|
||||
a.name
|
||||
for a in cpp.arguments(
|
||||
f.func.arguments,
|
||||
faithful=True,
|
||||
symint=True,
|
||||
method=False,
|
||||
cpp_no_default_args=set(),
|
||||
)
|
||||
]
|
||||
|
||||
if not modifies_arguments(f) and not returns_void:
|
||||
# Just to keep things simple here, we only care about this path
|
||||
# and always emit the if/else for now
|
||||
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
|
||||
base_type_call=base_type_call,
|
||||
tmp_var=TMP_VAR,
|
||||
guard=guard,
|
||||
try_jit_decomposition_bool=try_jit_decomposition_bool,
|
||||
any_has_forward_grad=any_has_forward_grad,
|
||||
op_name=cpp.name(f.func),
|
||||
op_overload=f.func.name.overload_name,
|
||||
returns_and_args=return_types + ", " + ", ".join(arg_types),
|
||||
arg_names=arg_names,
|
||||
base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard
|
||||
)
|
||||
|
||||
call += wrap_output(f, unpacked_bindings, TMP_VAR)
|
||||
else:
|
||||
assert not try_jit_decomposition
|
||||
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
|
||||
base_type_call=base_type_call, guard=guard
|
||||
)
|
||||
|
|
@ -1476,14 +1406,38 @@ def emit_body(
|
|||
def emit_any_has_forward_grad() -> List[str]:
|
||||
content: List[str] = []
|
||||
for derivative in fw_derivatives:
|
||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||
assert derivative.required_inputs_fw_grad is not None
|
||||
requires_fw_grad = " || ".join(
|
||||
[
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
for inp in differentiable_inputs
|
||||
if inp.name in derivative.required_inputs_fw_grad
|
||||
]
|
||||
)
|
||||
if not requires_fw_grad:
|
||||
# Handle functions like stack
|
||||
# For these, we don't unpack anything and always call the user function
|
||||
if not (
|
||||
len(differentiable_inputs) == 1
|
||||
and is_tensor_list_type(differentiable_inputs[0].type)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
|
||||
"forward AD formula does not use any input tangent) even though a forward gradient "
|
||||
"formula has been defined for it. This case should only happen for function that "
|
||||
"take a single TensorList as input. All other cases are not supported right now."
|
||||
)
|
||||
requires_fw_grad = "true"
|
||||
|
||||
if info and info.output_differentiability_conditions:
|
||||
assert len(info.output_differentiability_conditions) == 1
|
||||
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
|
||||
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && ({requires_fw_grad})"
|
||||
|
||||
content.append(
|
||||
f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
|
||||
f"(void){get_any_has_forward_grad_name(derivative.var_names)};"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def emit_check_inplace() -> List[str]:
|
||||
|
|
@ -1606,83 +1560,46 @@ def emit_body(
|
|||
content.append("\n".join(fw_grad_setters))
|
||||
return content
|
||||
|
||||
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
|
||||
#
|
||||
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
||||
#
|
||||
if derivative is None:
|
||||
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
|
||||
# - Used in the out_fn case when we want to forbid fw derivatives
|
||||
# - Used in the case where the fw_derivative is not defined, but we want
|
||||
# To check if there is a decomposition registered for jvp
|
||||
to_check: List[str] = []
|
||||
for inp in list(
|
||||
mapMaybe(
|
||||
gen_differentiable_input,
|
||||
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
|
||||
)
|
||||
):
|
||||
if is_tensor_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
)
|
||||
elif is_tensor_list_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
|
||||
req_inp=inp.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
|
||||
)
|
||||
return f'({" || ".join(to_check)})'
|
||||
else:
|
||||
# (2) If derivative is provided, use that information to determine which inputs
|
||||
# to check fw_grad for
|
||||
assert derivative.required_inputs_fw_grad is not None
|
||||
|
||||
if len(derivative.required_inputs_fw_grad) == 0:
|
||||
# Handle functions like stack
|
||||
# For these, we don't unpack anything and always call the user function
|
||||
if not (
|
||||
len(differentiable_inputs) == 1
|
||||
and is_tensor_list_type(differentiable_inputs[0].type)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
|
||||
"forward AD formula does not use any input tangent) even though a forward gradient "
|
||||
"formula has been defined for it. This case should only happen for function that "
|
||||
"take a single TensorList as input. All other cases are not supported right now."
|
||||
)
|
||||
any_has_fw_grad = "true"
|
||||
else:
|
||||
any_has_fw_grad = " || ".join(
|
||||
[
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
for inp in differentiable_inputs
|
||||
if inp.name in derivative.required_inputs_fw_grad
|
||||
]
|
||||
)
|
||||
any_has_fw_grad = f"({any_has_fw_grad})"
|
||||
|
||||
return any_has_fw_grad
|
||||
|
||||
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
|
||||
if is_out_fn:
|
||||
msg = "because it is an out= function"
|
||||
else:
|
||||
msg = (
|
||||
"because it has not been implemented yet.\\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation."
|
||||
def get_msg() -> str:
|
||||
if is_out_fn:
|
||||
msg = "because it is an out= function"
|
||||
else:
|
||||
msg = (
|
||||
"because it has not been implemented yet.\\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation."
|
||||
)
|
||||
return msg
|
||||
|
||||
res = ""
|
||||
to_check: List[str] = []
|
||||
for inp in list(
|
||||
mapMaybe(
|
||||
gen_differentiable_input,
|
||||
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
|
||||
)
|
||||
cond = get_any_has_fw_grad_cond(derivative=None)
|
||||
return (
|
||||
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
|
||||
if cond != ""
|
||||
else ""
|
||||
)
|
||||
):
|
||||
if is_tensor_type(inp.type):
|
||||
to_check.append(
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
)
|
||||
elif is_tensor_list_type(inp.type):
|
||||
cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
|
||||
res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
|
||||
arg=inp.name, cond=cond, name=name, msg=get_msg()
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
|
||||
)
|
||||
|
||||
if len(to_check) > 0:
|
||||
cond = " || ".join(to_check)
|
||||
res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(
|
||||
cond=cond, name=name, msg=get_msg()
|
||||
)
|
||||
return res
|
||||
|
||||
body: List[str] = []
|
||||
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
||||
|
|
@ -1696,7 +1613,7 @@ def emit_body(
|
|||
body.extend(setup_derivative(differentiable_inputs))
|
||||
body.append(declare_returned_variables(f))
|
||||
|
||||
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
|
||||
body.append(emit_call(f, unpacked_bindings))
|
||||
if requires_derivative:
|
||||
# set_flags has to appear after version_counter, because rebase_history
|
||||
# requires that the counter is incremented before it is called
|
||||
|
|
@ -1706,11 +1623,20 @@ def emit_body(
|
|||
if is_out_fn:
|
||||
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
|
||||
else:
|
||||
if requires_derivative and not try_jit_decomposition:
|
||||
if len(fw_derivatives) > 0:
|
||||
body.extend(emit_fw_derivatives())
|
||||
else:
|
||||
if requires_derivative:
|
||||
body.extend(emit_fw_derivatives())
|
||||
if len(fw_derivatives) == 0:
|
||||
body.append(emit_forbid_fw_derivatives())
|
||||
else:
|
||||
assert sum(
|
||||
len(derivative.var_names) for derivative in fw_derivatives
|
||||
) == len(differentiable_outputs), (
|
||||
"Expected the number of forward derivatives implemented to match the "
|
||||
"number of differentiable outputs. NB: This only applies when at least "
|
||||
"one forward derivative is implemented. Not implementing any forward "
|
||||
"derivatives is also okay, and we would require inputs to the op to "
|
||||
"not have associated tangents in that case."
|
||||
)
|
||||
|
||||
if requires_derivative:
|
||||
# Save only after the forward AD has been set up
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
||||
#include "torch/csrc/autograd/VariableTypeUtilsDependOnOps.h"
|
||||
#include "torch/csrc/autograd/generated/VariableType.h"
|
||||
#include "torch/csrc/autograd/FunctionsManual.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,40 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/runtime/decomposition_registry.h>
|
||||
|
||||
// This is the set of helpers in VariableTypeUtils have a dependency on
|
||||
// native_functions.yaml meaning the file will need to be re-compiled every time
|
||||
// an operator is changed or added. We cannot simply put these functions in
|
||||
// VariableType.h and VariableTypeutils.h, since they are included in files like
|
||||
// ADInplaceOrViewType_X.cpp which don't always want to be recompiled.
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace impl {
|
||||
|
||||
// Depends on torch/csrc/jit/ir/ir.h -> aten/src/ATen/core/interned_strings.h
|
||||
template <class Return, class... Args>
|
||||
Return run_jit_decomposition_with_args_for_jvp(
|
||||
c10::string_view name,
|
||||
const c10::OperatorHandle& opHandle,
|
||||
c10::DispatchKeySet dispatchKeySet,
|
||||
Args... args) {
|
||||
bool has_decomp = jit::has_jit_decomposition(opHandle.schema());
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
has_decomp,
|
||||
"Trying to use forward AD with ",
|
||||
name,
|
||||
" that does not support it"
|
||||
"because it has not been implemented yet and does not have a decomposition.\\nPlease file an issue "
|
||||
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
||||
"so that we can prioritize its implementation.");
|
||||
|
||||
return c10::KernelFunction::makeFromBoxedKernel(
|
||||
c10::BoxedKernel::makeFromFunction<&jit::run_jit_decomposition>())
|
||||
.call<Return, Args...>(opHandle, dispatchKeySet, args...);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
@ -100,23 +100,5 @@ inline bool isFwGradDefined(const c10::optional<at::Tensor>& t) {
|
|||
return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
|
||||
}
|
||||
|
||||
inline bool isFwGradDefinedTensorList(const at::TensorList& variables) {
|
||||
bool ret = false;
|
||||
for (auto& variable : variables) {
|
||||
ret |= isFwGradDefined(variable);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline bool isFwGradDefinedTensorList(
|
||||
const c10::List<c10::optional<at::Tensor>> li) {
|
||||
bool ret = false;
|
||||
for (auto i : c10::irange(li.size())) {
|
||||
auto t = li.get(i);
|
||||
ret |= (t.has_value() && isFwGradDefined(t.value()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -160,26 +160,6 @@ void RegisterDecomposition(
|
|||
schema_to_decomposition[&schema] = g;
|
||||
}
|
||||
|
||||
void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
// TODO: templatize based on op and keep static trace_exec
|
||||
auto* trace_exec = torch::jit::GetDecompositionExecutor(schema);
|
||||
trace_exec->run((*stack));
|
||||
if (stack->back().isTuple()) {
|
||||
at::IValue tup = stack->back();
|
||||
stack->pop_back();
|
||||
for (const auto& elem : tup.toTuple()->elements()) {
|
||||
stack->push_back(elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool has_jit_decomposition(const FunctionSchema& schema) {
|
||||
return GetDecompositionFunction(schema).has_value();
|
||||
}
|
||||
|
||||
Function* GetDecompositionExecutor(const FunctionSchema& schema) {
|
||||
auto maybe_func = GetDecompositionFunction(schema);
|
||||
TORCH_INTERNAL_ASSERT(maybe_func);
|
||||
|
|
|
|||
|
|
@ -25,11 +25,5 @@ TORCH_API Function* GetDecompositionExecutor(const char* schema_literal);
|
|||
|
||||
TORCH_API Function* GetDecompositionExecutor(const FunctionSchema& schema);
|
||||
|
||||
TORCH_API void run_jit_decomposition(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
TORCH_API bool has_jit_decomposition(const FunctionSchema& schema);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -39,10 +39,6 @@ import torch._refs.nn.functional
|
|||
import torch._refs.special
|
||||
import torch._refs.linalg
|
||||
|
||||
# Make sure that decompositions used for test_forward_mode_AD and
|
||||
# test_fn_fwgrad_bwgrad are registered to the jit
|
||||
import torch._decomp.decompositions_for_jvp
|
||||
|
||||
import torch._prims as prims # noqa: F401
|
||||
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
|
@ -10168,7 +10164,6 @@ op_db: List[OpInfo] = [
|
|||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=True),
|
||||
OpInfo('softmax',
|
||||
aliases=('special.softmax', 'nn.functional.softmax',),
|
||||
|
|
@ -10178,7 +10173,6 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=True),
|
||||
# `softmin` supports different dtypes based on whether `dtype` argument,
|
||||
# is passed or not. Hence two OpInfo entries, one with dtype and other without.
|
||||
|
|
@ -10191,7 +10185,6 @@ op_db: List[OpInfo] = [
|
|||
assert_jit_shape_analysis=False,
|
||||
assert_autodiffed=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
OpInfo('nn.functional.softmin',
|
||||
variant_test_name="with_dtype",
|
||||
|
|
@ -10200,7 +10193,6 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
assert_autodiffed=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
OpInfo(
|
||||
"nn.functional.cross_entropy",
|
||||
|
|
@ -10209,7 +10201,6 @@ op_db: List[OpInfo] = [
|
|||
sample_inputs_func=sample_inputs_cross_entropy,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=(
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
|
||||
|
|
@ -10301,7 +10292,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
assert_jit_shape_analysis=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_native_layer_norm,
|
||||
error_inputs_func=error_inputs_native_layer_norm,
|
||||
skips=(
|
||||
|
|
@ -10673,7 +10663,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[
|
||||
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
|
||||
# Consider making it a parameter or input, or detaching the gradient
|
||||
|
|
@ -10692,7 +10681,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
decorators=[
|
||||
DecorateInfo(
|
||||
|
|
@ -11732,7 +11720,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
sample_inputs_func=sample_inputs_batch_norm,
|
||||
skips=(
|
||||
|
|
@ -11755,7 +11742,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
decorators=[onlyCUDA, disablecuDNN],
|
||||
skips=(
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
|
|
@ -14718,7 +14704,6 @@ op_db: List[OpInfo] = [
|
|||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_softmax_variant,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_autodiffed=True),
|
||||
OpInfo(
|
||||
'log_softmax',
|
||||
|
|
@ -14728,7 +14713,6 @@ op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_autodiffed=True),
|
||||
UnaryUfuncInfo('logit',
|
||||
aten_backward_name='logit_backward',
|
||||
|
|
@ -15605,7 +15589,6 @@ op_db: List[OpInfo] = [
|
|||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_nll_loss,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
skips=(
|
||||
# RuntimeError:
|
||||
|
|
|
|||
|
|
@ -990,7 +990,6 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -1018,7 +1017,6 @@ op_db: List[OpInfo] = [
|
|||
],
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -1039,7 +1037,6 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
|
|||
Loading…
Reference in a new issue