mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
|
||
|---|---|---|
| .. | ||
| ao/sparsity | ||
| autograd | ||
| backends/xeon | ||
| benchmark_utils | ||
| bottleneck_test | ||
| cpp | ||
| cpp_api_parity | ||
| cpp_extensions | ||
| custom_backend | ||
| custom_operator | ||
| distributed | ||
| distributions | ||
| dynamo | ||
| edge | ||
| error_messages | ||
| expect | ||
| export | ||
| forward_backward_compatibility | ||
| functorch | ||
| fx | ||
| inductor | ||
| jit | ||
| jit_hooks | ||
| lazy | ||
| mobile | ||
| nn | ||
| onnx | ||
| onnx_caffe2 | ||
| optim | ||
| package | ||
| profiler | ||
| quantization | ||
| scripts | ||
| test_img | ||
| torch_np | ||
| typing | ||
| _test_bazel.py | ||
| allowlist_for_publicAPI.json | ||
| conftest.py | ||
| create_dummy_torchscript_model.py | ||
| delete.py | ||
| HowToWriteTestsUsingFileCheck.md | ||
| linear.py | ||
| load_torchscript_model.py | ||
| minioptest_failures_dict.json | ||
| mkl_verbose.py | ||
| mkldnn_verbose.py | ||
| pytest_shard_custom.py | ||
| run_doctests.sh | ||
| run_test.py | ||
| simulate_nccl_errors.py | ||
| test_ao_sparsity.py | ||
| test_autocast.py | ||
| test_autograd.py | ||
| test_autograd_fallback.py | ||
| test_binary_ufuncs.py | ||
| test_bundled_images.py | ||
| test_bundled_inputs.py | ||
| test_comparison_utils.py | ||
| test_compile_benchmark_util.py | ||
| test_complex.py | ||
| test_content_store.py | ||
| test_cpp_api_parity.py | ||
| test_cpp_extensions_aot.py | ||
| test_cpp_extensions_jit.py | ||
| test_cpp_extensions_open_device_registration.py | ||
| test_cuda.py | ||
| test_cuda_expandable_segments.py | ||
| test_cuda_multigpu.py | ||
| test_cuda_nvml_based_avail.py | ||
| test_cuda_primary_ctx.py | ||
| test_cuda_sanitizer.py | ||
| test_cuda_trace.py | ||
| test_custom_ops.py | ||
| test_dataloader.py | ||
| test_datapipe.py | ||
| test_decomp.py | ||
| test_deploy.py | ||
| test_determination.py | ||
| test_dispatch.py | ||
| test_dlpack.py | ||
| test_dynamic_shapes.py | ||
| test_expanded_weights.py | ||
| test_fake_tensor.py | ||
| test_flop_counter.py | ||
| test_foreach.py | ||
| test_function_schema.py | ||
| test_functional_autograd_benchmark.py | ||
| test_functional_optim.py | ||
| test_functionalization.py | ||
| test_functionalization_of_rng_ops.py | ||
| test_futures.py | ||
| test_fx.py | ||
| test_fx_experimental.py | ||
| test_fx_passes.py | ||
| test_fx_reinplace_pass.py | ||
| test_hub.py | ||
| test_import_stats.py | ||
| test_indexing.py | ||
| test_itt.py | ||
| test_jit.py | ||
| test_jit_autocast.py | ||
| test_jit_disabled.py | ||
| test_jit_fuser.py | ||
| test_jit_fuser_legacy.py | ||
| test_jit_fuser_te.py | ||
| test_jit_legacy.py | ||
| test_jit_llga_fuser.py | ||
| test_jit_profiling.py | ||
| test_jit_simple.py | ||
| test_jit_string.py | ||
| test_jiterator.py | ||
| test_kernel_launch_checks.py | ||
| test_legacy_vmap.py | ||
| test_license.py | ||
| test_linalg.py | ||
| test_logging.py | ||
| test_masked.py | ||
| test_maskedtensor.py | ||
| test_matmul_cuda.py | ||
| test_meta.py | ||
| test_metal.py | ||
| test_mkl_verbose.py | ||
| test_mkldnn.py | ||
| test_mkldnn_fusion.py | ||
| test_mkldnn_verbose.py | ||
| test_mobile_optimizer.py | ||
| test_model_dump.py | ||
| test_model_exports_to_core_aten.py | ||
| test_modules.py | ||
| test_monitor.py | ||
| test_mps.py | ||
| test_multiprocessing.py | ||
| test_multiprocessing_spawn.py | ||
| test_namedtensor.py | ||
| test_namedtuple_return_api.py | ||
| test_native_functions.py | ||
| test_native_mha.py | ||
| test_nestedtensor.py | ||
| test_nn.py | ||
| test_nnapi.py | ||
| test_numba_integration.py | ||
| test_numpy_interop.py | ||
| test_openmp.py | ||
| test_ops.py | ||
| test_ops_fwd_gradients.py | ||
| test_ops_gradients.py | ||
| test_ops_jit.py | ||
| test_optim.py | ||
| test_out_dtype_op.py | ||
| test_overrides.py | ||
| test_package.py | ||
| test_per_overload_api.py | ||
| test_prims.py | ||
| test_proxy_tensor.py | ||
| test_pruning_op.py | ||
| test_public_bindings.py | ||
| test_python_dispatch.py | ||
| test_pytree.py | ||
| test_quantization.py | ||
| test_reductions.py | ||
| test_scatter_gather_ops.py | ||
| test_schema_check.py | ||
| test_segment_reductions.py | ||
| test_serialization.py | ||
| test_set_default_mobile_cpu_allocator.py | ||
| test_shape_ops.py | ||
| test_show_pickle.py | ||
| test_sort_and_select.py | ||
| test_sparse.py | ||
| test_sparse_csr.py | ||
| test_sparse_semi_structured.py | ||
| test_spectral_ops.py | ||
| test_stateless.py | ||
| test_static_runtime.py | ||
| test_subclass.py | ||
| test_sympy_utils.py | ||
| test_tensor_creation_ops.py | ||
| test_tensorboard.py | ||
| test_tensorexpr.py | ||
| test_tensorexpr_pybind.py | ||
| test_testing.py | ||
| test_throughput_benchmark.py | ||
| test_torch.py | ||
| test_transformers.py | ||
| test_type_hints.py | ||
| test_type_info.py | ||
| test_type_promotion.py | ||
| test_typing.py | ||
| test_unary_ufuncs.py | ||
| test_utils.py | ||
| test_view_ops.py | ||
| test_vulkan.py | ||
| test_weak.py | ||
| test_xnnpack_integration.py | ||
| test_xpu.py | ||