mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Expand fused forloop to bfloat16 (#141104)
For MacOS14+ Running following script (adapted from one mentioned in https://github.com/pytorch/pytorch/pull/127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #141089, #141090, #141092, #141103
This commit is contained in:
parent
989888236e
commit
d8b4406e12
3 changed files with 56 additions and 33 deletions
|
|
@ -1,5 +1,12 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using metal::max;
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
#define chunk_size 65536
|
||||
|
|
@ -81,26 +88,28 @@ inline void adam_math_amsgrad(
|
|||
if (weight_decay != 0) {
|
||||
switch (adam_mode) {
|
||||
case ADAM_MODE::ORIGINAL:
|
||||
grad += param * weight_decay;
|
||||
grad += T(param * weight_decay);
|
||||
break;
|
||||
case ADAM_MODE::ADAMW:
|
||||
param -= lr * weight_decay * param;
|
||||
param -= T(lr * weight_decay * param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
|
||||
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
|
||||
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
|
||||
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
|
||||
const float casted_state_steps = static_cast<float>(state_steps);
|
||||
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const T step_size = lr / bias_correction1;
|
||||
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
|
||||
const auto bias_correction1 =
|
||||
1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const auto step_size = lr / bias_correction1;
|
||||
const auto bias_correction2 =
|
||||
1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
max_exp_avg_sq = max(max_exp_avg_sq, exp_avg_sq);
|
||||
|
||||
const T denom =
|
||||
const auto denom =
|
||||
(metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
|
||||
param -= step_size * exp_avg / denom;
|
||||
param -= T(step_size * exp_avg / denom);
|
||||
grad = grad_;
|
||||
}
|
||||
|
||||
|
|
@ -127,24 +136,26 @@ inline void adam_math(
|
|||
if (weight_decay != 0) {
|
||||
switch (adam_mode) {
|
||||
case ADAM_MODE::ORIGINAL:
|
||||
grad += param * weight_decay;
|
||||
grad += T(param * weight_decay);
|
||||
break;
|
||||
case ADAM_MODE::ADAMW:
|
||||
param -= lr * weight_decay * param;
|
||||
param -= T(lr * weight_decay * param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
|
||||
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
|
||||
exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad);
|
||||
exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad);
|
||||
const float casted_state_steps = static_cast<float>(state_steps);
|
||||
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const T step_size = lr / bias_correction1;
|
||||
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
const T denom =
|
||||
const auto bias_correction1 =
|
||||
1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const auto step_size = lr / bias_correction1;
|
||||
const auto bias_correction2 =
|
||||
1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
const auto denom =
|
||||
(metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
|
||||
param -= step_size * exp_avg / denom;
|
||||
param -= T(step_size * exp_avg / denom);
|
||||
grad = grad_;
|
||||
}
|
||||
|
||||
|
|
@ -295,6 +306,11 @@ REGISTER_ADAM_OPS_QUART(float, float);
|
|||
REGISTER_ADAM_OPS_QUART(float, half);
|
||||
REGISTER_ADAM_OPS_QUART(half, float);
|
||||
REGISTER_ADAM_OPS_QUART(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_ADAM_OPS_QUART(float, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, float);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sgd_momentum_math(
|
||||
|
|
@ -310,22 +326,22 @@ inline void sgd_momentum_math(
|
|||
const uint8_t is_first_step) {
|
||||
auto grad_ = grad;
|
||||
if (maximize) {
|
||||
grad_ *= -1.0;
|
||||
grad_ *= T(-1.0);
|
||||
}
|
||||
if (weight_decay != 0) {
|
||||
grad_ += weight_decay * param;
|
||||
grad_ += T(weight_decay * param);
|
||||
}
|
||||
|
||||
momentum_buffer = is_first_step
|
||||
? grad_
|
||||
: (momentum * momentum_buffer + (1 - dampening) * grad_);
|
||||
: T(momentum * momentum_buffer + (1 - dampening) * grad_);
|
||||
if (nesterov) {
|
||||
grad_ += momentum * momentum_buffer;
|
||||
grad_ += T(momentum * momentum_buffer);
|
||||
} else {
|
||||
grad_ = momentum_buffer;
|
||||
}
|
||||
|
||||
param -= lr * grad_;
|
||||
param -= T(lr * grad_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -337,13 +353,13 @@ inline void sgd_math(
|
|||
const uint8_t maximize) {
|
||||
auto grad_ = grad;
|
||||
if (maximize) {
|
||||
grad_ *= -1.0;
|
||||
grad_ *= T(-1.0);
|
||||
}
|
||||
if (weight_decay != 0) {
|
||||
grad_ += weight_decay * param;
|
||||
grad_ += T(weight_decay * param);
|
||||
}
|
||||
|
||||
param -= lr * grad_;
|
||||
param -= T(lr * grad_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -444,3 +460,7 @@ REGISTER_FUSED_SGD_OP(float);
|
|||
REGISTER_FUSED_SGD_OP(half);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_FUSED_SGD_OP(bfloat);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -131,9 +131,9 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
|
|||
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
|
||||
"Only float and half are supported");
|
||||
const auto scalar_type = tensor_lists[d][0].scalar_type();
|
||||
TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16,
|
||||
"Only float, bfloat and half are supported");
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
|
|
|
|||
|
|
@ -1027,8 +1027,11 @@ class TestOptimRenewed(TestCase):
|
|||
if _get_device_type(device) == "mps" and dtype not in (
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
):
|
||||
self.skipTest("MPS supports only torch.float16 and torch.float32")
|
||||
self.skipTest(
|
||||
"MPS supports only torch.float16, torch.float32 and torch.bfloat16"
|
||||
)
|
||||
self._test_derived_optimizers(device, dtype, optim_info, "fused")
|
||||
|
||||
@optims(
|
||||
|
|
|
|||
Loading…
Reference in a new issue