[MPS][BE] Use mtl_setBytes to upload bools as is (#141037)

But add static assert that size of bool is a single byte, to guard against hard to debug corruptions if someone decides to typedef it to int

Fixes https://github.com/pytorch/pytorch/issues/140971

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141037
Approved by: https://github.com/qqaatw, https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2024-11-19 23:08:43 +00:00 committed by PyTorch MergeBot
parent 9fac5a16fd
commit cee3f8541e

View file

@ -3,6 +3,8 @@
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
static_assert(sizeof(bool) == 1);
namespace at::native::mps {
static constexpr int64_t kChunkSize = 65536;
@ -25,11 +27,8 @@ struct FusedAdamEncodingFunctor {
const double weight_decay,
const double eps,
const bool maximize) const {
float eps_lv = eps;
uint8_t maximize_lv = maximize;
mtl_setArgs(
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv);
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize);
}
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
@ -41,10 +40,8 @@ struct FusedAdamEncodingFunctor {
const double weight_decay,
const double eps,
const bool maximize) const {
uint8_t maximize_lv = maximize;
mtl_setArgs(
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv);
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize);
}
};
@ -63,9 +60,6 @@ struct FusedSgdEncodingFunctor<true> {
const bool nesterov,
const bool maximize,
const bool is_first_step) const {
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
mtl_setArgs(computeEncoder,
tensorArgumentBuffer,
metadata_arguments,
@ -73,9 +67,9 @@ struct FusedSgdEncodingFunctor<true> {
momentum,
lr,
dampening,
nesterov_lv,
maximize_lv,
is_first_step_lv);
nesterov,
maximize,
is_first_step);
}
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
@ -88,10 +82,6 @@ struct FusedSgdEncodingFunctor<true> {
const bool nesterov,
const bool maximize,
const bool is_first_step) const {
uint8_t nesterov_lv = nesterov;
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
mtl_setArgs(computeEncoder,
tensorArgumentBuffer,
metadata_arguments,
@ -99,9 +89,9 @@ struct FusedSgdEncodingFunctor<true> {
momentum,
lr,
dampening,
nesterov_lv,
maximize_lv,
is_first_step_lv);
nesterov,
maximize,
is_first_step);
}
};
@ -113,9 +103,7 @@ struct FusedSgdEncodingFunctor<false> {
const double weight_decay,
const double lr,
const bool maximize) const {
uint8_t maximize_lv = maximize;
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv);
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize);
}
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
@ -124,9 +112,7 @@ struct FusedSgdEncodingFunctor<false> {
const double weight_decay,
const at::Tensor& lr,
const bool maximize) const {
uint8_t maximize_lv = maximize;
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv);
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize);
}
};