diff --git a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal index 602b32b5208..2006e768d82 100644 --- a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal +++ b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal @@ -1,5 +1,12 @@ #include +using metal::max; +#if __METAL_VERSION__ >= 310 +bfloat max(bfloat a, bfloat b) { + return a > b ? a : b; +} +#endif + #define kmaxThreadGroups 32 #define kmaxTensors 32 #define chunk_size 65536 @@ -81,26 +88,28 @@ inline void adam_math_amsgrad( if (weight_decay != 0) { switch (adam_mode) { case ADAM_MODE::ORIGINAL: - grad += param * weight_decay; + grad += T(param * weight_decay); break; case ADAM_MODE::ADAMW: - param -= lr * weight_decay * param; + param -= T(lr * weight_decay * param); break; } } - exp_avg = beta1 * exp_avg + (1 - beta1) * grad; - exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad); + exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad); const float casted_state_steps = static_cast(state_steps); - const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); - const T step_size = lr / bias_correction1; - const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); - const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); - max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq); + const auto bias_correction1 = + 1 - metal::precise::pow(beta1, casted_state_steps); + const auto step_size = lr / bias_correction1; + const auto bias_correction2 = + 1 - metal::precise::pow(beta2, casted_state_steps); + const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + max_exp_avg_sq = max(max_exp_avg_sq, exp_avg_sq); - const T denom = + const auto denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; - param -= step_size * exp_avg / denom; + param -= T(step_size * exp_avg / denom); grad = grad_; } @@ -127,24 +136,26 @@ inline void adam_math( if (weight_decay != 0) { switch (adam_mode) { case ADAM_MODE::ORIGINAL: - grad += param * weight_decay; + grad += T(param * weight_decay); break; case ADAM_MODE::ADAMW: - param -= lr * weight_decay * param; + param -= T(lr * weight_decay * param); break; } } - exp_avg = beta1 * exp_avg + (1 - beta1) * grad; - exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + exp_avg = T(beta1 * exp_avg + (1 - beta1) * grad); + exp_avg_sq = T(beta2 * exp_avg_sq + (1 - beta2) * grad * grad); const float casted_state_steps = static_cast(state_steps); - const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); - const T step_size = lr / bias_correction1; - const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); - const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); - const T denom = + const auto bias_correction1 = + 1 - metal::precise::pow(beta1, casted_state_steps); + const auto step_size = lr / bias_correction1; + const auto bias_correction2 = + 1 - metal::precise::pow(beta2, casted_state_steps); + const auto bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + const auto denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; - param -= step_size * exp_avg / denom; + param -= T(step_size * exp_avg / denom); grad = grad_; } @@ -295,6 +306,11 @@ REGISTER_ADAM_OPS_QUART(float, float); REGISTER_ADAM_OPS_QUART(float, half); REGISTER_ADAM_OPS_QUART(half, float); REGISTER_ADAM_OPS_QUART(half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_ADAM_OPS_QUART(float, bfloat); +REGISTER_ADAM_OPS_QUART(bfloat, bfloat); +REGISTER_ADAM_OPS_QUART(bfloat, float); +#endif template inline void sgd_momentum_math( @@ -310,22 +326,22 @@ inline void sgd_momentum_math( const uint8_t is_first_step) { auto grad_ = grad; if (maximize) { - grad_ *= -1.0; + grad_ *= T(-1.0); } if (weight_decay != 0) { - grad_ += weight_decay * param; + grad_ += T(weight_decay * param); } momentum_buffer = is_first_step ? grad_ - : (momentum * momentum_buffer + (1 - dampening) * grad_); + : T(momentum * momentum_buffer + (1 - dampening) * grad_); if (nesterov) { - grad_ += momentum * momentum_buffer; + grad_ += T(momentum * momentum_buffer); } else { grad_ = momentum_buffer; } - param -= lr * grad_; + param -= T(lr * grad_); } template @@ -337,13 +353,13 @@ inline void sgd_math( const uint8_t maximize) { auto grad_ = grad; if (maximize) { - grad_ *= -1.0; + grad_ *= T(-1.0); } if (weight_decay != 0) { - grad_ += weight_decay * param; + grad_ += T(weight_decay * param); } - param -= lr * grad_; + param -= T(lr * grad_); } template @@ -444,3 +460,7 @@ REGISTER_FUSED_SGD_OP(float); REGISTER_FUSED_SGD_OP(half); REGISTER_FUSED_SGD_MOMENTUM_OP(float); REGISTER_FUSED_SGD_MOMENTUM_OP(half); +#if __METAL_VERSION__ >= 310 +REGISTER_FUSED_SGD_OP(bfloat); +REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat); +#endif diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h index db48d8ec075..cb8d65a129c 100644 --- a/aten/src/ATen/native/mps/operations/MultiTensorApply.h +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -131,9 +131,9 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); for (const auto& d : c10::irange(depth)) { - TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float || - tensor_lists[d][0].scalar_type() == at::ScalarType::Half, - "Only float and half are supported"); + const auto scalar_type = tensor_lists[d][0].scalar_type(); + TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16, + "Only float, bfloat and half are supported"); } id device = MPSDevice::getInstance()->device(); diff --git a/test/test_optim.py b/test/test_optim.py index 9a239b384bc..1ee77312c8f 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1027,8 +1027,11 @@ class TestOptimRenewed(TestCase): if _get_device_type(device) == "mps" and dtype not in ( torch.float16, torch.float32, + torch.bfloat16, ): - self.skipTest("MPS supports only torch.float16 and torch.float32") + self.skipTest( + "MPS supports only torch.float16, torch.float32 and torch.bfloat16" + ) self._test_derived_optimizers(device, dtype, optim_info, "fused") @optims(