mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR: * Allows `BatchedTensorImpl`s to contain NTs * Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules * Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT Restrictions: * Only supports one level of vmap * Only supports vmapping over dim=0 for NTs * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/106786 Approved by: https://github.com/zou3519 |
||
|---|---|---|
| .. | ||
| api | ||
| decompositions | ||
| dest | ||
| executorch | ||
| operator_versions | ||
| selective_build | ||
| shape_functions | ||
| static_runtime | ||
| __init__.py | ||
| BUCK.oss | ||
| BUILD.bazel | ||
| build.bzl | ||
| code_template.py | ||
| context.py | ||
| gen.py | ||
| gen_backend_stubs.py | ||
| gen_executorch.py | ||
| gen_functionalization_type.py | ||
| gen_lazy_tensor.py | ||
| gen_vmap_plumbing.py | ||
| local.py | ||
| model.py | ||
| native_function_generation.py | ||
| utils.py | ||
| yaml_utils.py | ||