mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Torch] Support meta device in random.fork_rng (#137715)
Summary: ## Why random.fork_rng doesn't support meta device: ``` [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/aps_models/ads/tools/memory_estimator/estimation_dense.py", line 655, in estimate_dense_memory_size [rank0]: losses.sum().backward() [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/_tensor.py", line 604, in backward [rank0]: return handle_torch_function( [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/overrides.py", line 1718, in handle_torch_function [rank0]: result = mode.__torch_function__(public_api, types, args, kwargs) [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/_device.py", line 106, in __torch_function__ [rank0]: return func(*args, **kwargs) [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/_tensor.py", line 613, in backward [rank0]: torch.autograd.backward( [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/autograd/__init__.py", line 347, in backward [rank0]: _engine_run_backward( [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/autograd/graph.py", line 825, in _engine_run_backward [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/checkpoint.py", line 1125, in unpack_hook [rank0]: frame.recompute_fn(*args) [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/utils/checkpoint.py", line 1507, in recompute_fn [rank0]: with torch.random.fork_rng( [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/runtime/lib/python3.10/contextlib.py", line 135, in __enter__ [rank0]: return next(self.gen) [rank0]: File "/data/users/lyu1/fbsource/buck-out/v2/gen/fbcode/581363ebaea3320a/aps_models/ads/tools/memory_estimator/__memory_estimator__/memory_estimator-inplace#link-tree/torch/random.py", line 153, in fork_rng [rank0]: raise RuntimeError( [rank0]: RuntimeError: torch has no module of `meta`, you should register a module by `torch._register_device_module`. ``` This blocks us from running backward() on model with checkpoint enabled in meta mode. ## What This diff handles the case of meta device in random.fork_rng. Test Plan: Tested with toy model which has checkpoint on its module: P1641201046 Differential Revision: D64161410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137715 Approved by: https://github.com/kit1980
This commit is contained in:
parent
a47bb4a393
commit
dabe2a3c3b
1 changed files with 4 additions and 0 deletions
|
|
@ -147,6 +147,10 @@ def fork_rng(
|
|||
see details in [Note: support the custom device with privateuse1]
|
||||
"""
|
||||
|
||||
if device_type == "meta":
|
||||
yield
|
||||
return
|
||||
|
||||
device_type = torch.device(device_type).type
|
||||
device_mod = getattr(torch, device_type, None)
|
||||
if device_mod is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue