mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
308 lines
8.1 KiB
Python
308 lines
8.1 KiB
Python
# @lint-ignore-every FBCODEBZLADDLOADS
|
|
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
|
|
|
# shared by internal and OSS BUCK
|
|
def define_tools_targets(
|
|
python_binary,
|
|
python_library,
|
|
python_test,
|
|
third_party,
|
|
torchgen_deps,
|
|
contacts = []):
|
|
python_library(
|
|
name = "substitutelib",
|
|
srcs = ["substitute.py"],
|
|
base_module = "",
|
|
)
|
|
|
|
python_binary(
|
|
name = "substitute",
|
|
main_module = "substitute",
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":substitutelib",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "jit",
|
|
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
|
srcs = glob([
|
|
"jit/*.py",
|
|
"jit/templates/*",
|
|
]),
|
|
base_module = "tools",
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_binary(
|
|
name = "gen_unboxing_bin",
|
|
main_module = "tools.jit.gen_unboxing",
|
|
visibility = [
|
|
"PUBLIC",
|
|
],
|
|
deps = [
|
|
":jit",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "gen_selected_mobile_ops_header",
|
|
srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"],
|
|
base_module = "tools",
|
|
visibility = ["PUBLIC"],
|
|
)
|
|
|
|
python_library(
|
|
name = "gen_oplist_lib",
|
|
srcs = subdir_glob([
|
|
("code_analyzer", "gen_oplist.py"),
|
|
("code_analyzer", "gen_op_registration_allowlist.py"),
|
|
]),
|
|
base_module = "tools.code_analyzer",
|
|
tests = [
|
|
":gen_oplist_test",
|
|
],
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen_selected_mobile_ops_header",
|
|
torchgen_deps,
|
|
third_party("pyyaml"),
|
|
],
|
|
)
|
|
|
|
python_binary(
|
|
name = "gen_oplist",
|
|
main_module = "tools.code_analyzer.gen_oplist",
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen_oplist_lib",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "gen_operators_yaml_lib",
|
|
srcs = subdir_glob([
|
|
("code_analyzer", "gen_operators_yaml.py"),
|
|
("code_analyzer", "gen_op_registration_allowlist.py"),
|
|
]),
|
|
base_module = "",
|
|
tests = [
|
|
":gen_operators_yaml_test",
|
|
],
|
|
deps = [
|
|
third_party("pyyaml"),
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_binary(
|
|
name = "gen_operators_yaml",
|
|
main_module = "gen_operators_yaml",
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen_operators_yaml_lib",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "autograd",
|
|
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
|
srcs = glob(
|
|
["autograd/*.py"],
|
|
),
|
|
base_module = "tools",
|
|
resources = [
|
|
"autograd/deprecated.yaml",
|
|
"autograd/derivatives.yaml",
|
|
"autograd/templates/ADInplaceOrViewType.cpp",
|
|
"autograd/templates/Functions.cpp",
|
|
"autograd/templates/Functions.h",
|
|
"autograd/templates/TraceType.cpp",
|
|
"autograd/templates/VariableType.cpp",
|
|
"autograd/templates/VariableType.h",
|
|
"autograd/templates/ViewFuncs.cpp",
|
|
"autograd/templates/ViewFuncs.h",
|
|
"autograd/templates/annotated_fn_args.py.in",
|
|
"autograd/templates/python_enum_tag.cpp",
|
|
"autograd/templates/python_fft_functions.cpp",
|
|
"autograd/templates/python_functions.cpp",
|
|
"autograd/templates/python_functions.h",
|
|
"autograd/templates/python_linalg_functions.cpp",
|
|
"autograd/templates/python_nested_functions.cpp",
|
|
"autograd/templates/python_nn_functions.cpp",
|
|
"autograd/templates/python_return_types.h",
|
|
"autograd/templates/python_return_types.cpp",
|
|
"autograd/templates/python_sparse_functions.cpp",
|
|
"autograd/templates/python_special_functions.cpp",
|
|
"autograd/templates/python_torch_functions.cpp",
|
|
"autograd/templates/python_variable_methods.cpp",
|
|
"autograd/templates/variable_factories.h",
|
|
],
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
third_party("pyyaml"),
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "generate_code",
|
|
srcs = [
|
|
"setup_helpers/generate_code.py",
|
|
],
|
|
base_module = "tools",
|
|
deps = [
|
|
":autograd",
|
|
":jit",
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_binary(
|
|
name = "generate_code_bin",
|
|
main_module = "tools.setup_helpers.generate_code",
|
|
# Windows does not support inplace:
|
|
# https://github.com/facebook/buck/issues/2161.
|
|
#
|
|
# Note that //arvr/mode/embedded/win/clang-aarch64-release sets
|
|
# its target platform to
|
|
# ovr_config//platform/embedded:clang-aarch64-linux-release, hence
|
|
# that is why we are selecting that OS to trigger this behavior.
|
|
package_style = select({
|
|
"DEFAULT": "inplace",
|
|
"ovr_config//os:linux-arm64": "standalone",
|
|
}),
|
|
visibility = ["PUBLIC"],
|
|
# Because Windows does not support inplace packaging, we need to
|
|
# ensure it is unzipped before executing it, otherwise it will not
|
|
# be able to find any resources using path manipulation.
|
|
#
|
|
# See note above about why the OS is Linux here and not Windows.
|
|
zip_safe = select({
|
|
"DEFAULT": True,
|
|
"ovr_config//os:linux-arm64": False,
|
|
}),
|
|
deps = [
|
|
":generate_code",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "gen-version-header-lib",
|
|
srcs = [
|
|
"setup_helpers/gen_version_header.py",
|
|
],
|
|
base_module = "",
|
|
deps = [],
|
|
)
|
|
|
|
python_binary(
|
|
name = "gen-version-header",
|
|
main_module = "setup_helpers.gen_version_header",
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen-version-header-lib",
|
|
],
|
|
)
|
|
|
|
python_library(
|
|
name = "gen_aten_vulkan_spv_lib",
|
|
srcs = [
|
|
"gen_vulkan_spv.py",
|
|
],
|
|
base_module = "tools",
|
|
deps = [
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_binary(
|
|
name = "gen_aten_vulkan_spv_bin",
|
|
main_module = "tools.gen_vulkan_spv",
|
|
visibility = [
|
|
"PUBLIC",
|
|
],
|
|
deps = [
|
|
":gen_aten_vulkan_spv_lib",
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "vulkan_codegen_test",
|
|
srcs = [
|
|
"test/test_vulkan_codegen.py",
|
|
],
|
|
contacts = contacts,
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen_aten_vulkan_spv_lib",
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "selective_build_test",
|
|
srcs = [
|
|
"test/test_selective_build.py",
|
|
],
|
|
contacts = contacts,
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
torchgen_deps,
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "gen_oplist_test",
|
|
srcs = [
|
|
"test/gen_oplist_test.py",
|
|
],
|
|
contacts = contacts,
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
":gen_oplist_lib",
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "gen_operators_yaml_test",
|
|
srcs = [
|
|
"test/gen_operators_yaml_test.py",
|
|
],
|
|
visibility = ["PUBLIC"],
|
|
contacts = contacts,
|
|
deps = [
|
|
":gen_operators_yaml_lib",
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "test_codegen",
|
|
srcs = [
|
|
"test/test_codegen.py",
|
|
],
|
|
contacts = contacts,
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
torchgen_deps,
|
|
":autograd",
|
|
],
|
|
)
|
|
|
|
python_test(
|
|
name = "test_torchgen_executorch",
|
|
srcs = [
|
|
"test/test_executorch_gen.py",
|
|
"test/test_executorch_signatures.py",
|
|
"test/test_executorch_types.py",
|
|
"test/test_executorch_unboxing.py",
|
|
],
|
|
contacts = contacts,
|
|
visibility = ["PUBLIC"],
|
|
deps = [
|
|
torchgen_deps,
|
|
],
|
|
)
|