From b8abdaa286fd161af48af57a675827f4f849914d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 16 Jan 2025 09:22:22 -0300 Subject: [PATCH] Make functionalization `ViewMeta` serializable with pickle. (#143712) Fix: #141974 This PR makes `ViewMeta` sequence, present in functional tensors, serializable with pickle. In order to accomplish that, it makes `ViewMeta` an abstract class with overridable `forward` and `reverse` functions. In this context, each operation that once instanciated `ViewMeta`, should now create a new specialized class that inherits from `ViewMeta. Therefore, this PR also uses codegen for creating these specializations. In summary, these are the changes this PR introduces: - `ViewMeta` is turned into an abstract class (see _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual functions that need to be implemented. `to_out_index` should be implemented by operations that might return more than 1 output. - New `ViewMeta` specializations for `resize_` and `_unsafe_view` are created (see _FunctionalizeFallbackKernel.h_). - New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the declaration and definition of the `ViewMeta` specializations, which are automatically generated in the ATen codegen (see _gen.py_). - New `_functionalization` Python sub-module is created (see _Module.cpp_). It serves as namespace for the `ViewMeta` specializations and `InverseReturnMode` enum. - New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds the automatically generated Python bindings for the `ViewMeta` specialization, which are generated in the torch codegen (see _generate_code.py_). Note that this PR makes use of codegen at 2 different moments: - ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes. - Torch codegen (_generate_code.py_): generated the Python bindings for them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712 Approved by: https://github.com/bdhirsh --- .gitignore | 1 + BUILD.bazel | 3 + aten/src/ATen/FunctionalStorageImpl.cpp | 14 +- aten/src/ATen/FunctionalStorageImpl.h | 100 ++++-- aten/src/ATen/FunctionalTensorWrapper.cpp | 74 ++-- aten/src/ATen/FunctionalTensorWrapper.h | 27 +- aten/src/ATen/FunctionalizeFallbackKernel.cpp | 58 +-- aten/src/ATen/FunctionalizeFallbackKernel.h | 58 +++ aten/src/ATen/templates/FunctionalInverses.h | 12 +- .../templates/RegisterFunctionalization.cpp | 2 +- aten/src/ATen/templates/ViewMetaClasses.cpp | 19 + aten/src/ATen/templates/ViewMetaClasses.h | 12 + .../ViewMetaClassesPythonBinding.cpp | 11 + build.bzl | 2 + build_variables.bzl | 1 + caffe2/CMakeLists.txt | 2 + test/dynamo/test_aot_autograd_cache.py | 27 +- test/functorch/test_aotdispatch.py | 1 - tools/setup_helpers/generate_code.py | 33 +- torch/_C/__init__.pyi.in | 1 + torch/_C/_functionalization.pyi | 16 + .../_aot_autograd/autograd_cache.py | 14 - .../collect_metadata_analysis.py | 8 +- .../_aot_autograd/functional_utils.py | 83 ++--- .../_aot_autograd/input_output_analysis.py | 4 +- .../_aot_autograd/runtime_wrappers.py | 8 +- torch/_functorch/_aot_autograd/schemas.py | 30 +- torch/csrc/Module.cpp | 2 + .../python_torch_functions_manual.cpp | 9 - torch/csrc/functionalization/Module.cpp | 71 ++++ torch/csrc/functionalization/Module.h | 36 ++ torchgen/api/functionalization.py | 120 ++++--- torchgen/api/types/signatures.py | 74 +--- torchgen/gen.py | 105 ++++-- torchgen/gen_functionalization_type.py | 338 ++++++++++++++++-- 35 files changed, 951 insertions(+), 425 deletions(-) create mode 100644 aten/src/ATen/FunctionalizeFallbackKernel.h create mode 100644 aten/src/ATen/templates/ViewMetaClasses.cpp create mode 100644 aten/src/ATen/templates/ViewMetaClasses.h create mode 100644 aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp create mode 100644 torch/_C/_functionalization.pyi create mode 100644 torch/csrc/functionalization/Module.cpp create mode 100644 torch/csrc/functionalization/Module.h diff --git a/.gitignore b/.gitignore index 8d4ceaa811c..c81f0734665 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,7 @@ torch/return_types.pyi torch/nn/functional.pyi torch/utils/data/datapipes/datapipe.pyi torch/csrc/autograd/generated/* +torch/csrc/functionalization/generated/* torch/csrc/lazy/generated/*.[!m]* torch_compile_debug/ # Listed manually because some files in this directory are not generated diff --git a/BUILD.bazel b/BUILD.bazel index df46835f363..893dbfc6cec 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -90,6 +90,8 @@ generated_cpu_cpp = [ "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/VmapGeneratedPlumbing.h", + "aten/src/ATen/ViewMetaClasses.h", + "aten/src/ATen/ViewMetaClasses.cpp", "aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", @@ -1087,6 +1089,7 @@ test_suite( "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", + "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/ts_native_functions.yaml", diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index a5512818343..cae0ab0ba60 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -9,11 +9,6 @@ namespace at::functionalization { -ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { - if (out_idx == this->out_index) return *this; - return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); -} - // Note [Functionalization: Alias Removal Part 2] // See Note [Functionalization: Alias Removal] for more details. // This function applies a single update from one of the views to the StorageImpl. @@ -47,7 +42,7 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co std::vector tmp_values({base}); tmp_values.reserve(update.view_metas.size()); for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { - at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); + at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back()); // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // All of these ops require additional information to recover the sizes of the original tensor. // If need to, we could probably apply this optimization and only bother computing tmp_values @@ -55,9 +50,8 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co tmp_values.push_back(std::move(next_view)); } for(int64_t i = static_cast(update.view_metas.size()) - 1; i >= 0; --i) { - int64_t out_idx = update.view_metas[i].out_index; // Each view inverse is implemented in ViewInverses.cpp. - t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); + t = update.view_metas[i]->reverse(tmp_values[i], t); } TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); return t; @@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); } -void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& metas) { +void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector>& metas) { TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); if (metas.size() > 1) { for (size_t i = 1; i < metas.size(); ++i) { // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI - TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, + TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided, "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 3f80171196f..71c259937e9 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -8,44 +8,89 @@ namespace at::functionalization { // See Note [Functionalization Pass In Core] +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) + /// scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to + /// handle. + ViewOrScatterInverse, +}; + +#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \ + static const char* name() { \ + return #TYPE; \ + } + +#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \ + using SerializableTuple = std::tuple<__VA_ARGS__>; + // ViewMeta is a class used by the functionalization pass to navigate between // a base tensor and a view tensor. // For example, if I call `b = a.view1(...)` -// the functionalization pass will generate and store a ViewMeta on b that looks -// like: +// the functionalization pass will generate and store a ViewMeta specialization +// for `view1` operation on b that looks like: // -// ViewMeta( -// [](const Tensor& base, int64_t mutated_view_idx) { -// return base.view1(...); -// }, -// [](const at::Tensor& base, const at::Tensor& mutated_view, -// int64_t mutated_view_idx) -> at::Tensor { -// return at::functionalization::impl::view1_inverse(base, mutated_view, -// ...); +// struct TORCH_API view1_ViewMeta : public ViewMeta { +// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta); +// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( +// bool /* reapply_views */, +// const std::vector&); +// +// view1_ViewMeta(const SerializableTuple& tpl) +// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} +// +// view1_ViewMeta(bool reapply_views, const std::vector& size) +// : ViewMeta(/*has_symbolic_inputs=*/false), +// reapply_views(reapply_views), +// size(size) {} +// +// Tensor forward(const Tensor& base) override { +// return base.view1(...); // } // -// The forward_fn lambda describes how to replay view1 on a tensor. +// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); +// } // -// The reverse_fn lambda describes how, given a tensor that is already a view, +// SerializableTuple to_serializable_tuple() { +// return std::make_tuple(reapply_views, size); +// } +// +// bool reapply_views; +// std::vector size; +// }; +// +// The forward function describes how to replay view1 on a tensor. +// +// The reverse function describes how, given a tensor that is already a view, // how to get the corresponding base tensor. See Note [Functionalization Pass: // View Inverses] for details. +// +// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type +// representing the `ViewMeta` instance state. Methods that take in/return such +// a type are used for supporting pickle serialization. struct ViewMeta { ViewMeta( - std::function forward, - std::function reverse, bool has_symbolic_inputs, bool is_multi_output = false, bool is_as_strided = false, int64_t out_idx = 0) - : forward_fn(std::move(forward)), - reverse_fn(std::move(reverse)), - out_index(out_idx), + : out_index(out_idx), is_multi_output(is_multi_output), is_as_strided(is_as_strided), has_symbolic_inputs(has_symbolic_inputs) {} - std::function forward_fn; - std::function reverse_fn; + virtual ~ViewMeta() {} + + virtual Tensor forward(const Tensor& base) = 0; + virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; + // See Note [out_idx in ViewMeta] int64_t out_index; @@ -57,10 +102,17 @@ struct ViewMeta { // Tells us if this view operation has any symbolic inputs bool has_symbolic_inputs; - // Returns a copy of the current ViewMeta, if out_idx matches the current - // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse + // Returns a new ViewMeta with the same forward/reverse // functions, but a new out index. - ViewMeta to_out_idx(int64_t out_idx); + // + // This method should be implemented by those `ViewMeta` that have more than + // one output. + virtual std::shared_ptr to_out_index(int64_t out_index) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "ViewMeta::to_out_index not implemented. ", + "Likely because there's only one output."); + } }; // FunctionalStorageImpl is a subclass of StorageImpl used by the @@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const at::Tensor new_val; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::vector view_metas; + const std::vector> view_metas; }; explicit FunctionalStorageImpl(const Tensor& value); void add_update( const Tensor& updated_val, - const std::vector& view_metas); + const std::vector>& view_metas); bool apply_updates(); const Tensor& base() { return base_; diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 409f944a88e..4aed2aac4a0 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -129,17 +129,19 @@ void FunctionalTensorWrapper::freeze_storage() const { // - view_value: The output tensor that we need to wrap. // - base: The "base" of the view that `view_value` was generated from. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. -FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) - : c10::TensorImpl( - c10::DispatchKeySet(DispatchKey::Functionalize), - view_value.dtype(), - view_value.device() - ), - value_(view_value), - is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), - was_storage_changed_(base->was_storage_changed_), - is_symbolic_(base->is_symbolic_) -{ +FunctionalTensorWrapper::FunctionalTensorWrapper( + const Tensor& view_value, + const FunctionalTensorWrapper* base, + const std::shared_ptr& meta) + : c10::TensorImpl( + c10::DispatchKeySet(DispatchKey::Functionalize), + view_value.dtype(), + view_value.device()), + value_(view_value), + is_multi_output_view_( + base->is_multi_output_view_ || meta->is_multi_output), + was_storage_changed_(base->was_storage_changed_), + is_symbolic_(base->is_symbolic_) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); @@ -148,11 +150,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); - maybe_mark_symbolic(meta); + maybe_mark_symbolic(meta.get()); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } - functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { return static_cast(storage_.unsafeGetStorageImpl()); } @@ -176,18 +177,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const { } // See Note [Functionalization Pass - Inplace View Ops] -void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { +void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr& meta) { view_metas_.push_back(meta); // Manually track the fact that this tensor recieved a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. - maybe_mark_symbolic(meta); + maybe_mark_symbolic(meta.get()); // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. at::AutoDispatchSkipFunctionalize guard; - value_ = meta.forward_fn(value_, meta.out_index); + value_ = meta->forward(value_); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } @@ -368,15 +369,8 @@ void FunctionalTensorWrapper::sync_() { regenerate_from_base(); } -Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { - auto t = base; - - // Reapply views to get the viewed tensor from the base in alias_ - for (auto& view_meta: view_metas_) { - t = view_meta.forward_fn(t, view_meta.out_index); - } - - return t; +const std::vector>& FunctionalTensorWrapper::view_metas() const { + return view_metas_; } void FunctionalTensorWrapper::regenerate_from_base() { @@ -385,7 +379,7 @@ void FunctionalTensorWrapper::regenerate_from_base() { auto t = storage_impl->base(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); - t = apply_view_metas(t); + t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); replace_(t, /*from_lazy_regenerate=*/true); @@ -759,20 +753,28 @@ void freeze_functional_tensor(const Tensor& tensor) { functional_base_impl->freeze_storage(); } -Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { +Tensor create_functional_tensor_with_view_meta( + const at::Tensor& view_to_wrap, + const at::Tensor& base, + const std::shared_ptr& meta, + int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); + auto meta_ = meta; if (out_idx != 0) { // Note [out_idx in ViewMeta] // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. - meta = meta.to_out_idx(out_idx); + meta_ = meta->to_out_index(out_idx); } - return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); + return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta_); } -std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { +std::vector create_functional_tensor_with_view_meta( + ITensorListRef view_to_wrap, + const at::Tensor& base, + const std::shared_ptr& meta) { std::vector outputs(view_to_wrap.size()); int64_t i = 0; for (const auto& tensor : view_to_wrap) { @@ -782,12 +784,22 @@ std::vector create_functional_tensor_with_view_meta(ITensorListRef view_ return outputs; } -void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { +void mutate_view_meta(const at::Tensor& self, const std::shared_ptr& meta) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); self_impl->mutate_view_meta(meta); } +Tensor apply_view_meta_sequence( + const Tensor& base, + const std::vector>& sequence) { + Tensor r = base; + for (auto& vm : sequence) { + r = vm->forward(r); + } + return r; +} + // Note [Propagating strides in the functionalization pass] // In order to properly compute stride information, the functionalization pass // calls each {view} reference implementations with meta tensors. diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index c418ef39427..f25a5637de3 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { explicit FunctionalTensorWrapper( const Tensor& view_value, const FunctionalTensorWrapper* base, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); // Get the underlying, actual tensor, that doesn't know anything about // functionalization. @@ -97,17 +97,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { ->are_all_mutations_under_no_grad_or_inference_mode(); } - void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { - is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; + void maybe_mark_symbolic(functionalization::ViewMeta* meta) { + is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; } bool is_symbolic() const { return is_symbolic_; } - // Runs the forward_fn of every ViewMeta collected in the current instance - // to some other base. - Tensor apply_view_metas(const Tensor& base); + // Retrieves the ViewMeta sequence of this tensor. + const std::vector>& view_metas() + const; // Sync's the underlying tensor with its alias, if it's out of date. This // involves two steps: 1) Apply any pending updates/mutations to the alias 2) @@ -144,7 +144,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // from the base tensor. This method is used by inplace-view ops like // transpose_. It appends a ViewMeta to the existing stack, and refreshes the // tensor by replaying the views off of the alias. - void mutate_view_meta(const at::functionalization::ViewMeta& meta); + void mutate_view_meta( + const std::shared_ptr& meta); // Custom implementation of self.set_(src) void set__impl(const FunctionalTensorWrapper* other); @@ -273,7 +274,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_symbolic_ = false; size_t generation_ = 0; - std::vector view_metas_; + std::vector> view_metas_; protected: static void copy_tensor_metadata( @@ -365,16 +366,20 @@ TORCH_API void propagate_xla_data_direct( Tensor create_functional_tensor_with_view_meta( const Tensor& view_to_wrap, const Tensor& base, - functionalization::ViewMeta meta, + const std::shared_ptr& meta, int64_t out_idx = 0); std::vector create_functional_tensor_with_view_meta( ITensorListRef view_to_wrap, const Tensor& base, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); void mutate_view_meta( const Tensor& self, - const functionalization::ViewMeta& meta); + const std::shared_ptr& meta); + +TORCH_API Tensor apply_view_meta_sequence( + const Tensor& base, + const std::vector>& sequence); void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset( diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 36b6f91c1d9..1bf805d134f 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -27,6 +29,31 @@ #include #endif +namespace at::functionalization { + +Tensor resize__ViewMeta::forward(const Tensor& base) { + if (reapply_views) { + return base.as_strided(size, c10::contiguous_strides(size)); + } else { + return at::as_strided_copy(base, size, c10::contiguous_strides(size)); + } +} + +Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { + return base.as_strided_scatter( + mutated_view, size, c10::contiguous_strides(size)); +} + +Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) { + return at::_unsafe_view_symint(base, size); +} + +Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { + return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); +} + +} // namespace at::functionalization + namespace { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -168,19 +195,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch // The output of resizing is equivalent to taking a slice of a larger tensor. // We have to emulate this "slicing" with an as_strided call. auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - if (reapply_views) { - return base.as_strided(size, c10::contiguous_strides(size)); - } else { - return at::as_strided_copy(base, size, c10::contiguous_strides(size)); - } - }, - [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); - }, - /*has_symbolic_inputs=*/false - ); + auto view_meta = std::make_shared( + reapply_views, size.vec()); at::functionalization::impl::mutate_view_meta(self, view_meta); return self; } @@ -299,17 +315,11 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt tmp_output = at::_unsafe_view_symint(self_, size); } - bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); - - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return at::_unsafe_view_symint(base, size); - }, - [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { - return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); - }, - /*has_symbolic_inputs=*/has_symbolic_inputs - ); + bool has_symbolic_inputs = std::any_of( + size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); + auto view_meta = + std::make_shared( + has_symbolic_inputs, size.vec()); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); // See Note [Propagating strides in the functionalization pass] diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.h b/aten/src/ATen/FunctionalizeFallbackKernel.h new file mode 100644 index 00000000000..cd4f64a70fa --- /dev/null +++ b/aten/src/ATen/FunctionalizeFallbackKernel.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +namespace at::functionalization { + +// `ViewMeta` implementation for `resize_` operation. +struct TORCH_API resize__ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta); + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* reapply_views */, + const std::vector&); + + resize__ViewMeta(const SerializableTuple& tpl) + : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + resize__ViewMeta(bool reapply_views, const std::vector& size) + : ViewMeta(/*has_symbolic_inputs=*/false), + reapply_views(reapply_views), + size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(reapply_views, size); + } + + bool reapply_views; + std::vector size; +}; + +// `ViewMeta` implementation for `_unsafe_view` operation. +struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta); + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* has_symbolic_inputs */, + const std::vector&); + + _unsafe_view_ViewMeta(const SerializableTuple& tpl) + : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + _unsafe_view_ViewMeta( + bool has_symbolic_inputs, + const std::vector& size) + : ViewMeta(has_symbolic_inputs), size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(has_symbolic_inputs, size); + } + + std::vector size; +}; + +} // namespace at::functionalization diff --git a/aten/src/ATen/templates/FunctionalInverses.h b/aten/src/ATen/templates/FunctionalInverses.h index 3217e097d7a..b15cd09a6c6 100644 --- a/aten/src/ATen/templates/FunctionalInverses.h +++ b/aten/src/ATen/templates/FunctionalInverses.h @@ -2,22 +2,12 @@ // ${generated_comment} +#include #include namespace at { namespace functionalization { -enum class InverseReturnMode { - /// Specifies that functional inverses should always return a view. - AlwaysView, - /// Specifies that functional inverses should always return a non-view / copy. - NeverView, - /// Specifies that functional inverses should return a view unless a (copying) scatter - /// inverse exists, in which case that will be used instead. - /// This avoids as_strided() calls that can be difficult for subclasses to handle. - ViewOrScatterInverse, -}; - struct FunctionalInverses { ${view_inverse_declarations} diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index 999c06e2cb8..93848d673f8 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/templates/ViewMetaClasses.cpp b/aten/src/ATen/templates/ViewMetaClasses.cpp new file mode 100644 index 00000000000..0fd53171935 --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClasses.cpp @@ -0,0 +1,19 @@ +// ${generated_comment} + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +${op_headers} +#endif + +namespace at { +namespace functionalization { + +${view_meta_implementations} + +} // namespace functionalization +} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClasses.h b/aten/src/ATen/templates/ViewMetaClasses.h new file mode 100644 index 00000000000..be2dee2a871 --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClasses.h @@ -0,0 +1,12 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include + +namespace at { +namespace functionalization { + +${view_meta_declarations} + +} // namespace functionalization +} // namespace at diff --git a/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp b/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp new file mode 100644 index 00000000000..c784e5abe5c --- /dev/null +++ b/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp @@ -0,0 +1,11 @@ +#include +#include + +namespace torch::functionalization { + +void initGenerated(PyObject* module) { + auto functionalization = py::handle(module).cast(); + $view_meta_bindings +} + +} // namespace torch::functionalization diff --git a/build.bzl b/build.bzl index ad8ea1c8cef..6acbe49d790 100644 --- a/build.bzl +++ b/build.bzl @@ -117,6 +117,7 @@ def define_targets(rules): ":LazyNonNativeIr.h", ":RegisterDispatchDefinitions.ini", ":RegisterDispatchKey.cpp", + ":ViewMetaClassesPythonBinding.cpp", ":native_functions.yaml", ":shape_inference.h", ":tags.yaml", @@ -297,6 +298,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [ "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", + "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ] GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP diff --git a/build_variables.bzl b/build_variables.bzl index 8bd8ad3a8df..a206c6a4f9a 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -929,6 +929,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", + "torch/csrc/functionalization/Module.cpp", "torch/csrc/instruction_counter/Module.cpp", ] + lazy_tensor_core_python_sources diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 11b590a4817..7e4174a212d 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -310,6 +310,7 @@ set(GENERATED_CXX_PYTHON "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" + "${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" ) set(GENERATED_H_PYTHON @@ -373,6 +374,7 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" + "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp" ${autograd_python} ${autograd_yaml} ${autograd_templates} diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index d543c7028b0..69f8310c4f6 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -250,11 +250,7 @@ class AOTAutogradCacheTests(InductorTestCase): @functorch_config.patch( {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} ) - def test_view_replay_bypass(self): - """ - Shoud bypass when view replay is turned on - """ - + def test_view_replay(self): def fn(a): tmp = a.detach() a.mul_(2) @@ -262,10 +258,25 @@ class AOTAutogradCacheTests(InductorTestCase): with torch.autograd._force_original_view_tracking(True): compiled_fn = torch.compile(fn) - compiled_fn(torch.rand(2, 3)) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + def run_and_check(miss, hit, bypass): + self._clear_dynamo_and_codecache() + + inp = torch.rand(2, 3) + compiled_inp = inp.clone().detach() + + with torch.autograd._force_original_view_tracking(True): + out = fn(inp) + compiled_out = compiled_fn(compiled_inp) + + self.assertEqual(out, compiled_out) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass) + + run_and_check(miss=1, hit=0, bypass=0) + run_and_check(miss=1, hit=1, bypass=0) + run_and_check(miss=1, hit=2, bypass=0) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", False) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 50ef291417b..fc1efde4883 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6897,7 +6897,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): { "enable_autograd_cache": True, "strict_autograd_cache": True, - "view_replay_for_aliased_outputs": False, } ) @torch._inductor.config.patch("fx_graph_cache", True) diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 6e0a64888f0..a57732e5eba 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -189,6 +189,12 @@ def main() -> None: ) options = parser.parse_args() + # Path: aten/src/ATen + aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) + operator_selector = get_selector( + options.selected_op_list_path, options.operators_yaml_path + ) + generate_code( options.gen_dir, options.native_functions_path, @@ -198,13 +204,32 @@ def main() -> None: options.disable_autograd, options.force_schema_registration, # options.selected_op_list - operator_selector=get_selector( - options.selected_op_list_path, options.operators_yaml_path - ), + operator_selector=operator_selector, + ) + + # Generate the python bindings for functionalization's `ViewMeta` classes. + from torchgen.gen_functionalization_type import ( + gen_functionalization_view_meta_classes, + ) + + functionalization_templates_dir = os.path.join(aten_path, "templates") + functionalization_install_dir = os.path.join( + options.gen_dir, "torch/csrc/functionalization/generated" + ) + + os.makedirs(functionalization_install_dir, exist_ok=True) + assert os.path.isdir(functionalization_install_dir) + assert os.path.isdir(functionalization_templates_dir) + + gen_functionalization_view_meta_classes( + options.native_functions_path or NATIVE_FUNCTIONS_PATH, + options.tags_path or TAGS_PATH, + selector=operator_selector, + install_dir=functionalization_install_dir, + template_dir=functionalization_templates_dir, ) if options.gen_lazy_ts_backend: - aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7bd149dd076..b068601f1d2 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -67,6 +67,7 @@ from . import ( _export, _cpu, _dynamo, + _functionalization, _functorch, _lazy, _lazy_ts_backend, diff --git a/torch/_C/_functionalization.pyi b/torch/_C/_functionalization.pyi new file mode 100644 index 00000000000..4e00df97e27 --- /dev/null +++ b/torch/_C/_functionalization.pyi @@ -0,0 +1,16 @@ +from torch import Tensor +from torch.types import _bool + +# Defined in torch/csrc/functionalization/Module.cpp + +class ViewMeta: + has_symbolic_inputs: _bool + +# Returns the list of ViewMeta instances of the given functional tensor. +# +# Although we do have python bindings for their types, we won't +# expose them here, since they should not be used by users. +def get_view_meta_sequence(tensor: Tensor) -> list[ViewMeta]: ... + +# Applies the ViewMeta sequence on top of the given base. +def apply_view_meta_sequence(base: Tensor, sequence: list[ViewMeta]) -> Tensor: ... diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 38092a99225..c112c892922 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -227,19 +227,6 @@ def check_cacheable(gm: torch.fx.GraphModule): check_node_safe(node) -def check_metadata_cacheable(metadata: ViewAndMutationMeta): - """ - When view replay is turned on, we bypass autograd cache if - the output is aliased. - """ - if config.view_replay_for_aliased_outputs: - for info in metadata.output_info: - if info.functional_tensor is not None: - raise BypassAOTAutogradCache( - "Cannot cache a graph with functional tensor" - ) - - class AOTAutogradCacheDetails(FxGraphHashDetails): """ Object to capture all the details for a dynamo graph module relevant to computing @@ -875,7 +862,6 @@ class AOTAutogradCache: def save(key: str, entry: AOTAutogradCacheEntry, remote: bool): """Save a single entry into the cache.""" try: - check_metadata_cacheable(entry.runtime_metadata) content = pickle.dumps(entry) CacheArtifactManager.record_artifact( CacheArtifactType.AOT_AUTOGRAD, key, content diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 0b723535824..1476b181812 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -36,10 +36,10 @@ from .functional_utils import ( has_metadata_mutation, MetadataKey, to_fun, + ViewMetaSequence, was_inductor_storage_resized, ) from .schemas import ( - FunctionalTensorMetadataEq, InputAliasInfo, MutationType, OutputAliasInfo, @@ -604,7 +604,7 @@ from a multi-output view call" # # The FunctionalTensor will be saved if one of the 2 conditions below # is true: - functional_tensor = None + view_meta_sequence = None if ( # 1. If the output_type is either of: # (i) alias_of_intermediate; @@ -636,7 +636,7 @@ from a multi-output view call" and not input_info[base_idx].mutates_metadata ): if isinstance(o, FunctionalTensor): - functional_tensor = FunctionalTensorMetadataEq(o.elem) + view_meta_sequence = ViewMetaSequence(o) out_info = OutputAliasInfo( output_type=output_type, @@ -644,7 +644,7 @@ from a multi-output view call" base_idx=base_idx, dynamic_dims=dynamic_dims, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, - functional_tensor=functional_tensor, + view_meta_sequence=view_meta_sequence, ) output_info.append(out_info) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index fb509296705..4add4ae845d 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -13,15 +13,12 @@ from typing import Optional, Tuple import torch from torch import Tensor +from torch._C import _functionalization from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.meta_utils import is_sparse_any -from torch.fx.experimental.symbolic_shapes import ( - definitely_true, - sym_eq, - SymIntEqByExpr, -) +from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, @@ -227,9 +224,9 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, + target_view_meta_sequence: Optional[ViewMetaSequence] = None, *, - replay_views, + replay_views: bool, ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -248,13 +245,11 @@ def gen_alias_from_base( # to replay them (view functions) on the aliased_base_tensor. if ( replay_views - and target_functional_tensor is not None - and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) + and target_view_meta_sequence is not None + and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) ): - functional_tensor = target_functional_tensor.tensor - - out = torch._functionalize_apply_view_metas( - functional_tensor, aliased_base_tensor + out = _functionalization.apply_view_meta_sequence( + aliased_base_tensor, target_view_meta_sequence.sequence ) # If re-applying the ViewMeta sequence succeeded, there should be no more # problems going forward. We just check we got to the target shape and @@ -315,28 +310,8 @@ def gen_alias_from_base( return aliased_out -def has_same_metadata(t1, t2): - return ( - definitely_true(sym_eq(t1.size(), t2.size())) - and definitely_true(t1.layout == t2.layout) - and ( - is_sparse_any(t1) - or ( - definitely_true(sym_eq(t1.stride(), t2.stride())) - and definitely_true(t1.storage_offset() == t2.storage_offset()) - ) - ) - and t1.is_conj() == t2.is_conj() - and t1.is_neg() == t2.is_neg() - ) - - @dataclass(frozen=True) class MetadataKey: - """ - This should be equal whenever has_same_metadata would return True - """ - size: Tuple[SymIntEqByExpr, ...] layout: torch.layout is_sparse: bool @@ -360,25 +335,45 @@ class MetadataKey: ) -# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata -# after applying all the ViewMeta operations. -class FunctionalTensorMetadataEq: - def __init__(self, tensor: torch.Tensor) -> None: - assert torch._is_functional_tensor(tensor) - self.tensor = tensor +# ViewMeta sequence wrapper for equality comparisons. +# +# Even though we can compare each ViewMeta instance, we compare the resulting +# tensor metadata, instead. That's because the creation of synthetic bases + the +# re-generation of input views might end-up creating a different sequence of +# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same +# metadata. +# +# Therefore, we store what the end result should look like as serializable +# metadata. +# +# When logging, this class should look like: +# +# ViewMetaSequence(view, select_int, slice_Tensor) +# +# i.e. a parenthesized list of view operations within that ViewMeta sequence. +class ViewMetaSequence: + def __init__(self, tensor: FunctionalTensor) -> None: + assert torch._is_functional_tensor(tensor.elem) + self.sequence = _functionalization.get_view_meta_sequence(tensor.elem) + self.metadata = MetadataKey.make(tensor) + + def __repr__(self) -> str: + suffix = len("_ViewMeta") + types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence) + return f"ViewMetaSequence({types})" def __eq__(self, other: object) -> bool: # If other is None, then it probably means that we weren't able to recreate - # the FunctionalTensorMetadataEq. One of this cases is when we update the - # view metadata by calling: create_synthetic_base_metadata. + # the ViewMeta sequence. One example is when we update the view metadata by + # calling: create_synthetic_base_metadata. if other is None: return True - # Comparison agains any other type is not implemented. - if not isinstance(other, FunctionalTensorMetadataEq): + # Comparison against any other type is not implemented. + if not isinstance(other, ViewMetaSequence): return NotImplemented - return has_same_metadata(self.tensor, other.tensor) + return self.metadata == other.metadata # new_arg and arg here are either: diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 727b3af1e32..faa10e33547 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -75,7 +75,7 @@ def remove_dupe_metadata( dynamic_dims=o.dynamic_dims, base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], requires_grad=o.requires_grad, - functional_tensor=o.functional_tensor, + view_meta_sequence=o.view_meta_sequence, ) for o in m.output_info ], @@ -226,7 +226,7 @@ def create_synthetic_base_metadata( # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases base_idx=new_base_idx, # type: ignore[arg-type] requires_grad=o.requires_grad, - functional_tensor=o.functional_tensor, + view_meta_sequence=o.view_meta_sequence, ) ) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 604d6540849..b81d3e92901 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -172,7 +172,7 @@ class AliasOfInputHandler: self.base_idx = info.base_idx self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.functional_tensor = info.functional_tensor + self.view_meta_sequence = info.view_meta_sequence self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -181,7 +181,7 @@ class AliasOfInputHandler: aliased_base_tensor, self.unwrap_out(out), self.requires_grad, - self.functional_tensor, + self.view_meta_sequence, replay_views=self.replay_views, ) @@ -209,7 +209,7 @@ class AliasOfIntermediateHandler: self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad - self.functional_tensor = info.functional_tensor + self.view_meta_sequence = info.view_meta_sequence self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -218,7 +218,7 @@ class AliasOfIntermediateHandler: aliased_base_tensor, self.unwrap_out(out), self.requires_grad, - self.functional_tensor, + self.view_meta_sequence, replay_views=self.replay_views, ) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index bab5b7c35ad..22b9941ee40 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -5,7 +5,6 @@ input/output types, metadata, config, function signatures etc. """ import collections -import dataclasses import functools from dataclasses import dataclass, field from enum import Enum @@ -20,10 +19,7 @@ from torch._subclasses.fake_tensor import is_fake from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config -from .functional_utils import ( - _check_if_mutation_can_be_in_graph, - FunctionalTensorMetadataEq, -) +from .functional_utils import _check_if_mutation_can_be_in_graph, ViewMetaSequence from .utils import strict_zip @@ -92,15 +88,14 @@ class OutputAliasInfo: dynamic_dims: Optional[Set[int]] # requires_grad requires_grad: bool - # FunctionalTensorWrapper that represents this output. + # Sequence of ViewMeta objects. # - # Provides us the means to replay views from it. + # Provides us the means to re-run view functions on other tensors. # - # We need to wrap the actual FunctionalTensorWrapper with this class so that - # we only compare the tensor's metadata. That's because with the transformations - # of the model throughout AOTAutograd, the sequence of ViewMeta and the base - # tensor might change. - functional_tensor: Optional[FunctionalTensorMetadataEq] = None + # We need to wrap the actual list of ViewMeta with this class so that + # we compare the ViewMeta elements appropriately, i.e. their type and + # the elements returned by the `as_tuple()` call. + view_meta_sequence: Optional[ViewMetaSequence] = None class MutationType(Enum): @@ -582,17 +577,6 @@ class ViewAndMutationMeta: self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] # Clear traced tangents at runtime self.traced_tangents = [] - new_output_info = [] - for out in self.output_info: - if config.view_replay_for_aliased_outputs: - new_out = out - else: - # If we're not using view_replay, remove the functional tensor. - # Functional tensors are unfortunately not serializable, - # so doing this is required for AOTAutograd caching. - new_out = dataclasses.replace(out, functional_tensor=None) - new_output_info.append(new_out) - self.output_info = new_output_info for inp_meta in self.subclass_inp_meta: if isinstance(inp_meta, SubclassCreationMeta): inp_meta.make_runtime_safe() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 59c2e96be54..9f6ac1af89a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -71,6 +71,7 @@ #include #include #include +#include #include #include #include @@ -1869,6 +1870,7 @@ PyObject* initModule() { torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); + torch::functionalization::initModule(module); #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index a4d9eed924b..b125413f2e7 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -633,15 +633,6 @@ void initTorchFunctions(PyObject* module) { at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::mark_mutation_hidden_from_autograd(t); }); - py_module.def( - "_functionalize_apply_view_metas", - [](const at::Tensor& tensor, const at::Tensor& base) { - TORCH_INTERNAL_ASSERT( - at::functionalization::impl::isFunctionalTensor(tensor)); - auto impl = - at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); - return impl->apply_view_metas(base); - }); py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); diff --git a/torch/csrc/functionalization/Module.cpp b/torch/csrc/functionalization/Module.cpp new file mode 100644 index 00000000000..d38cb107805 --- /dev/null +++ b/torch/csrc/functionalization/Module.cpp @@ -0,0 +1,71 @@ +#include +#include + +#include +#include +#include +#include + +namespace torch::functionalization { + +void initModule(PyObject* module) { + auto m = py::handle(module).cast(); + + // Create a `torch._C._functionalization` Python module. + auto functionalization = m.def_submodule( + "_functionalization", "functionalization related pybind."); + + // Retrieve the ViewMeta sequence of a given functional tensor. + functionalization.def("get_view_meta_sequence", [](const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + return impl->view_metas(); + }); + + // Applies the given ViewMeta sequence to the given base. + functionalization.def( + "apply_view_meta_sequence", + [](const at::Tensor& base, + const std::vector>& + sequence) { + return at::functionalization::impl::apply_view_meta_sequence( + base, sequence); + }); + + // Binding for InverseReturnMode. + py::enum_( + functionalization, "InverseReturnMode") + .value("AlwaysView", at::functionalization::InverseReturnMode::AlwaysView) + .value("NeverView", at::functionalization::InverseReturnMode::NeverView) + .value( + "ViewOrScatterInverse", + at::functionalization::InverseReturnMode::ViewOrScatterInverse); + + // Create bindings for the ViewMeta base class. + // + // Needed so that we can take a list of ViewMeta objects as parameter. + // Specifically, in the Python-side, we will have a list of derived ViewMeta + // classes. We need to tell pybind11 that all of those are, in fact, instances + // of different ViewMeta sub-types. + py::class_< + at::functionalization::ViewMeta, + std::shared_ptr>( + functionalization, "ViewMeta") + .def_property_readonly( + "has_symbolic_inputs", + [](const std::shared_ptr& meta) { + return meta->has_symbolic_inputs; + }); + + // Bindings for `ViewMeta` specializations manually implemented. + create_binding_with_pickle( + functionalization); + create_binding_with_pickle( + functionalization); + + // Bindings for `ViewMeta` specializations automatically generated. + initGenerated(functionalization.ptr()); +} + +} // namespace torch::functionalization diff --git a/torch/csrc/functionalization/Module.h b/torch/csrc/functionalization/Module.h new file mode 100644 index 00000000000..2f77fd3098c --- /dev/null +++ b/torch/csrc/functionalization/Module.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include +#include + +namespace torch::functionalization { + +// Creates the default bindings for `ViewMeta` specializations. +// +// Defines a constructor using the types in `SerializableTuple`, as well +// as pickle methods. +template +void create_binding_with_pickle(py::module m) { + py::class_, at::functionalization::ViewMeta>( + m, T::name()) + .def(py::init()) + .def( + "as_tuple", + [](const std::shared_ptr& meta) { + return meta->to_serializable_tuple(); + }) + .def(py::pickle( + [](const std::shared_ptr& meta) { + return meta->to_serializable_tuple(); + }, + [](const typename T::SerializableTuple& tpl) { + return std::make_shared(tpl); + })); +} + +void initModule(PyObject* module); +void initGenerated(PyObject* module); + +} // namespace torch::functionalization diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 93667e39b17..367156d9ba7 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -23,20 +23,13 @@ from torchgen.model import ( # This file describes the translation of JIT schema to API's used -# when creating view lambdas that are used by the functionalization pass. -# There are two types of lambdas: forward lambdas and reverse lambdas. -# These API's mostly follow the dispatcher API, with a few quirks: -# - The lambda capture has to convert reference types to value types -# - While the forward lambda just directly calls into the at::_ops API -# (following the dispatcher convention), the logic here for the reverse lambda +# when creating `ViewMeta` specializations that are used by the functionalization pass. +# These API's mostly follow the dispatcher API, with one difference: +# - While the forward function just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse function # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). -# The lambdas generated for each view op in the functionalization pass are of the form -# [capture_arguments](outer_arguments) -> returns_type { -# return name(inner_arguments); -# } - # Define some specific lambda input arguments. base_binding = Binding( name="base", @@ -46,6 +39,18 @@ base_binding = Binding( ), default=None, ) + +has_symbolic_inputs_binding = Binding( + name="has_symbolic_inputs", + nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)), + argument=Argument( + name="has_symbolic_inputs", + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), @@ -54,11 +59,11 @@ mutated_view_binding = Binding( ), default=None, ) -mutated_view_idx_binding = Binding( - name="mutated_view_idx", - nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), +out_index_binding = Binding( + name="out_index", + nctype=NamedCType(name="out_index", type=BaseCType(longT)), argument=Argument( - name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None ), default=None, ) @@ -86,8 +91,13 @@ inverse_return_mode_binding = Binding( ) -# The lambda capture itself doesn't have a name. -# The name returned here corresponds to the name of the inner function called by the lambda. +# Name of the `ViewMeta` specialization class created. +def classname(func: FunctionSchema, with_namespace: bool = False) -> str: + namespace = "at::functionalization::" if with_namespace else "" + return f"{namespace}{func.name.unambiguous_name()}_ViewMeta" + + +# Name of the operation called inside the `forward`/`reverse` implementations. def name( g: NativeFunctionsViewGroup, *, @@ -124,24 +134,6 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str: return f"{api_name}_inverse" -def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: - # capture arguments include all arguments except `self`. - # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), - # So any reference types (IntArrayRef) need to be converted to value types (vector) - args = func.arguments.flat_all - assert args[0].type == BaseType(BaseTy.Tensor) - non_self_args = args[1:] - non_self_value_bindings = [ - dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args - ] - - all_bindings = [ - inverse_return_mode_binding if is_reverse else reapply_views_binding - ] - all_bindings.extend(non_self_value_bindings) - return all_bindings - - def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 @@ -152,24 +144,49 @@ def returns_type(func: FunctionSchema) -> CType: return BaseCType(tensorT) -def outer_arguments(*, is_reverse: bool) -> list[Binding]: - if is_reverse: - return [base_binding, mutated_view_binding, mutated_view_idx_binding] - else: - return [base_binding, mutated_view_idx_binding] +# Checks whether `func` might return more than one value. +def is_multi_output(func: FunctionSchema) -> bool: + return len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None + ) -def inner_call_index(func: FunctionSchema) -> Binding | None: - # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. - # When we replay a view op that returns multiple tensors, we need to index into the output appropriately - if len(func.returns) > 1 or ( - len(func.returns) == 1 and func.returns[0].type.is_list_like() - ): - return mutated_view_idx_binding - return None +# `ViewMeta` specialization constructor parameters. +def base_ctor_arguments(func: FunctionSchema) -> list[Binding]: + # All specializations are paremeterized by `has_symbolic_inputs` flag. + arguments = [has_symbolic_inputs_binding] + + # If `func` might return more than 1 value, we also parameterize this specialization + # with the output index. + if is_multi_output(func): + arguments.append(out_index_binding) + + return arguments -def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +# `ViewMeta` specialized class' constructor arguments. +# +# Values needed specifically by this specialization, that the base class does not need. +# Same as the class' attributes, but non-owning. +def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]: + return attributes(func, owning=False) + + +# `ViewMeta` specialized class' non-static member data. +# +# Essential data for calling the instance's `forward` and `reverse functions. You can +# think of them as values that should be captured from the functionalization kernel. +def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + return [ + reapply_views_binding, + inverse_return_mode_binding, + *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]], + ] + + +def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] @@ -183,13 +200,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. - index_binding = inner_call_index(func) - if index_binding is not None: + if is_multi_output(func): return [ base_binding, mutated_view_binding, inverse_return_mode_binding, - index_binding, + out_index_binding, ] + non_self_bindings else: return [ diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index d7c60e52d93..f34028a5aa7 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -300,83 +300,11 @@ class ViewInverseSignature: return_type = functionalization.returns_type(self.g.view.func) decls = [ a.decl() - for a in functionalization.inner_arguments( - self.g.view.func, is_reverse=True - ) + for a in functionalization.op_arguments(self.g.view.func, is_reverse=True) ] return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" -@dataclass(frozen=True) -class FunctionalizationLambda: - g: NativeFunctionsViewGroup - - # are we generating the forward lambda or the reverse lambda? - is_reverse: bool - - def captures(self) -> list[Expr]: - # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments - # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, - # and plumb it into the lambda. - outer_ctx = dispatcher.arguments(self.g.view.func) + [ - functionalization.reapply_views_binding, - functionalization.inverse_return_mode_binding, - ] - capture_bindings = functionalization.capture_arguments( - self.g.view.func, is_reverse=self.is_reverse - ) - # allow_expensive_conversions is set because we want to convert - # some reference types (IntArrayRef) to value types (vector). - capture_exprs = translate.translate( - outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True - ) - return capture_exprs - - def decl(self) -> str: - return_type = functionalization.returns_type(self.g.view.func) - capture_str = ", ".join( - f"{val.type.name} = {val.expr}" for val in self.captures() - ) - decls = [ - a.decl() - for a in functionalization.outer_arguments(is_reverse=self.is_reverse) - ] - return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" - - def inner_call(self, *, reapply_views: bool | None = None) -> str: - inner_call_name = functionalization.name( - self.g, - is_reverse=self.is_reverse, - include_namespace=True, - reapply_views=reapply_views, - ) - - arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) - capture_ctx = functionalization.capture_arguments( - self.g.view.func, is_reverse=self.is_reverse - ) - full_ctx = arg_ctx + capture_ctx - - assert self.g.view_copy is not None - call_bindings = functionalization.inner_arguments( - self.g.view_copy.func, is_reverse=self.is_reverse - ) - maybe_index = functionalization.inner_call_index(self.g.view_copy.func) - call_exprs = [ - e.expr for e in translate.translate(full_ctx, call_bindings, method=False) - ] - if not self.is_reverse and maybe_index is not None: - return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' - else: - return f'{inner_call_name}({", ".join(call_exprs)});' - - @staticmethod - def from_func( - g: NativeFunctionsViewGroup, *, is_reverse: bool - ) -> FunctionalizationLambda: - return FunctionalizationLambda(g, is_reverse) - - @dataclass(frozen=True) class StructuredImplSignature: g: NativeFunctionsGroup diff --git a/torchgen/gen.py b/torchgen/gen.py index e9a10b9c52e..5009495885b 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -45,6 +45,8 @@ from torchgen.gen_functionalization_type import ( gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, + gen_functionalization_view_meta_classes_decl, + gen_functionalization_view_meta_classes_impl, GenCompositeViewCopyKernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing @@ -2577,48 +2579,48 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f }, ) + def gen_op_headers( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> list[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + def functionalization_env_callable( g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, ) -> dict[str, list[str]]: - def gen_op_headers( - g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, - ) -> list[str]: - if isinstance(g, NativeFunctionsViewGroup): - # view ops always get a functionalization kernel - headers = [ - f"#include ", - f"#include ", - ] - if g.view_copy is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - elif isinstance(g, NativeFunctionsGroup): - headers = [ - f"#include ", - f"#include ", - f"#include ", - f"#include ", - ] - if g.inplace is not None: - headers += [ - f"#include ", - f"#include ", - ] - if g.mutable is not None: - headers += [ - f"#include ", - f"#include ", - ] - return headers - else: - return [ - f"#include ", - f"#include ", - ] - return { "ops_headers": gen_op_headers(g), "func_definitions": gen_functionalization_definition( @@ -2684,6 +2686,31 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f }, ) + cpu_fm.write( + "ViewMetaClasses.h", + lambda: { + "view_meta_declarations": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_decl(selector, g), + view_groups, + ) + ) + }, + ) + + cpu_fm.write( + "ViewMetaClasses.cpp", + lambda: { + "view_meta_implementations": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_impl(selector, g), + view_groups, + ) + ), + "op_headers": list(concatMap(gen_op_headers, view_groups)), + }, + ) + # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # needs to have a corresponding non-aliasing {view}_copy variant. diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 4f9865d6d3e..2d6ad376872 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,16 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, TYPE_CHECKING +from typing import Callable, Optional, TYPE_CHECKING -from torchgen.api import cpp, dispatcher +from torchgen.api import cpp, dispatcher, functionalization from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, CType, DispatcherSignature, - FunctionalizationLambda, iTensorListRefT, NativeSignature, OptionalCType, @@ -48,7 +47,7 @@ from torchgen.native_function_generation import ( MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) -from torchgen.utils import dataclass_repr +from torchgen.utils import concatMap, dataclass_repr, FileManager if TYPE_CHECKING: @@ -365,6 +364,8 @@ def emit_view_functionalization_body( with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) + spec = ViewMetaSpecialization(g, f=f) + # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) @@ -385,9 +386,6 @@ def emit_view_functionalization_body( for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] - forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) - reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) - # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ @@ -415,19 +413,7 @@ def emit_view_functionalization_body( : at::functionalization::InverseReturnMode::NeverView ); {symbolic_inputs_check} - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - {forward_lambda.decl()} {{ - if (reapply_views) {{ - return {forward_lambda.inner_call(reapply_views=True)} - }} else {{ - return {forward_lambda.inner_call(reapply_views=False)} - }} - }}, - {reverse_lambda.decl()} {{ - return {reverse_lambda.inner_call()} - }}, - /*has_symbolic_inputs=*/{symbolic_inputs_varname} - ); + auto view_meta = {spec.new()}; auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); @@ -455,7 +441,6 @@ def emit_view_functionalization_body( """ else: - is_multi_output_view = isinstance(f.func.returns[0].type, ListType) return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} @@ -489,21 +474,7 @@ def emit_view_functionalization_body( }} }} {symbolic_inputs_check} - at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( - {forward_lambda.decl()} {{ - if (reapply_views) {{ - return {forward_lambda.inner_call(reapply_views=True)} - }} else {{ - return {forward_lambda.inner_call(reapply_views=False)} - }} - }}, - {reverse_lambda.decl()} {{ - return {reverse_lambda.inner_call()} - }}, - /*has_symbolic_inputs=*/{symbolic_inputs_varname}, - /*is_multi_output=*/{str(is_multi_output_view).lower()}, - /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} - ); + auto view_meta = {spec.new()}; auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] if (compute_reference_meta) {{ @@ -771,6 +742,301 @@ def gen_functionalization_view_inverse_declaration( return emit_decl_helper(g) +# Helper class for generating `ViewMeta` specializations. +@dataclass +class ViewMetaSpecialization: + g: NativeFunctionsViewGroup + f: NativeFunction + + @property + def is_multi_output(self) -> bool: + return functionalization.is_multi_output(self.f.func) + + @property + def is_as_strided(self) -> bool: + return str(self.f.func.name) == "as_strided" + + @property + def out_index(self) -> str: + if self.is_multi_output: + return functionalization.out_index_binding.name + return "0" + + @property + def classname(self) -> str: + return functionalization.classname(self.f.func) + + def decl(self) -> list[str]: + base_ctor_arguments = functionalization.base_ctor_arguments(self.f.func) + extra_ctor_arguments = functionalization.extra_ctor_arguments(self.f.func) + attributes = functionalization.attributes(self.f.func) + + # List of types for declaring the `SerializableTuple` type. + serializable_tuple_args = ",\n".join( + f" {binding.type} /* {binding.name} */" + for binding in (base_ctor_arguments + attributes) + ) + + # Arguments used for forwarding the tuple elements to the constructor. + destructure_tuple_args = ", ".join( + f"std::get<{i}>(tpl)" + for i in range(len(base_ctor_arguments) + len(extra_ctor_arguments)) + ) + + # List of constructor parameters + ctor_parameters = ", ".join( + binding.decl() for binding in (base_ctor_arguments + extra_ctor_arguments) + ) + + # Call the base class `ViewMeta` constructor. + # + # Both of `is_multi_output` and `is_as_strided` are known values, given the + # operation schema. + is_multi_output_str = str(self.is_multi_output).lower() + is_as_strided_str = str(self.is_as_strided).lower() + + base_ctor_bindings = ", ".join( + [ + # `has_symbolic_inputs` is always taken as parameter. + functionalization.has_symbolic_inputs_binding.name, + f"/*is_multi_output=*/{is_multi_output_str}", + f"/*is_as_strided=*/{is_as_strided_str}", + # `out_index` is know if the operation returns only one value. Otherwise, + # we also take it as parameter. + f"/*out_index=*/{self.out_index}", + ] + ) + + # Assignments of `extra_ctor_arguments` to their corresponding fields. + # These are extra fields to-be-declared in this specialization. + # + # We need to set `allow_expensive_conversions`, since we are storing owned versions + # of the non-owning arguments. + ctor_assignments = ",\n".join( + f" {e.type.name}({e.expr})" + for e in translate( + extra_ctor_arguments, + attributes, + method=False, + allow_expensive_conversions=True, + ) + ) + + # List of arguments for constructing the `SerializableTuple` from an instance. + tuple_arguments = ", ".join( + binding.name for binding in (base_ctor_arguments + attributes) + ) + + # List of field declarations. + attr_declarations = "\n".join(f" {binding.decl()};" for binding in attributes) + + # Override `to_out_index` if this operation returns more than 1 value. + to_out_index_decl = "" + if self.is_multi_output: + to_out_index_decl = ( + " std::shared_ptr to_out_index(int64_t out_idx) override;" + ) + + return [ + f""" +struct TORCH_API {self.classname} : public ViewMeta {{ + FUNCTIONALIZATION_VIEWMETA_NAME({self.classname}); + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(\n{serializable_tuple_args}); + + {self.classname}(const SerializableTuple& tpl) + : {self.classname}({destructure_tuple_args}) {{}} + + {self.classname}({ctor_parameters}) + : at::functionalization::ViewMeta({base_ctor_bindings}), +{ctor_assignments} {{}} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; +{to_out_index_decl} + + SerializableTuple to_serializable_tuple() {{ + return std::make_tuple({tuple_arguments}); + }} + +{attr_declarations} +}}; +""" + ] + + # Generate a call to the actual operation. + def opcall(self, is_reverse: bool, reapply_views: bool) -> str: + opname = functionalization.name( + self.g, + is_reverse=is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + # Expected arguments for the operation. + assert self.g.view_copy is not None + op_arguments = functionalization.op_arguments(self.g.view_copy.func, is_reverse) + + # The context is composed by the constructor arguments (which are also + # the field variables stored in the instance), and the `base` tensor. + context = [functionalization.base_binding] + context += functionalization.base_ctor_arguments(self.f.func) + context += functionalization.attributes(self.f.func) + + # If we are generating the call for the reverse function, we also have + # access to `mutated_view` argument. + if is_reverse: + context.append(functionalization.mutated_view_binding) + + arguments = ", ".join( + [e.expr for e in translate(context, op_arguments, method=False)] + ) + + # Index the result if this operation returns multiple values. + maybe_index = "" + if not is_reverse and self.is_multi_output: + maybe_index = f"[{self.out_index}]" + + return f"{opname}({arguments}){maybe_index}" + + def impl(self) -> list[str]: + functions = [ + f""" +at::Tensor {self.classname}::forward(const at::Tensor& base) {{ + if (reapply_views) {{ + return {self.opcall(is_reverse=False, reapply_views=True)}; + }} else {{ + return {self.opcall(is_reverse=False, reapply_views=False)}; + }} +}}""", + f""" +at::Tensor {self.classname}::reverse(const at::Tensor& base, const Tensor& mutated_view) {{ + return {self.opcall(is_reverse=True, reapply_views=True)}; +}}""", + ] + + # If this operation returns multiple values, also generate a `to_out_index` + # implementation. + if self.is_multi_output: + functions.append(f""" +std::shared_ptr {self.classname}::to_out_index(int64_t out_index) {{ + return {self.new("out_index")}; +}} +""") + + return functions + + # Create the Python binding for this specialized class. + def binding(self) -> list[str]: + name = functionalization.classname(self.f.func, with_namespace=True) + return [f" create_binding_with_pickle<{name}>(functionalization);"] + + # Generate an instanciation of this specialized class. + def new(self, out_index: str = "0") -> str: + name = functionalization.classname(self.f.func, with_namespace=True) + ctor_arguments = functionalization.base_ctor_arguments( + self.f.func + ) + functionalization.extra_ctor_arguments(self.f.func) + # Replace the `out_index` parameter with the given `out_index`. + arguments = ", ".join( + binding.name if binding.name != "out_index" else out_index + for binding in ctor_arguments + ) + return f"std::make_shared<{name}>({arguments})" + + # Run the function `run` for both: `view` and `view_inplace` functions. + @staticmethod + def map( + g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]] + ) -> list[str]: + def maybe_run(f: Optional[NativeFunction]) -> list[str]: + if f is None: + return [] + with native_function_manager(f): + return run(ViewMetaSpecialization(g, f)) + + return list(concatMap(maybe_run, (g.view, g.view_inplace))) + + +def gen_functionalization_view_meta_classes_base( + selector: SelectiveBuilder, + g: NativeFunctionsViewGroup, + run: Callable[[ViewMetaSpecialization], list[str]], +) -> list[str]: + if not selector.include_all_operators: + return [] + + if g.composite: + return [] + + return ViewMetaSpecialization.map(g, run) + + +def gen_functionalization_view_meta_classes_decl( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.decl + ) + + +def gen_functionalization_view_meta_classes_impl( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.impl + ) + + +def gen_functionalization_view_meta_classes_binding( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> list[str]: + return gen_functionalization_view_meta_classes_base( + selector, g, ViewMetaSpecialization.binding + ) + + +# Generates the Python bindings for the `ViewMeta` specialized classes. +def gen_functionalization_view_meta_classes( + native_functions_path: str, + tags_path: str, + selector: SelectiveBuilder, + install_dir: str, + template_dir: str, +) -> None: + from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml + + # Parse the native_functions.yaml. + # Then, group them into `NativeFunctionsViewGroup`. + # + # This is the same steps we do in gen.py (ATen codegen). + native_functions = parse_native_yaml( + native_functions_path, tags_path + ).native_functions + native_functions_with_view_groups = get_grouped_by_view_native_functions( + native_functions + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + fm = FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False) + fm.write( + "ViewMetaClassesPythonBinding.cpp", + lambda: { + "view_meta_bindings": list( + concatMap( + lambda g: gen_functionalization_view_meta_classes_binding( + selector, g + ), + view_groups, + ) + ), + }, + ) + + def gen_functionalization_registration( selector: SelectiveBuilder, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,