diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 41f14a15ba9..06a064063c4 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry { has_symbolic_sizes_strides_( t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + // true if the tensor is contiguous bool is_contiguous() const; diff --git a/build_variables.bzl b/build_variables.bzl index 8bd8ad3a8df..a95c03cd0b3 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -138,6 +138,7 @@ core_trainer_sources = [ "torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/utils/warnings.cpp", "torch/csrc/autograd/jit_decomp_interface.cpp", + "torch/csrc/dynamo/compiled_autograd.cpp", "torch/csrc/jit/frontend/name_mangler.cpp", "torch/csrc/jit/ir/type_hashing.cpp", "torch/csrc/jit/serialization/pickler.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 14e3f2e044c..6c5cd6a9f25 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -121,23 +121,27 @@ class _multiply_invoke(torch.nn.Module): out.backward(grad_out) actual = normalize_gm(graph.print_readable(False)) self.assertEqual(x.grad, grad_out * grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list): l_inputs_ = L_inputs_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - result: "f32[s0]" = getitem * getitem; getitem = None + new_grad: "f32[2]" = torch.clone(getitem_3) - new_grad_1: "f32[s0]" = torch.clone(result); result = None + result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, - ) + ) graph = None @@ -187,26 +191,30 @@ class GraphModule(torch.nn.Module): actual = normalize_gm(graph.print_readable(False)) self.assertEqual(obj.counter, 1) self.assertEqual(x.grad, grad_out + grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"): l_inputs_ = L_inputs_ l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None + + new_grad: "f32[2]" = torch.clone(getitem_3) add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None - result: "f32[s0]" = getitem * getitem; getitem = None + result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, - ) + ) out = fn(x, y) out.backward(grad_out) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index d916c4186a3..2a2b7569009 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2924,7 +2924,6 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { "aot0_le", "aot0_permute_2", "code: CompiledFunctionBackward0 (NodeCall 2)", - "aot0_tangents_1", "aot0_full_default", "aot0_where", "aot0_mm", @@ -2974,20 +2973,17 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { expected_logs = [ "CompiledFunctionBackward1", - "aot1_tangents_1", "aot1_sin_1", - "aot1_primals_2", "aot1_neg", "aot0_tangents_2", "aot1_cos_1", - "aot1_primals_1", "aot0_tangents_1", "CompiledFunctionBackward0", + "aot0_sin_1", "aot0_neg", - "aot0_sin", "aot0_mul", + "aot0_cos_1", "aot0_mul_1", - "aot0_cos", "aot0_add", ] @@ -3618,6 +3614,7 @@ known_failing_tests = { "test_tp_compile_comm_reordering", "test_unwrap_async_collective_tensor_tangent", # Uncategorized + "test_not_implemented_grad", # Dynamo changes the types of exceptions } if not HAS_CUDA: diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 9d1a3202f14..ed50a02d7b2 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -337,7 +337,9 @@ class DistributedPatternTests(TestCase): self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None - self.assertEqual(bw_cnt.op_count, 48) + self.assertEqual( + bw_cnt.op_count, 72 + ) # Number of ops in the Dynamo-produced graphs def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index fa1d0ce4bc9..a2f9c60b222 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -68,6 +68,7 @@ struct TORCH_API ${op} : public ${superclass} { } ${will_release_variables} void compiled_args(CompiledNodeArgs& args) override; + ivalue_list get_packed_args(); variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; ${saved_variables} ${saved_list_sizes} @@ -107,6 +108,13 @@ static variable_list ${op}_apply_functional( ${body} return grad_inputs; } +static variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) +{ + auto packed_args = PackedArgs(args); + auto needs_input_grad = packed_args.unpack>(); + ${unpack_ivalues} + return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); +} variable_list ${op}::apply(variable_list&& grads) { ${thread_lock} @@ -120,11 +128,42 @@ void ${op}::compiled_args(CompiledNodeArgs& args) { ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { - ${apply_with_saved_before} - variable_list result = apply(variable_list(grads)); - ${apply_with_saved_after} - return result; + ${apply_with_saved_before} + + static std::once_flag flag; + std::call_once(flag, [&](){ + ${compute_schema} + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); + }); + + variable_list result; + auto packed_args = get_packed_args(); + auto output_metadata = torch::dynamo::autograd::IValuePacker< + std::vector>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + result = interface->call_function( + saved.get_py_compiler(), + "apply_functional", + name(), + grads, + packed_args, + output_metadata); + + ${apply_with_saved_after} + return result; } +ivalue_list ${op}::get_packed_args() { + PackedArgs packed_args; + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + packed_args.pack(needs_input_grad); + ${get_packed_args} + return std::move(packed_args).vec(); +} + """ ) @@ -993,14 +1032,38 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { f"{T} {x}" for T, x in zip(apply_functional_args_ref_types, apply_functional_args) ] + get_packed_args = "\n".join( + f"packed_args.pack({name});" for name in apply_functional_args + ) + unpack_ivalues = [] + for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): + if typ.endswith("&"): + typ = typ[:-1] + unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") + + schema_args = [f"std::array"] + for typ in apply_functional_args_ref_types: + if typ.endswith("&"): + typ = typ[:-1] + if typ.startswith("const"): + typ = typ[5:] + schema_args.append(typ.strip()) + compute_schema = ["std::vector schema = {"] + for schema_arg in schema_args: + compute_schema.append( + f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," + ) + compute_schema.append("};") return template.substitute( unpacks="\n".join(unpack), op=info.op, + compute_schema="\n".join(compute_schema), apply_functional_args=apply_functional_args, apply_functional_args_signature=apply_functional_args_signature, compute_needs_input_grad=compute_needs_input_grad, num_inputs=len(input_name_to_idx), + unpack_ivalues="\n".join(unpack_ivalues), compute_index_ranges=compute_index_ranges, saved_variables=saved_variables, release_variables=release_variables, @@ -1015,4 +1078,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { compiled_args=compiled_args, apply_with_saved_before=apply_with_saved_before, apply_with_saved_after=apply_with_saved_after, + get_packed_args=get_packed_args, ) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index fb7017bc6dc..cd4db23e332 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -8,6 +8,7 @@ from collections import defaultdict from typing import Any, Optional, TYPE_CHECKING, Union import torch +import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( call_backward, call_hook, @@ -65,6 +66,39 @@ def maybe_clone(x): return x +# We lazily bind "functional backward" variants for PyTorch built-in autograd +# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 +# Each "functional backward" is bound the first time the node's apply_with_saved +# function is called. It's possible to avoid lazy binding and instead bind +# all of this upfront (perhaps at import time) via codegen changes. +class OpNamespace: + def add(self, name, fn): + assert not hasattr(self, name) + result = Op(name, fn) + torch._dynamo.allow_in_graph(result) + setattr(self, name, result) + return result + + def get(self, name): + return getattr(self, name) + + +class Op: + def __init__(self, name, fn): + self.fn = fn + self.__name__ = name + self.__module__ = "torch._dynamo.compiled_autograd.ops" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): + return self.__module__ + "." + self.__name__ + + +ops = OpNamespace() + + _graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] _impure_targets = OrderedSet( [ @@ -137,7 +171,8 @@ class AutogradCompilerInstance: self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} - args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( + self.symnode_proxy_lookup = {} + args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -160,7 +195,9 @@ class AutogradCompilerInstance: ) for idx, val in enumerate(sizes) ] - self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins) + self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins) + for i, symint in enumerate(sizes): + self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i] for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -182,7 +219,9 @@ class AutogradCompilerInstance: ) else: raise AssertionError("Unexpected scalar type: ", type(val)) - self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins) + self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins) + for i, symval in enumerate(scalars): + self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) @@ -216,7 +255,6 @@ class AutogradCompilerInstance: ), kwargs={}, ) - with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: list[Optional[torch.Tensor]] = [] @@ -232,6 +270,65 @@ class AutogradCompilerInstance: self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) + # Guess what the outputs should be from the InputMetadata. + # This is not sound in general (we guess contiguous strides + # and no Tensor subclass-ness); we will stop guessing + # the output metadata in a follow-up. + def guess_output(self, input_metadata): + if input_metadata is None: + return None + tensoroptions, shape, _ = input_metadata + kwargs = {} + names = [ + "requires_grad", + "memory_format", + "device", + "dtype", + "layout", + "pinned_memory", + ] + for name, option in zip(names, tensoroptions): + if option is not None: + kwargs[name] = option + + with disable_proxy_modes_tracing(): + return torch.ops.aten.zeros(shape, **kwargs) + + def bind_function(self, fn_name, fn): + """Binds ops.fn_name = fn""" + ops.add(fn_name, fn) + + def apply_functional(self, fn_name, grads, args, output_metadata): + """Proxies a call to ops.fn_name(grads, *args) into the graph""" + op = ops.get(fn_name) + return self.proxy_call(op, (grads, *args), output_metadata) + + def proxy_call(self, fn, args, output_metadata): + """Proxies a call to fn(*args) into the graph""" + flat_args, _ = pytree.tree_flatten(args) + proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) + proxy_out = self.fx_tracer.create_proxy( + "call_function", fn, args=proxy_args, kwargs={} + ) + result = [self.guess_output(metadata) for metadata in output_metadata] + self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))]) + return result + + def validate_outputs(self, _, outputs, args, output_metadata): + """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" + op = ops.get("validate_outputs") + proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) + new_proxy_outputs = self.fx_tracer.create_proxy( + "call_function", op, args=proxy_args, kwargs={} + ) + assert len(output_metadata) == len(outputs) + outputs = [ + None if output is None or metadata is None else self.guess_output(metadata) + for output, metadata in zip(outputs, output_metadata) + ] + self.bind_tensors_to_proxies(outputs, new_proxy_outputs) + return outputs + def proxy_call_hook(self, hook, *args, **kwargs): return self.fx_tracer.create_proxy( "call_function", @@ -314,6 +411,7 @@ class AutogradCompilerInstance: assert nodes[first_getitem_idx] == inputs_users[0] last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 assert nodes[last_getitem_idx] == inputs_users[-1] + # getitem nodes on inputs for i, node in enumerate(inputs_users): if not has_cuda_inputs and node.meta["val"].device.type == "cuda": has_cuda_inputs = True @@ -323,9 +421,13 @@ class AutogradCompilerInstance: is_scalar = len(node.meta["val"].size()) == 0 if is_cpu and is_scalar: node_users = list(node.users.keys()) + # We can only move the cpu scalar if it is not exposed to user code. if all( - isinstance(user.target, torch._ops.OpOverload) - and user.target.namespace in ("prims", "aten") + ( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + ) + or isinstance(user.target, Op) for user in node_users ): # all users are prims/aten, can move safely @@ -335,6 +437,7 @@ class AutogradCompilerInstance: # this is to handle the case where cudagraphs is enabled on a cpu-only graph if has_cuda_inputs: for node in to_move.values(): + verbose_log.debug("Moving node %s from cpu to cuda", node) node.meta["val"] = node.meta["val"].cuda() # return runtime indices we need to move to cuda @@ -368,7 +471,10 @@ class AutogradCompilerInstance: or (node.op == "call_function" and node.target in _impure_targets) ) + before = len(self.fx_tracer.graph.nodes) self.fx_tracer.graph.eliminate_dead_code(is_impure) + after = len(self.fx_tracer.graph.nodes) + verbose_log.debug("DCE removed %d nodes", before - after) def end_capture(self, outputs): self.fx_tracer.create_proxy( @@ -384,6 +490,10 @@ class AutogradCompilerInstance: (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + # TODO(rzou): the guessed metadata is incorrect, we will remove it at the end of the PR stack. self.rename_aot_dispatcher_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() @@ -402,9 +512,6 @@ class AutogradCompilerInstance: # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and # should prevent these ops from going into the CA graph. self.dce() - runtime_inputs_to_move: list[int] = [] - if snapshot_cudagraph_enabled(): - runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" @@ -778,8 +885,11 @@ class AutogradCompilerInstance: return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) - # can it be torch.SymInt as the code used to imply? - assert isinstance(t, torch.Tensor) + if isinstance(t, (torch.SymInt, torch.SymFloat)): + return self.symnode_proxy_lookup[t.node] + if not isinstance(t, torch.Tensor): + # constant types like device, dtype, str + return t proxy_tensor = fetch_object_proxy(self.fx_tracer, t) assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) return proxy_tensor.proxy diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 3b37b87391f..53e8bcd0645 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -898,6 +898,19 @@ bool has_input_metadata(const Edge& thing) { return thing.is_valid(); } +std::vector> collect_input_metadata( + const edge_list& edges) { + std::vector> input_metadata; + for (const auto& edge : edges) { + if (!edge.is_valid()) { + input_metadata.emplace_back(std::nullopt); + continue; + } + input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); + } + return input_metadata; +} + // Given an vector or vector>, validate the // outputs. This involves using the InputMetadata to check the outputs and also // potentially calling .sum_to on the outputs. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 4243f1b1d6e..5bf00bac537 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -47,6 +47,8 @@ TORCH_API void validate_outputs( const std::vector>& input_metadata, variable_list& grads, const std::function& format_error); +TORCH_API std::vector> collect_input_metadata( + const edge_list& edges); struct NodeTask { std::weak_ptr base_; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index ba2f6edbc6c..abd11303eaf 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -34,8 +34,12 @@ using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; +using ivalue_list = std::vector; +using functional_apply_t = std::function< + variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; +using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. @@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this { std::string("apply_with_saved not implemented: ") + name()); } + // If this node is the AOTBackward node produced by torch.compile. + // Compiled Autograd special-cases on this information. + virtual bool is_aot_backward() const { + return false; + } + protected: /// Performs the `Node`'s actual operation. virtual variable_list apply(variable_list&& inputs) = 0; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 6342bf280a5..4e8bba79a16 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -8,6 +8,7 @@ namespace torch::dynamo::autograd { class CompiledNodeArgs; class SwapSavedVariables; +struct PackedArgs; } // namespace torch::dynamo::autograd // A hook that's called on gradients diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 19151cbaafe..abd8ff30e90 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -288,6 +288,11 @@ auto PyNode::name() const -> std::string { return name; } +bool PyNode::is_aot_backward() const { + py::handle handle(obj); + return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); +} + auto PyNode::compiled_autograd_should_lift() const -> bool { pybind11::gil_scoped_acquire gil; static PyObject* attr_name = diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 46faff8e468..2f28c765ab0 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -43,6 +43,8 @@ struct PyNode : public Node { std::string name() const override; bool is_traceable() override; + bool is_aot_backward() const override; + void compiled_args(CompiledNodeArgs& args) override; variable_list apply_with_saved( const variable_list& inputs, diff --git a/torch/csrc/dynamo/compiled_autograd.cpp b/torch/csrc/dynamo/compiled_autograd.cpp new file mode 100644 index 00000000000..7e2aad57618 --- /dev/null +++ b/torch/csrc/dynamo/compiled_autograd.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace torch::dynamo::autograd { + +std::unique_ptr kPyCompilerInterface; + +const std::unique_ptr& getPyCompilerInterface() { + TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr); + return kPyCompilerInterface; +} + +void setPyCompilerInterface(std::unique_ptr&& impl) { + TORCH_INTERNAL_ASSERT(impl != nullptr); + kPyCompilerInterface = std::move(impl); +} + +void resetPyCompilerInterface() { + kPyCompilerInterface.reset(); +} + +std::vector> get_input_metadata( + const edge_list& edges) { + return torch::autograd::collect_input_metadata(edges); +} + +} // namespace torch::dynamo::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 383cff14b8e..b00ec6e00a4 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -900,6 +900,506 @@ class SwapSavedVariables { StashedVars stashed_ivalues; }; +// NOTE: [Compiled Autograd and backward functions] +// Built-in autograd nodes have functional apply variants +// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph +// capture wants to take a variant of this function and proxy it into the graph. +// Every autograd node defines an apply_with_saved function, that when invoked, +// proxys a call to a function into the Compiled Autograd graph. +// +// Some requirements that we have are: +// - The proxy'ed function must have inputs that are FX-graphable types. +// - Windows has a DLL symbol limit of 65536. +// - Node::apply_with_saved is in libtorch_cpu which does not have direct access +// to Python +// +// There were multiple ways to skin the cat, but what we end up doing is: +// - for e.g. MulBackward0_apply_functional, we create a new C++ function +// MulBackward0_apply_functional_ivalue that accepts vector. +// - We define how to pack and unpack arbitrary C++ types into IValues. +// - apply_with_saved passes MulBackward0_apply_functional_ivalue and +// the IValue arguments to Python via an indirection. +// In Python, these get proxy'ed into a graph. + +// Helper struct for packing/unpacking an arbitrary C++ type into a single +// IValue. There are various full and partial specializations for IValuePacker +// to handle packing specific types (like TensorOptions) into an IValue. +template +struct IValuePacker { + // Defines how to pack T into an IValue. + static at::IValue pack(const T& t) { + return t; + } + // Defines how to unpack an IValue into T. + static T unpack(const at::IValue& t) { + return t.to(); + } + // Returns the TypePtr for the IValue (this is like the "type" of the IValue). + // We use this when passing the packed IValue from Python to C++. + // In Python, the IValue is just a PyObject* with the native type. + // For example, it may be a Python int, a Python List[int], etc. + // When passing this PyObject* into C++, we need to know how to parse it + // into a C++ type that then gets put into an IValue. + // That's what the TypePtr is for: it contains the information to do the + // parsing. See torch::jit::toIValue for more information. + static at::TypePtr packed_type() { + if constexpr (::std::is_same_v) { + return at::TensorType::get(); + } else if constexpr (::std::is_same_v) { + return at::IntType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymIntType::get(); + } else if constexpr (::std::is_same_v) { + return at::BoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::FloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymFloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymBoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::LayoutType::get(); + } else if constexpr (::std::is_same_v) { + return at::StringType::get(); + } else if constexpr (::std::is_same_v) { + return at::DeviceObjType::get(); + } else if constexpr (::std::is_same_v) { + return at::NumberType::get(); + } else if constexpr (::std::is_same_v) { + return at::MemoryFormatType::get(); + } else if constexpr (::std::is_same_v) { + return at::ScalarTypeType::get(); + } else { + // If you got here, you have probably added a member of a new type + // to a built-in C++ autograd node. + // Unfortunately, we don't know how to handle this type yet. + // To get this new type to work with Compiled Autograd, please + // either change it to be an IValue-constructible type, or + // define how to pack and unpack an object of this time into an IValue + // by creating a specialization of IValuePacker for this type. + // See NOTE: [Compiled Autograd and backward functions] for context. + TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); + return at::NoneType::get(); + } + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const size_t& t) { + // We generally use size_t as the size of a list of Tensors or number of + // dimensions. The number of dimensions generally do not exceed 64 + // (TensorIterator has that limitation), and lists of Tensors generally do + // not exceed the int64_t max (you'd probably run out of RAM or run into + // significant Tensor overhead). If you run into this limitation the fix is + // to figure out how to pack size_t into int64_t. Note that size_t has some + // weird behavior on Mac OS. + uint64_t maximum_value = std::numeric_limits::max(); + TORCH_INTERNAL_ASSERT( + static_cast(t) <= maximum_value, + "size_t too large to pack into IValue"); + return static_cast(t); // pack as int64_t + } + static size_t unpack(const at::IValue& t) { + return static_cast(t.toInt()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +template <> +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + return t; + } + static std::vector unpack(const at::IValue& t) { + // We need this because there's no t.to>() override? + return t.toSymIntVector(); + } + static at::TypePtr packed_type() { + return at::ListType::create(at::SymIntType::get()); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const VariableInfo& t) { + auto tuple = std::make_tuple( + t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); + return tuple; + } + static VariableInfo unpack(const at::IValue& t) { + auto tuple = t.to, + bool, + bool>>(); + VariableInfo v; + v.layout = std::get<0>(tuple); + v.device = std::get<1>(tuple); + v.scalar_type = std::get<2>(tuple); + v.size = std::get<3>(tuple); + v.requires_grad = std::get<4>(tuple); + v.is_empty = std::get<5>(tuple); + return v; + } + static at::TypePtr packed_type() { + return at::TupleType::create({ + at::LayoutType::get(), + at::DeviceObjType::get(), + at::ScalarTypeType::get(), + at::ListType::create(at::SymIntType::get()), + at::BoolType::get(), + at::BoolType::get(), + }); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const caffe2::TypeMeta& t) { + return at::typeMetaToScalarType(t); // pack as at::ScalarType + } + static caffe2::TypeMeta unpack(const at::IValue& t) { + return caffe2::TypeMeta::fromScalarType(t.to()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +inline std::optional optTypeMetaToScalarType( + const std::optional& t) { + if (t.has_value()) { + return at::typeMetaToScalarType(t.value()); + } else { + return std::nullopt; + } +} + +using packed_tensoroptions_t = std::tuple< + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional>; + +inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { + auto tuple = std::make_tuple( + t.requires_grad_opt(), + t.memory_format_opt(), + t.device_opt(), + optTypeMetaToScalarType(t.dtype_opt()), + t.layout_opt(), + t.pinned_memory_opt()); + return tuple; +} +inline at::TensorOptions unpack_TensorOptions( + const packed_tensoroptions_t& tuple) { + at::TensorOptions result; + auto maybe_requires_grad = std::get<0>(tuple); + if (maybe_requires_grad.has_value()) { + result = result.requires_grad(maybe_requires_grad.value()); + } + auto maybe_memory_format = std::get<1>(tuple); + if (maybe_memory_format.has_value()) { + result = result.memory_format(maybe_memory_format.value()); + } + auto maybe_device = std::get<2>(tuple); + if (maybe_device.has_value()) { + result = result.device(maybe_device.value()); + } + auto maybe_dtype = std::get<3>(tuple); + if (maybe_dtype.has_value()) { + result = + result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); + } + auto maybe_layout = std::get<4>(tuple); + if (maybe_layout.has_value()) { + result = result.layout(maybe_layout.value()); + } + auto maybe_pinned_memory = std::get<5>(tuple); + if (maybe_pinned_memory.has_value()) { + result = result.pinned_memory(maybe_pinned_memory.value()); + } + return result; +} + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorOptions& t) { + return pack_TensorOptions(t); + } + static at::TensorOptions unpack(const at::IValue& t) { + auto tuple = t.to(); + return unpack_TensorOptions(tuple); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {at::OptionalType::create(at::BoolType::get()), + at::OptionalType::create(at::MemoryFormatType::get()), + at::OptionalType::create(at::DeviceObjType::get()), + at::OptionalType::create(at::ScalarTypeType::get()), + at::OptionalType::create(at::LayoutType::get()), + at::OptionalType::create(at::BoolType::get())}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const TypeAndSize& t) { + auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); + return tuple; + } + static TypeAndSize unpack(const at::IValue& t) { + auto tuple = + t.to, packed_tensoroptions_t>>(); + TypeAndSize result; + result.sym_sizes = std::get<0>(tuple); + result.options = unpack_TensorOptions(std::get<1>(tuple)); + return result; + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker::packed_type()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::optional& t) { + if (t.has_value()) { + return IValuePacker::pack(t.value()); + } else { + return std::nullopt; + } + } + static std::optional unpack(const at::IValue& t) { + if (t.isNone()) { + return std::nullopt; + } else { + return IValuePacker::unpack(t); + } + } + static at::TypePtr packed_type() { + return at::OptionalType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + if constexpr (::std::is_constructible_v) { + return t; + } + if (t.empty()) { + auto lst = c10::impl::GenericList(at::AnyType::get()); + return lst; + } + auto type_ptr = IValuePacker::pack(t[0]).type(); + auto lst = c10::impl::GenericList(type_ptr); + for (const auto& elt : t) { + lst.emplace_back(IValuePacker::pack(elt)); + } + return lst; + } + static std::vector unpack(const at::IValue& t) { + if constexpr (::std::is_constructible_v) { + return t.to<::std::vector>(); + } + std::vector result; + auto lst = t.toList(); + for (const at::IValue& elt : lst) { + result.emplace_back(IValuePacker::unpack(elt)); + } + return result; + } + static at::TypePtr packed_type() { + return at::ListType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const c10::List& t) { + return IValuePacker>::pack(t.vec()); + } + static c10::List unpack(const at::IValue& t) { + return c10::List(IValuePacker>::unpack(t)); + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::array& t) { + std::vector result(t.begin(), t.end()); + return IValuePacker>::pack(result); + } + static std::array unpack(const at::IValue& t) { + std::array result; + auto packed = IValuePacker>::unpack(t); + for (size_t i = 0; i < packed.size(); i++) { + result[i] = packed[i]; + } + return result; + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorGeometry& t) { + auto tuple = std::make_tuple( + t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); + return tuple; + } + static at::TensorGeometry unpack(const at::IValue& t) { + auto tuple = t.to, + std::vector, + at::SymInt>>(); + return at::TensorGeometry( + std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker>::packed_type(), + at::SymIntType::get()}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const InputMetadata& t) { + TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); + auto tuple = std::make_tuple( + pack_TensorOptions(t.options()), + t.shape_as_dim_vector().vec(), + t.is_tensor_subclass()); + return tuple; + } + static InputMetadata unpack(const at::IValue& t) { + auto tuple = t.to< + std::tuple, bool>>(); + + return InputMetadata( + unpack_TensorOptions(std::get<0>(tuple)), + SymIntSmallVec(std::get<1>(tuple)), + std::get<2>(tuple), + false); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker::packed_type(), + IValuePacker>::packed_type(), + at::BoolType::get()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const at::OptionalArray& t) { + return IValuePacker>>::pack(t.list); + } + static at::OptionalArray unpack(const at::IValue& t) { + auto result = IValuePacker>>::unpack(t); + if (result.has_value()) { + return {result.value()}; + } else { + return {}; + } + } + static at::TypePtr packed_type() { + return IValuePacker>>::packed_type(); + } +}; + +// This is a helper struct for packing and unpacking multiple arguments into +// an ivalue_list. It leverages IValuePacker. +struct PackedArgs { + PackedArgs() = default; + + explicit PackedArgs(std::vector stack_) + : stack(std::move(stack_)) {} + + std::vector vec() && { + return std::move(stack); + } + + template + void pack(const T& t) { + stack.emplace_back(IValuePacker::pack(t)); + } + template + T unpack() { + return IValuePacker::unpack(std::move(stack[idx++])); + } + + private: + std::vector stack; + int64_t idx = 0; +}; + +// This is a layer of indirection for calling methods on the Python +// AutogradCompilerInstance (referred to as the "py_compiler") from +// libtorch_cpu (where Python is not available). +// A PyCompilerInterfaceImpl in libtorch_python subclasses it and +// overrides the methods to do the actual calls back to Python. +struct TORCH_API PyCompilerInterface { + PyCompilerInterface() = default; + PyCompilerInterface(const PyCompilerInterface&) = delete; + PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; + PyCompilerInterface(PyCompilerInterface&&) = delete; + PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; + virtual ~PyCompilerInterface() = default; + + // Invokes py_compiler.bind_function(fn_name, fn) + virtual void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + functional_apply_t fn, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector packed_args_schema) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + // Invokes py_compiler.method_name(fn_name, inputs, packed_args, + // output_metadata) + virtual variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } +}; + +TORCH_API const std::unique_ptr& getPyCompilerInterface(); +TORCH_API void setPyCompilerInterface( + std::unique_ptr&& impl); +TORCH_API void resetPyCompilerInterface(); + +// including torch/csrc/autograd/engine.h breaks BC by somehow introducing +// symbol resolution issues. Instead requiring downstream users to include +// engine.h to access collect_input_metadata, we provide it here (with a +// different name to avoid ambigous symbols...) +TORCH_API std::vector> get_input_metadata( + const edge_list& edges); + } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 3753e5988cd..12ef964f7e0 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -52,6 +52,118 @@ Notes: namespace torch::dynamo::autograd { using c10::SymInt; +// List[Optional[Tensor]] in Python can't be directly parsed into a +// List[Tensor], so we need to do this conversion manually. +static std::vector toTensorList( + const std::vector>& inputs) { + std::vector result; + result.reserve(inputs.size()); + for (const auto& inp : inputs) { + if (inp.has_value()) { + result.emplace_back(*inp); + } else { + result.emplace_back(); + } + } + return result; +} + +// Binds a function (that represents some backward computation) to Python. +// All of these functions have a common signature, which is +// (in C++) (vector, vector) -> vector +// (in Python) (List[Optional[Tensor]], *packed_args: IValue) -> +// List[Optional[Tensor]] +// +// The vector are the list of gradient Tensors, each of which may be +// undefined (in C++) which corresponds to None (in Python). +static void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema) { + // This is the function that can be called from Python. + auto py_func = py::cpp_function( + [packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)]( + std::vector>& inputs, + const py::args& py_args) -> py::object { + // py_args is a tuple of PyObject*. + // We need to reconstruct a vector to invoke `fn`. + // To do so, we use the packed_args_schema to convert each PyObject* + // to its corresponding C++ type that can be stored into IValue. + TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size()); + std::vector args; + args.reserve(py_args.size()); + auto tuple_args = jit::tuple_slice(py_args); + for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) { + args.emplace_back(jit::toIValue( + tuple_args[idx], packed_args_schema[idx], std::nullopt)); + } + // None in Python corresponds to undefined Tensor in C++ + auto inputs_ = toTensorList(inputs); + auto outputs = fn(inputs_, args); + return jit::toPyObject(at::IValue(outputs)); + }); + py::handle handle(py_compiler); + handle.attr("bind_function")(fn_name, py_func); +} + +// Invokes py_compiler.method_name(fn_name, inputs, packed_args, +// output_metadata) +static variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + // convert ivalue_list -> PyObject* + PyObject* py_packed_args = + PyTuple_New(static_cast(packed_args.size())); + for (const auto i : c10::irange(packed_args.size())) { + py::object obj = jit::toPyObject(packed_args[i]); + Py_INCREF(obj.ptr()); + PyTuple_SET_ITEM(py_packed_args, i, obj.ptr()); + } + + // call the corresponding method on the py_compiler + py::handle handle(py_compiler); + py::object stuff = handle.attr(method_name)( + fn_name, + inputs, + py::handle(py_packed_args), + jit::toPyObject(output_metadata)); + + // Convert the output from PyObject* to vector + auto tmp = py::cast>>(stuff); + return toTensorList(tmp); +} + +struct PyCompilerInterfaceImpl : PyCompilerInterface { + void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema) override { + return torch::dynamo::autograd::bind_function( + py_compiler, fn_name, std::move(fn), std::move(packed_args_schema)); + } + variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) override { + return torch::dynamo::autograd::call_function( + py_compiler, + method_name, + fn_name, + inputs, + packed_args, + output_metadata); + } +}; + static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -88,6 +200,22 @@ static void check(bool result) { check(nullptr); } +static variable_list validate_outputs( + const variable_list& outputs, + const ivalue_list& args) { + auto r = PackedArgs(args); + auto value = r.unpack>>(); + auto new_outputs = outputs; + + torch::autograd::validate_outputs( + value, new_outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "[Compiled Autograd Tracing:]" << msg; + return ss.str(); + }); + return new_outputs; +} + // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; @@ -657,6 +785,8 @@ static CacheNode* _compiled_autograd_impl( ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); + setPyCompilerInterface(std::make_unique()); + TraceState state = call_begin_capture( py_compiler, *cache, compiler_call, output_edges.size()); InputBuffers input_buffers; @@ -723,16 +853,48 @@ static CacheNode* _compiled_autograd_impl( SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); variable_list outputs = call.node->apply_with_saved(inputs, saved); - saved.debug_asserts(); saved.before(call.node->next_edges()); - validate_outputs( - call.node->next_edges(), outputs, [&](const std::string& msg) { - std::ostringstream ss; - ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " - << msg; - return ss.str(); - }); + + auto input_metadata = get_input_metadata(call.node->next_edges()); + TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size()); + + // Lazily bind the `validate_outputs` function to Python. + static c10::once_flag flag; + c10::call_once(flag, [&]() { + auto schema = std::vector{IValuePacker< + std::vector>>::packed_type()}; + bind_function( + py_compiler.get(), "validate_outputs", validate_outputs, schema); + }); + + // Don't emit validate_outputs nodes that follow a CompiledBackward node. + // These nodes would otherwise prevent reordering of accumulate_grad + // nodes. + // + // Note that this will not cause correctness issues, because + // 1) AOTAutograd already coerces gradients to have the same metadata as + // the inputs. 2) the AOTAutograd graph already has the necessary + // aten::sum_to nodes in it (so it doesn't need to rely on + // validate_outputs to handle that). + // + // However, we may be dropping some (edge case) safety checks compared to + // eager: a backward that would have errored out in eager may not error + // out in compiled autograd (for example, if the user provided an + // incorrect number of gradients). + if (!call.node->is_aot_backward()) { + PackedArgs args; + args.pack(input_metadata); + ivalue_list input_metadata_state = std::move(args).vec(); + outputs = call_function( + py_compiler, + "validate_outputs", + "validate_outputs", + outputs, + input_metadata_state, + input_metadata_state[0]); + } + saved.after(call.node->next_edges()); saved.debug_asserts(); @@ -761,6 +923,7 @@ static CacheNode* _compiled_autograd_impl( } } + resetPyCompilerInterface(); PyObject* res = check(call_end_capture(py_compiler, state.outputs)); TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); TORCH_CHECK(