pytorch/torch
Li Yu (ads) dabe2a3c3b [Torch] Support meta device in random.fork_rng (#137715)
Summary:
## Why
random.fork_rng doesn't support meta device:
```
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/aps_models/ads/tools/memory_estimator/estimation_dense.py", line 655, in estimate_dense_memory_size
[rank0]:     losses.sum().backward()
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/_tensor.py", line 604, in backward
[rank0]:     return handle_torch_function(
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/overrides.py", line 1718, in handle_torch_function
[rank0]:     result = mode.__torch_function__(public_api, types, args, kwargs)
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/_device.py", line 106, in __torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/_tensor.py", line 613, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/checkpoint.py", line 1125, in unpack_hook
[rank0]:     frame.recompute_fn(*args)
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/checkpoint.py", line 1507, in recompute_fn
[rank0]:     with torch.random.fork_rng(
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/runtime/lib/python3.10/contextlib.py", line 135, in __enter__
[rank0]:     return next(self.gen)
[rank0]:   File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/random.py", line 153, in fork_rng
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: torch has no module of `meta`, you should register a module by `torch._register_device_module`.
```

This blocks us from running backward() on model with checkpoint enabled in meta mode.

## What
This diff handles the case of meta device in random.fork_rng.

Test Plan: Tested with toy model which has checkpoint on its module: P1641201046

Differential Revision: D64161410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137715
Approved by: https://github.com/kit1980
2024-10-16 18:00:39 +00:00
..
_awaits
_C Revert "Expose option to disable CRC-32 computation during torch.save (#137735)" 2024-10-16 17:03:06 +00:00
_C_flatbuffer
_custom_op
_decomp Add decomposition for permute_copy (#130944) 2024-10-15 13:51:20 +00:00
_dispatch
_dynamo Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)" 2024-10-16 16:38:29 +00:00
_export Fix constant returning (#137993) 2024-10-16 16:42:09 +00:00
_functorch Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)" 2024-10-16 16:38:29 +00:00
_higher_order_ops Add host-side Triton TMA support to Dynamo (#137677) 2024-10-16 02:18:48 +00:00
_inductor Remove an unused variable in _inductor/codegen/simd.py (#138000) 2024-10-16 13:54:21 +00:00
_lazy
_library Proper handling of arguments passed by in kwargs inside zip_schema (#137311) 2024-10-04 21:50:31 +00:00
_logging Don't actually import module when checking if its valid (#136548) 2024-09-25 20:47:32 +00:00
_numpy
_prims Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513) 2024-10-01 00:32:47 +00:00
_prims_common Fix six broken tests in test_ops.py (#136653) 2024-09-30 20:32:55 +00:00
_refs Add decomposition for permute_copy (#130944) 2024-10-15 13:51:20 +00:00
_strobelight Add code pointer to internal Meta implementation (#137984) 2024-10-15 23:35:22 +00:00
_subclasses [fake_tensor][cache] Supports ops with tuple of output tensors (#137935) 2024-10-15 22:15:07 +00:00
_vendor
amp Fix autocast for non-strict export (#137495) 2024-10-16 17:39:00 +00:00
ao torch/ao/quantization/utils.py: Moving eps to targeted device to avoid device mismatch issue (#135204) 2024-10-15 14:58:55 +00:00
autograd Param fixes in docstring (#136097) 2024-09-21 18:56:34 +00:00
backends Clarify opt-einsum usage, fix #127109 (#137596) 2024-10-09 20:31:24 +00:00
compiler [dynamo] add torch.compiler.set_stance (#137504) 2024-10-16 16:18:25 +00:00
contrib
cpu Extend vectorization with SVE(ARM) with Torch Compile (Inductor) (#134672) 2024-10-10 13:20:40 +00:00
csrc Revert "Expose option to disable CRC-32 computation during torch.save (#137735)" 2024-10-16 17:03:06 +00:00
cuda [ROCm] Add AMDSMI support for UUID input (#129741) 2024-10-15 15:56:30 +00:00
distributed Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)" 2024-10-16 16:38:29 +00:00
distributions [BE]: Update mypy to 1.11.2 (#133816) 2024-09-16 19:44:11 +00:00
export [alt] fix unroll in successive unflatten (#137646) 2024-10-12 15:53:52 +00:00
fft
func
futures
fx Fix autocast for non-strict export (#137495) 2024-10-16 17:39:00 +00:00
jit
legacy
lib
linalg docs: clarify alias usage for x parameter in vector_norm function (#136921) 2024-09-30 02:50:06 +00:00
masked Fix memory leak on masked Tensor (#137890) 2024-10-15 18:37:55 +00:00
monitor
mps
mtia [MTIA] Support torch.cuda.get_device_capability equivalent API on MTIA (#135889) 2024-09-17 17:42:56 +00:00
multiprocessing multiprocessing.spawn: allow a grace period when shutdown (#131278) 2024-10-07 12:37:34 +00:00
nested Fix to() on non-contiguous NJTs (#137124) 2024-10-08 15:11:05 +00:00
nn Removed _compile workaround for create_block_mask (#137477) 2024-10-11 19:04:23 +00:00
onnx Revert "[ONNX] Remove ExportTypes (#137789)" 2024-10-15 17:40:06 +00:00
optim RMSprop docs: add missing input "epsilon" (#137854) 2024-10-15 16:40:42 +00:00
package [3.13] fix 3.13 pickle error in torch/package (#136049) 2024-09-14 14:28:09 +00:00
profiler [Profiler] Torch Profiler distributed info is not JSON serializable (#135548) 2024-09-13 02:22:33 +00:00
quantization
signal
sparse [sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397) 2024-09-27 21:37:34 +00:00
special
testing Upgrade distributed test to g4dn instances (T4 GPUs) (#137161) 2024-10-16 16:42:57 +00:00
utils Add host-side Triton TMA support to Dynamo (#137677) 2024-10-16 02:18:48 +00:00
xpu Use torch.Stream&torch.Event for Dynamo capature (#134850) 2024-10-02 14:15:33 +00:00
__config__.py
__future__.py
__init__.py Revert "[Dynamo] Disable torch function compilation during guard execution and in compiled bytecode (#137669)" 2024-10-15 23:22:58 +00:00
_appdirs.py
_classes.py
_compile.py
_custom_ops.py
_deploy.py
_environment.py Improve is_fbcode functionality (#136871) 2024-09-27 21:19:01 +00:00
_guards.py Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972) 2024-10-01 13:22:10 +00:00
_jit_internal.py
_linalg_utils.py
_lobpcg.py
_lowrank.py
_meta_registrations.py Add meta functions for lerp, addcmul, and addcdiv. (#136909) 2024-10-12 12:40:46 +00:00
_namedtensor_internals.py
_ops.py Add type annotations for higher order ops/flex_attention (#137065) 2024-10-02 04:39:25 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py Use torch.Stream&torch.Event for Dynamo capature (#134850) 2024-10-02 14:15:33 +00:00
_tensor.py Remove dependency on numpy for serialization for XLA/open registration devices without numpy (#137444) 2024-10-09 19:35:55 +00:00
_tensor_docs.py Revert "Add deterministic path for CUDA cumsum (#136224)" 2024-09-27 12:54:47 +00:00
_tensor_str.py
_thread_safe_fork.py [inductor] parallel compile: add import of thread_safe_fork for internal (#137155) 2024-10-03 17:37:21 +00:00
_torch_docs.py Add torch.squeeze parameter description to declare allowed type (#137485) 2024-10-09 05:29:13 +00:00
_utils.py Remove dependency on numpy for serialization for XLA/open registration devices without numpy (#137444) 2024-10-09 19:35:55 +00:00
_utils_internal.py Log compile ids to pt2_remote_cache and pt2_compile_events (#137431) 2024-10-08 18:04:48 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py Remove dependency on numpy for serialization for XLA/open registration devices without numpy (#137444) 2024-10-09 19:35:55 +00:00
abi-check.cpp
CMakeLists.txt
custom_class.h
custom_class_detail.h
extension.h
functional.py Clarify opt-einsum usage, fix #127109 (#137596) 2024-10-09 20:31:24 +00:00
hub.py torch.hub: add get_dir/set_dir type hints (#134906) 2024-09-12 03:53:29 +00:00
library.h
library.py Fix custom op bug of clearing dir (#137655) 2024-10-11 04:32:40 +00:00
overrides.py Revert "Introduce torch.sym_sum (#136429)" 2024-10-09 20:08:01 +00:00
py.typed
quasirandom.py
random.py [Torch] Support meta device in random.fork_rng (#137715) 2024-10-16 18:00:39 +00:00
README.txt
return_types.py
script.h
serialization.py Revert "Expose option to disable CRC-32 computation during torch.save (#137735)" 2024-10-16 17:03:06 +00:00
storage.py Fix serialization for torch.uint16, torch.uint32, torch.uint64 (#137184) 2024-10-03 14:56:11 +00:00
torch_version.py
types.py
version.py.tpl

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.