2025-01-21 00:17:30 +00:00
|
|
|
from typing import Optional
|
2023-01-11 23:37:20 +00:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import Tensor
|
2023-01-20 21:43:29 +00:00
|
|
|
from torch.autograd.grad_mode import no_grad
|
Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428
Also improves hook registration type hints:
```python
from typing import Any, Dict, Tuple
from torch import nn
from torch.optim import Adam, Adagrad, Optimizer
linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)
def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def pre_hook_fn_return_modified(
optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
return inputs, kwargs
def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
optimizer.register_step_post_hook(hook_fn) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_none) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified) # OK
optimizer.register_step_post_hook(hook_fn_other_optimizer) # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
|
|
|
from typing_extensions import TypeAlias
|
2023-01-11 23:37:20 +00:00
|
|
|
|
2025-01-21 00:17:30 +00:00
|
|
|
def _get_foreach_kernels_supported_devices() -> list[str]:
|
2023-11-13 19:37:21 +00:00
|
|
|
r"""Return the device type list that supports foreach kernels."""
|
2023-08-09 07:50:59 +00:00
|
|
|
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
2023-06-07 13:59:20 +00:00
|
|
|
|
2025-01-21 00:17:30 +00:00
|
|
|
def _get_fused_kernels_supported_devices() -> list[str]:
|
2023-11-13 19:37:21 +00:00
|
|
|
r"""Return the device type list that supports fused kernels in optimizer."""
|
[MPS] Fused Adam & AdamW (#127242)
Summary:
This PR adds fused Adam and AdamW implementations.
Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300
Times are in milliseconds (ms).
```
**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200
Times are in milliseconds (ms).
```
```python
def profile_fused_adam():
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools
def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
fn(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=fused,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-3,
weight_decay=.0,
eps=1e-5,
maximize=False,
grad_scale=None,
found_inf=None,
)
torch.mps.synchronize()
device = "mps"
results = []
for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
if adamWflag:
fn = adamw.adamw
else:
fn = adam.adam
for fused in [True, False]:
t = benchmark.Timer(
stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
label='Fused Adam',
sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
globals=locals(),
description= f"Fused: {fused}",
).blocked_autorange(min_run_time=5)
results.append(t)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2024-06-18 19:59:50 +00:00
|
|
|
return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
|
2023-01-11 23:37:20 +00:00
|
|
|
|
2025-01-21 00:17:30 +00:00
|
|
|
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
|
|
|
|
|
Indices: TypeAlias = list[int]
|
2024-05-17 06:57:49 +00:00
|
|
|
_foreach_supported_types = [torch.Tensor]
|
|
|
|
|
|
Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428
Also improves hook registration type hints:
```python
from typing import Any, Dict, Tuple
from torch import nn
from torch.optim import Adam, Adagrad, Optimizer
linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)
def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def pre_hook_fn_return_modified(
optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
return inputs, kwargs
def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
optimizer.register_step_post_hook(hook_fn) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_none) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified) # OK
optimizer.register_step_post_hook(hook_fn_other_optimizer) # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
|
|
|
|
2023-01-20 21:43:29 +00:00
|
|
|
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
|
|
|
|
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
2023-01-11 23:37:20 +00:00
|
|
|
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
|
|
|
|
# - tensorlists CAN be None
|
|
|
|
|
# - all tensors in the first specified list cannot be None
|
|
|
|
|
# - given an index i, all specified tensorlist[i]s match in dtype and device
|
2023-01-18 04:02:36 +00:00
|
|
|
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
|
|
|
|
|
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
|
|
|
|
|
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
|
|
|
|
|
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
|
|
|
|
# may be necessary. Check out torch/optim/sgd.py for an example.
|
2023-01-20 21:43:29 +00:00
|
|
|
@no_grad()
|
Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:29 +00:00
|
|
|
def _group_tensors_by_device_and_dtype(
|
Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428
Also improves hook registration type hints:
```python
from typing import Any, Dict, Tuple
from torch import nn
from torch.optim import Adam, Adagrad, Optimizer
linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)
def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def pre_hook_fn_return_modified(
optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
return inputs, kwargs
def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
return None
optimizer.register_step_post_hook(hook_fn) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_none) # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified) # OK
optimizer.register_step_post_hook(hook_fn_other_optimizer) # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
|
|
|
tensorlistlist: TensorListList,
|
Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:29 +00:00
|
|
|
with_indices: bool = False,
|
2025-01-21 00:17:30 +00:00
|
|
|
) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]:
|
2024-06-04 18:19:28 +00:00
|
|
|
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
2023-01-25 19:27:31 +00:00
|
|
|
|
2024-03-01 16:38:19 +00:00
|
|
|
def _device_has_foreach_support(device: torch.device) -> bool:
|
|
|
|
|
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
|
|
|
|
|
|
|
|
|
|
2025-01-21 00:17:30 +00:00
|
|
|
def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool:
|
2024-05-17 06:57:49 +00:00
|
|
|
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|