From cee3f8541e9fed93095e249d17f791b861548562 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:08:43 +0000 Subject: [PATCH] [MPS][BE] Use `mtl_setBytes` to upload bools as is (#141037) But add static assert that size of bool is a single byte, to guard against hard to debug corruptions if someone decides to typedef it to int Fixes https://github.com/pytorch/pytorch/issues/140971 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141037 Approved by: https://github.com/qqaatw, https://github.com/Skylion007 --- .../native/mps/operations/MultiTensorApply.h | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h index 7507165bdfe..e97afd49e57 100644 --- a/aten/src/ATen/native/mps/operations/MultiTensorApply.h +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -3,6 +3,8 @@ #include #include +static_assert(sizeof(bool) == 1); + namespace at::native::mps { static constexpr int64_t kChunkSize = 65536; @@ -25,11 +27,8 @@ struct FusedAdamEncodingFunctor { const double weight_decay, const double eps, const bool maximize) const { - float eps_lv = eps; - uint8_t maximize_lv = maximize; - mtl_setArgs( - computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv); + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); } void operator()(id& computeEncoder, @@ -41,10 +40,8 @@ struct FusedAdamEncodingFunctor { const double weight_decay, const double eps, const bool maximize) const { - uint8_t maximize_lv = maximize; - mtl_setArgs( - computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv); + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); } }; @@ -63,9 +60,6 @@ struct FusedSgdEncodingFunctor { const bool nesterov, const bool maximize, const bool is_first_step) const { - uint8_t nesterov_lv = nesterov; - uint8_t maximize_lv = maximize; - uint8_t is_first_step_lv = is_first_step; mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, @@ -73,9 +67,9 @@ struct FusedSgdEncodingFunctor { momentum, lr, dampening, - nesterov_lv, - maximize_lv, - is_first_step_lv); + nesterov, + maximize, + is_first_step); } void operator()(id& computeEncoder, @@ -88,10 +82,6 @@ struct FusedSgdEncodingFunctor { const bool nesterov, const bool maximize, const bool is_first_step) const { - uint8_t nesterov_lv = nesterov; - uint8_t maximize_lv = maximize; - uint8_t is_first_step_lv = is_first_step; - mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, @@ -99,9 +89,9 @@ struct FusedSgdEncodingFunctor { momentum, lr, dampening, - nesterov_lv, - maximize_lv, - is_first_step_lv); + nesterov, + maximize, + is_first_step); } }; @@ -113,9 +103,7 @@ struct FusedSgdEncodingFunctor { const double weight_decay, const double lr, const bool maximize) const { - uint8_t maximize_lv = maximize; - - mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv); + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); } void operator()(id& computeEncoder, @@ -124,9 +112,7 @@ struct FusedSgdEncodingFunctor { const double weight_decay, const at::Tensor& lr, const bool maximize) const { - uint8_t maximize_lv = maximize; - - mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv); + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); } };