pytorch/test
Joel Schlosser 9ec8dd2467 Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
2024-02-14 22:00:43 +00:00
..
ao/sparsity Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
autograd
backends/xeon
benchmark_utils
bottleneck_test
cpp Fix C++20 build (#112333) 2024-02-13 05:10:19 +00:00
cpp_api_parity
cpp_extensions Enable optional tensorList fallback to cpu. (#119273) 2024-02-07 03:54:13 +00:00
custom_backend
custom_operator Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
distributed [DTensor] Enable Adamax foreach optimizer (#119850) 2024-02-14 20:43:00 +00:00
distributions Bugfix to MixtureSameFamily's _pad_mixture_dimension (#118947) 2024-02-06 16:24:22 +00:00
dynamo Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
edge
error_messages
expect Add pixel_shuffle to core aten decomps (#119899) 2024-02-14 21:01:11 +00:00
export Windows Dynamo Error Removal CI Check (#115969) 2024-02-14 21:14:36 +00:00
forward_backward_compatibility Revert "[RELAND] Remove deprecated fbgemm operators (#112153)" 2024-02-01 18:35:19 +00:00
functorch Windows Dynamo Error Removal CI Check (#115969) 2024-02-14 21:14:36 +00:00
fx Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
inductor Handle aliases correctly in foreach (#119508) 2024-02-14 21:21:28 +00:00
jit Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
jit_hooks
lazy Implement shallow copy functions for FunctionalTensorWrapper. (#118783) 2024-02-08 17:15:46 +00:00
mobile [BE][Ez]: FURB129: remove unneeded readlines() (#119796) 2024-02-13 21:21:22 +00:00
nn Integrate swap_tensors into nn.Module.load_state_dict (#117913) 2024-02-09 22:32:29 +00:00
onnx Avoid performing replacements when it would unrefine ranges (#117356) 2024-02-13 15:56:59 +00:00
onnx_caffe2
optim ReduceLROnPlateau init _last_lr (#119366) (#119556) 2024-02-09 19:35:02 +00:00
package
profiler Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
quantization Windows Dynamo Error Removal CI Check (#115969) 2024-02-14 21:14:36 +00:00
scripts
test_img
torch_np Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
typing Fix signatures of torch.{add, sub, mul} (#118398) 2024-01-30 22:18:15 +00:00
_test_bazel.py
allowlist_for_publicAPI.json Autograd doc cleanup (#118500) 2024-01-29 21:51:33 +00:00
conftest.py Various CI settings (#117668) 2024-01-26 00:17:29 +00:00
create_dummy_torchscript_model.py
delete.py
HowToWriteTestsUsingFileCheck.md
linear.py
load_torchscript_model.py
minioptest_failures_dict.json
mkl_verbose.py
mkldnn_verbose.py
pytest_shard_custom.py Reduce pytest prints (#117069) 2024-01-23 18:39:30 +00:00
run_doctests.sh
run_test.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
simulate_nccl_errors.py
test_ao_sparsity.py
test_autocast.py Remove incorrect usages of skipIfTorchDynamo (#117114) 2024-01-10 22:25:31 +00:00
test_autograd.py Reify view_func() closures as ViewFuncs (#118404) 2024-02-14 22:00:43 +00:00
test_autograd_fallback.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_binary_ufuncs.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_bundled_images.py
test_bundled_inputs.py
test_comparison_utils.py
test_compile_benchmark_util.py
test_complex.py
test_content_store.py Windows Dynamo Error Removal CI Check (#115969) 2024-02-14 21:14:36 +00:00
test_cpp_api_parity.py
test_cpp_extensions_aot.py
test_cpp_extensions_jit.py
test_cpp_extensions_open_device_registration.py Enable optional tensorList fallback to cpu. (#119273) 2024-02-07 03:54:13 +00:00
test_cuda.py add test cases for GradScaler on CPU (#109994) 2024-02-02 21:49:07 +00:00
test_cuda_expandable_segments.py
test_cuda_multigpu.py add GradScaler on CPU (#109993) 2024-01-29 23:42:35 +00:00
test_cuda_nvml_based_avail.py
test_cuda_primary_ctx.py
test_cuda_sanitizer.py
test_cuda_trace.py
test_custom_ops.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_dataloader.py [BE]: Apply RUF025 dict.fromkeys preview rule (#118637) 2024-01-30 20:46:54 +00:00
test_datapipe.py Replace follow_imports = silent with normal (#118414) 2024-01-27 02:44:11 +00:00
test_decomp.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_deploy.py
test_determination.py
test_dispatch.py
test_dlpack.py Improve uint{16,32,64} dlpack/numpy compatibility (#116808) 2024-01-11 17:01:54 +00:00
test_dynamic_shapes.py Avoid performing replacements when it would unrefine ranges (#117356) 2024-02-13 15:56:59 +00:00
test_expanded_weights.py Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard (#117625) 2024-02-02 14:38:08 +00:00
test_fake_tensor.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_flop_counter.py
test_foreach.py
test_function_schema.py
test_functional_autograd_benchmark.py
test_functional_optim.py
test_functionalization.py
test_functionalization_of_rng_ops.py
test_futures.py
test_fx.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_fx_experimental.py Revert "[codemod] markDynamoStrictTest batch 16 (#117218)" 2024-01-12 03:06:20 +00:00
test_fx_passes.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_fx_reinplace_pass.py
test_hub.py [BE][Ez]: FURB129: remove unneeded readlines() (#119796) 2024-02-13 21:21:22 +00:00
test_import_stats.py
test_indexing.py Support builtin callable with object arguments in dynamo (#118678) 2024-01-31 17:54:08 +00:00
test_itt.py
test_jit.py Remove unnecessary skipIfTorchDynamo from test_jit_fuser_te (#118728) 2024-02-12 20:55:29 +00:00
test_jit_autocast.py
test_jit_disabled.py
test_jit_fuser.py
test_jit_fuser_legacy.py
test_jit_fuser_te.py Remove unnecessary skipIfTorchDynamo from test_jit_fuser_te (#118728) 2024-02-12 20:55:29 +00:00
test_jit_legacy.py
test_jit_llga_fuser.py
test_jit_profiling.py
test_jit_simple.py
test_jit_string.py
test_jiterator.py
test_kernel_launch_checks.py
test_legacy_vmap.py Skip some slow tests (under Dynamo) (#117389) 2024-01-12 22:18:07 +00:00
test_license.py
test_linalg.py [ROCm] enable hipsolver backend for linalg.eigh (#115177) 2024-02-08 22:03:27 +00:00
test_logging.py
test_masked.py
test_maskedtensor.py
test_matmul_cuda.py Enable scaled_mm on sm89 devices (#118881) 2024-02-03 00:44:03 +00:00
test_meta.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_metal.py
test_mkl_verbose.py
test_mkldnn.py
test_mkldnn_fusion.py [codemod] markDynamoStrictTest batch 22 (#117729) 2024-01-18 16:59:26 +00:00
test_mkldnn_verbose.py
test_mobile_optimizer.py
test_model_dump.py
test_model_exports_to_core_aten.py
test_modules.py [easy] Add testing utilties for torch.nn.utils.set_swap_module_params_on_conversion (#118023) 2024-02-07 18:55:44 +00:00
test_monitor.py Change dynamo_test_failures.py to silently run skipped tests (#117401) 2024-01-17 02:48:19 +00:00
test_mps.py [MPS] Add naive std_mean implementation (#119777) 2024-02-13 21:51:29 +00:00
test_multiprocessing.py
test_multiprocessing_spawn.py
test_namedtensor.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_namedtuple_return_api.py Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)" 2024-01-18 23:40:30 +00:00
test_native_functions.py
test_native_mha.py
test_nestedtensor.py Fix meta registration for _flash_attention_forward() (#119812) 2024-02-14 02:38:53 +00:00
test_nn.py Fixed an issue where nn.Linear would cause an internal int underflow … (#119221) 2024-02-08 21:06:34 +00:00
test_nnapi.py
test_numba_integration.py
test_numpy_interop.py Replace follow_imports = silent with normal (#118414) 2024-01-27 02:44:11 +00:00
test_openmp.py
test_ops.py Reify view_func() closures as ViewFuncs (#118404) 2024-02-14 22:00:43 +00:00
test_ops_fwd_gradients.py
test_ops_gradients.py
test_ops_jit.py
test_optim.py Migrate load_state_dict hook tests to OptimizerInfo (#119310) 2024-02-07 16:00:01 +00:00
test_out_dtype_op.py [export] Remove torch._export.export (#119095) 2024-02-08 21:22:04 +00:00
test_overrides.py
test_package.py
test_per_overload_api.py
test_prims.py
test_proxy_tensor.py Add a decomposition for isin() (#115390) 2024-02-14 03:03:42 +00:00
test_pruning_op.py
test_public_bindings.py Add torch.dtype instances to the public API (#119307) 2024-02-07 02:57:49 +00:00
test_python_dispatch.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_pytree.py [BE]: Apply RUF025 dict.fromkeys preview rule (#118637) 2024-01-30 20:46:54 +00:00
test_quantization.py
test_reductions.py Make and/or on uint8 tensors properly return 0x00 or 0x01 (#117827) 2024-01-22 17:30:22 +00:00
test_scatter_gather_ops.py
test_schema_check.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_segment_reductions.py
test_serialization.py [CI] Install dill in ci (#116214) 2024-01-24 23:42:35 +00:00
test_set_default_mobile_cpu_allocator.py
test_shape_ops.py
test_show_pickle.py
test_sort_and_select.py Removed an internal assertion for the optional stable value and inste… (#117414) 2024-01-17 02:25:21 +00:00
test_sparse.py
test_sparse_csr.py [SparseCsr] Remove triton sdpa skip after triton pin update (#109601) 2024-02-08 16:40:25 +00:00
test_sparse_semi_structured.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_spectral_ops.py [ROCm] Enable float16/complex32 fft tests on ROCm (#117296) 2024-02-13 22:35:32 +00:00
test_stateless.py
test_static_runtime.py
test_subclass.py
test_sympy_utils.py
test_tensor_creation_ops.py Unskip test_complex_type_conversions (#118694) 2024-01-31 08:04:15 +00:00
test_tensorboard.py
test_tensorexpr.py
test_tensorexpr_pybind.py
test_testing.py [BE]: Apply RUF025 dict.fromkeys preview rule (#118637) 2024-01-30 20:46:54 +00:00
test_throughput_benchmark.py
test_torch.py Update to TorchFix 0.4.0 (#119424) 2024-02-12 23:30:12 +00:00
test_transformers.py Add linux cpu test for 3.12 (#117853) 2024-02-14 20:52:23 +00:00
test_type_hints.py
test_type_info.py
test_type_promotion.py Improve uint{16,32,64} dlpack/numpy compatibility (#116808) 2024-01-11 17:01:54 +00:00
test_typing.py More progress on type checking ValueRanges (#118870) 2024-02-05 20:29:25 +00:00
test_unary_ufuncs.py Remove some unnecessary skipIfTorchDynamo (#118725) 2024-01-31 18:18:17 +00:00
test_utils.py [ROCm] Hipify trie re-engineering and adding unit tests (#118433) 2024-02-02 16:04:59 +00:00
test_view_ops.py
test_vulkan.py
test_weak.py
test_xnnpack_integration.py
test_xpu.py [2/2] Intel GPU Runtime Upstreaming for Stream (#117619) 2024-02-10 03:39:42 +00:00