mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
146 lines
4.5 KiB
Python
146 lines
4.5 KiB
Python
"""
|
|
To run this file by hand from the root of the PyTorch
|
|
repository, run:
|
|
|
|
python -m tools.autograd.gen_autograd \
|
|
aten/src/ATen/native/native_functions.yaml \
|
|
aten/src/ATen/native/tags.yaml \
|
|
$OUTPUT_DIR \
|
|
tools/autograd
|
|
|
|
Where $OUTPUT_DIR is where you would like the files to be
|
|
generated. In the full build system, OUTPUT_DIR is
|
|
torch/csrc/autograd/generated/
|
|
"""
|
|
|
|
# gen_autograd.py generates C++ autograd functions and Python bindings.
|
|
#
|
|
# It delegates to the following scripts:
|
|
#
|
|
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
|
|
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
|
|
# gen_python_functions.py: generates Python bindings to THPVariable
|
|
#
|
|
|
|
import argparse
|
|
import os
|
|
from typing import List
|
|
|
|
from torchgen.api import cpp
|
|
from torchgen.api.autograd import (
|
|
match_differentiability_info,
|
|
NativeFunctionWithDifferentiabilityInfo,
|
|
)
|
|
from torchgen.gen import parse_native_yaml
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
|
|
from . import gen_python_functions
|
|
from .gen_autograd_functions import (
|
|
gen_autograd_functions_lib,
|
|
gen_autograd_functions_python,
|
|
)
|
|
from .gen_inplace_or_view_type import gen_inplace_or_view_type
|
|
from .gen_trace_type import gen_trace_type
|
|
from .gen_variable_factories import gen_variable_factories
|
|
from .gen_variable_type import gen_variable_type
|
|
from .gen_view_funcs import gen_view_funcs
|
|
from .load_derivatives import load_derivatives
|
|
|
|
|
|
def gen_autograd(
|
|
native_functions_path: str,
|
|
tags_path: str,
|
|
out: str,
|
|
autograd_dir: str,
|
|
operator_selector: SelectiveBuilder,
|
|
disable_autograd: bool = False,
|
|
) -> None:
|
|
# Parse and load derivatives.yaml
|
|
differentiability_infos, used_dispatch_keys = load_derivatives(
|
|
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
|
|
)
|
|
|
|
template_path = os.path.join(autograd_dir, "templates")
|
|
|
|
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
|
|
fns = sorted(
|
|
filter(
|
|
operator_selector.is_native_function_selected_for_training, native_funcs
|
|
),
|
|
key=lambda f: cpp.name(f.func),
|
|
)
|
|
fns_with_diff_infos: List[
|
|
NativeFunctionWithDifferentiabilityInfo
|
|
] = match_differentiability_info(fns, differentiability_infos)
|
|
|
|
# Generate VariableType.h/cpp
|
|
if not disable_autograd:
|
|
gen_variable_type(
|
|
out,
|
|
native_functions_path,
|
|
tags_path,
|
|
fns_with_diff_infos,
|
|
template_path,
|
|
used_dispatch_keys,
|
|
)
|
|
|
|
gen_inplace_or_view_type(
|
|
out, native_functions_path, tags_path, fns_with_diff_infos, template_path
|
|
)
|
|
|
|
# operator filter not applied as tracing sources are excluded in selective build
|
|
gen_trace_type(out, native_funcs, template_path)
|
|
# Generate Functions.h/cpp
|
|
gen_autograd_functions_lib(out, differentiability_infos, template_path)
|
|
|
|
# Generate variable_factories.h
|
|
gen_variable_factories(out, native_functions_path, tags_path, template_path)
|
|
|
|
# Generate ViewFuncs.h/cpp
|
|
gen_view_funcs(out, fns_with_diff_infos, template_path)
|
|
|
|
|
|
def gen_autograd_python(
|
|
native_functions_path: str,
|
|
tags_path: str,
|
|
out: str,
|
|
autograd_dir: str,
|
|
) -> None:
|
|
differentiability_infos, _ = load_derivatives(
|
|
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
|
|
)
|
|
|
|
template_path = os.path.join(autograd_dir, "templates")
|
|
|
|
# Generate Functions.h/cpp
|
|
gen_autograd_functions_python(out, differentiability_infos, template_path)
|
|
|
|
# Generate Python bindings
|
|
deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
|
|
gen_python_functions.gen(
|
|
out, native_functions_path, tags_path, deprecated_path, template_path
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
|
|
parser.add_argument(
|
|
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
|
|
)
|
|
parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
|
|
parser.add_argument("out", metavar="OUT", help="path to output directory")
|
|
parser.add_argument(
|
|
"autograd", metavar="AUTOGRAD", help="path to autograd directory"
|
|
)
|
|
args = parser.parse_args()
|
|
gen_autograd(
|
|
args.native_functions,
|
|
args.tags,
|
|
args.out,
|
|
args.autograd,
|
|
SelectiveBuilder.get_nop_selector(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|