From 0a89bdf9892b3021aca0bcd3df7388a20e24cfd1 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 9 Sep 2022 08:00:04 -0700 Subject: [PATCH] Set up aten/src/ATen/functorch directory; move some files there (#84648) This PR: - sets up aten/src/ATen/functorch in PyTorch's build system - Moves {BatchedTensorImpl.h, and BatchedTensorImpl.cpp} there as a test. Test Plan: - functorch build and test should pass Differential Revision: [D39315051](https://our.internmc.facebook.com/intern/diff/D39315051) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84648 Approved by: https://github.com/ezyang --- BUILD.bazel | 1 + aten/src/ATen/CMakeLists.txt | 4 ++-- .../src/ATen/functorch}/BatchedTensorImpl.cpp | 2 +- .../csrc => aten/src/ATen/functorch}/BatchedTensorImpl.h | 8 +++----- buckbuild.bzl | 2 ++ functorch/functorch/csrc/DynamicLayer.cpp | 2 +- functorch/functorch/csrc/Interpreter.cpp | 2 +- functorch/functorch/csrc/LegacyVmapTransforms.h | 2 +- functorch/functorch/csrc/PlumbingHelper.cpp | 2 +- functorch/functorch/csrc/PlumbingHelper.h | 2 +- functorch/functorch/csrc/PyTorchOperatorHacks.cpp | 2 +- functorch/functorch/csrc/TensorWrapper.cpp | 2 +- functorch/functorch/csrc/VmapModeRegistrations.cpp | 2 +- functorch/functorch/csrc/dim/dim.cpp | 2 +- functorch/functorch/csrc/init.cpp | 2 +- setup.py | 1 + 16 files changed, 20 insertions(+), 18 deletions(-) rename {functorch/functorch/csrc => aten/src/ATen/functorch}/BatchedTensorImpl.cpp (98%) rename {functorch/functorch/csrc => aten/src/ATen/functorch}/BatchedTensorImpl.h (95%) diff --git a/BUILD.bazel b/BUILD.bazel index dd417c413a6..2c00e0d1dc5 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -133,6 +133,7 @@ filegroup( name = "aten_base_cpp", srcs = glob([ "aten/src/ATen/*.cpp", + "aten/src/ATen/functorch/*.cpp", "aten/src/ATen/detail/*.cpp", "aten/src/ATen/cpu/*.cpp", ]), diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 286d59f3e97..2380f66ce2b 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -56,8 +56,8 @@ if(NOT BUILD_CAFFE2 AND NOT BUILD_LITE_INTERPRETER) EXCLUDE(ATen_CORE_TEST_SRCS "${ATen_CORE_TEST_SRCS}" ${ATen_CORE_EXCLUDED_TEST_SRCS}) endif() -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") -file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") +file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") file(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h") diff --git a/functorch/functorch/csrc/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp similarity index 98% rename from functorch/functorch/csrc/BatchedTensorImpl.cpp rename to aten/src/ATen/functorch/BatchedTensorImpl.cpp index 0c41c7096dc..c5d6eb34030 100644 --- a/functorch/functorch/csrc/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -3,7 +3,7 @@ // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include +#include #include #include diff --git a/functorch/functorch/csrc/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h similarity index 95% rename from functorch/functorch/csrc/BatchedTensorImpl.h rename to aten/src/ATen/functorch/BatchedTensorImpl.h index 0172bbc8bcf..32098960457 100644 --- a/functorch/functorch/csrc/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -12,8 +12,6 @@ #include #include -#include - namespace at { namespace functorch { @@ -42,7 +40,7 @@ constexpr int64_t kBatchDimsStackSize = 5; // // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor. -struct BatchedTensorImpl : public c10::TensorImpl { +struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level); // Returns batch dimension of this tensor @@ -136,10 +134,10 @@ inline std::bitset createVmapLevelsBitset(int64_t level) { } // Use this to construct a BatchedTensor from a regular Tensor -FUNCTORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level); +TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level); // Adds a batch dim to `tensor`, returning a BatchedTensor -FUNCTORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level); +TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level); // Certain dispatch keys must be propagated to the BatchedTensor (or, in general, // any wrapper Tensor subclasses). This is because there are methods on Tensor diff --git a/buckbuild.bzl b/buckbuild.bzl index d4349d6a75a..458ff20f3c2 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -807,6 +807,7 @@ def define_buck_targets( ("aten/src", "ATen/*.h"), ("aten/src", "ATen/cpu/**/*.h"), ("aten/src", "ATen/detail/*.h"), + ("aten/src", "ATen/functorch/**/*.h"), ("aten/src", "ATen/quantized/*.h"), ("aten/src", "ATen/vulkan/*.h"), ("aten/src", "ATen/metal/*.h"), @@ -869,6 +870,7 @@ def define_buck_targets( ("", "torch/custom_class_detail.h"), # Add again due to namespace difference from aten_header. ("", "aten/src/ATen/*.h"), + ("", "aten/src/ATen/functorch/**/*.h"), ("", "aten/src/ATen/quantized/*.h"), ], exclude = [ diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp index 19255df6a66..c76a63db122 100644 --- a/functorch/functorch/csrc/DynamicLayer.cpp +++ b/functorch/functorch/csrc/DynamicLayer.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include diff --git a/functorch/functorch/csrc/Interpreter.cpp b/functorch/functorch/csrc/Interpreter.cpp index 48e568b17c2..69500d651b0 100644 --- a/functorch/functorch/csrc/Interpreter.cpp +++ b/functorch/functorch/csrc/Interpreter.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/functorch/functorch/csrc/LegacyVmapTransforms.h b/functorch/functorch/csrc/LegacyVmapTransforms.h index 443c4e867de..00ecd5b60d8 100644 --- a/functorch/functorch/csrc/LegacyVmapTransforms.h +++ b/functorch/functorch/csrc/LegacyVmapTransforms.h @@ -7,7 +7,7 @@ #pragma once #include -#include +#include namespace at { namespace functorch { diff --git a/functorch/functorch/csrc/PlumbingHelper.cpp b/functorch/functorch/csrc/PlumbingHelper.cpp index e75fb82a386..738185b230b 100644 --- a/functorch/functorch/csrc/PlumbingHelper.cpp +++ b/functorch/functorch/csrc/PlumbingHelper.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include namespace at { namespace functorch { diff --git a/functorch/functorch/csrc/PlumbingHelper.h b/functorch/functorch/csrc/PlumbingHelper.h index f0dce14893e..4a1716d921f 100644 --- a/functorch/functorch/csrc/PlumbingHelper.h +++ b/functorch/functorch/csrc/PlumbingHelper.h @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once #include -#include +#include #include // NOTE: [vmap plumbing] diff --git a/functorch/functorch/csrc/PyTorchOperatorHacks.cpp b/functorch/functorch/csrc/PyTorchOperatorHacks.cpp index fcd9442d08b..75c33f1e349 100644 --- a/functorch/functorch/csrc/PyTorchOperatorHacks.cpp +++ b/functorch/functorch/csrc/PyTorchOperatorHacks.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/functorch/functorch/csrc/TensorWrapper.cpp b/functorch/functorch/csrc/TensorWrapper.cpp index 22b8136a6b8..c4b6eac2579 100644 --- a/functorch/functorch/csrc/TensorWrapper.cpp +++ b/functorch/functorch/csrc/TensorWrapper.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include diff --git a/functorch/functorch/csrc/VmapModeRegistrations.cpp b/functorch/functorch/csrc/VmapModeRegistrations.cpp index e7407fcc0d6..8548a5c2518 100644 --- a/functorch/functorch/csrc/VmapModeRegistrations.cpp +++ b/functorch/functorch/csrc/VmapModeRegistrations.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/functorch/functorch/csrc/dim/dim.cpp b/functorch/functorch/csrc/dim/dim.cpp index 907554e861c..37cd6a0d8d2 100644 --- a/functorch/functorch/csrc/dim/dim.cpp +++ b/functorch/functorch/csrc/dim/dim.cpp @@ -13,7 +13,7 @@ #include //#include #include -#include +#include #include #include #include diff --git a/functorch/functorch/csrc/init.cpp b/functorch/functorch/csrc/init.cpp index de618fabebd..c7080359ed1 100644 --- a/functorch/functorch/csrc/init.cpp +++ b/functorch/functorch/csrc/init.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include diff --git a/setup.py b/setup.py index 83f262ab600..5d91ab13c1c 100644 --- a/setup.py +++ b/setup.py @@ -996,6 +996,7 @@ def main(): 'include/ATen/cuda/detail/*.cuh', 'include/ATen/cuda/detail/*.h', 'include/ATen/cudnn/*.h', + 'include/ATen/functorch/*.h', 'include/ATen/ops/*.h', 'include/ATen/hip/*.cuh', 'include/ATen/hip/*.h',