[MPS] Expand fused forloop to bfloat16 (#141104)

For MacOS14+

Running following script (adapted from one mentioned in https://github.com/pytorch/pytorch/pull/127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
This commit is contained in:
Nikita Shulga 2024-11-20 07:16:51 -08:00 committed by PyTorch MergeBot
parent 989888236e
commit d8b4406e12
3 changed files with 56 additions and 33 deletions

View file

@ -1,5 +1,12 @@
#include <metal_stdlib>
using metal::max;
#if __METAL_VERSION__ >= 310
bfloat max(bfloat a, bfloat b) {
return a > b ? a : b;
}
#endif
#define kmaxThreadGroups 32
#define kmaxTensors 32
#define chunk_size 65536
@ -81,26 +88,28 @@ inline void adam_math_amsgrad(
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += param * weight_decay;
grad += T(param * weight_decay);
break;
case ADAM_MODE::ADAMW:
param -= lr * weight_decay * param;
param -= T(lr * weight_decay * param);
break;
}
}
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
const float casted_state_steps = static_cast<float>(state_steps);
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
const auto bias_correction1 =
1 - metal::precise::pow(beta1, casted_state_steps);
const auto step_size = lr / bias_correction1;
const auto bias_correction2 =
1 - metal::precise::pow(beta2, casted_state_steps);
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
max_exp_avg_sq = max(max_exp_avg_sq, exp_avg_sq);
const T denom =
const auto denom =
(metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= step_size * exp_avg / denom;
param -= T(step_size * exp_avg / denom);
grad = grad_;
}
@ -127,24 +136,26 @@ inline void adam_math(
if (weight_decay != 0) {
switch (adam_mode) {
case ADAM_MODE::ORIGINAL:
grad += param * weight_decay;
grad += T(param * weight_decay);
break;
case ADAM_MODE::ADAMW:
param -= lr * weight_decay * param;
param -= T(lr * weight_decay * param);
break;
}
}
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
const float casted_state_steps = static_cast<float>(state_steps);
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
const T step_size = lr / bias_correction1;
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
const T denom =
const auto bias_correction1 =
1 - metal::precise::pow(beta1, casted_state_steps);
const auto step_size = lr / bias_correction1;
const auto bias_correction2 =
1 - metal::precise::pow(beta2, casted_state_steps);
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
const auto denom =
(metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
param -= step_size * exp_avg / denom;
param -= T(step_size * exp_avg / denom);
grad = grad_;
}
@ -295,6 +306,11 @@ REGISTER_ADAM_OPS_QUART(float, float);
REGISTER_ADAM_OPS_QUART(float, half);
REGISTER_ADAM_OPS_QUART(half, float);
REGISTER_ADAM_OPS_QUART(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_ADAM_OPS_QUART(float, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, float);
#endif
template <typename T>
inline void sgd_momentum_math(
@ -310,22 +326,22 @@ inline void sgd_momentum_math(
const uint8_t is_first_step) {
auto grad_ = grad;
if (maximize) {
grad_ *= -1.0;
grad_ *= T(-1.0);
}
if (weight_decay != 0) {
grad_ += weight_decay * param;
grad_ += T(weight_decay * param);
}
momentum_buffer = is_first_step
? grad_
: (momentum * momentum_buffer + (1 - dampening) * grad_);
: T(momentum * momentum_buffer + (1 - dampening) * grad_);
if (nesterov) {
grad_ += momentum * momentum_buffer;
grad_ += T(momentum * momentum_buffer);
} else {
grad_ = momentum_buffer;
}
param -= lr * grad_;
param -= T(lr * grad_);
}
template <typename T>
@ -337,13 +353,13 @@ inline void sgd_math(
const uint8_t maximize) {
auto grad_ = grad;
if (maximize) {
grad_ *= -1.0;
grad_ *= T(-1.0);
}
if (weight_decay != 0) {
grad_ += weight_decay * param;
grad_ += T(weight_decay * param);
}
param -= lr * grad_;
param -= T(lr * grad_);
}
template <typename T>
@ -444,3 +460,7 @@ REGISTER_FUSED_SGD_OP(float);
REGISTER_FUSED_SGD_OP(half);
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_FUSED_SGD_OP(bfloat);
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
#endif

View file

@ -131,9 +131,9 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
"Only float and half are supported");
const auto scalar_type = tensor_lists[d][0].scalar_type();
TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16,
"Only float, bfloat and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();

View file

@ -1027,8 +1027,11 @@ class TestOptimRenewed(TestCase):
if _get_device_type(device) == "mps" and dtype not in (
torch.float16,
torch.float32,
torch.bfloat16,
):
self.skipTest("MPS supports only torch.float16 and torch.float32")
self.skipTest(
"MPS supports only torch.float16, torch.float32 and torch.bfloat16"
)
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@optims(