mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS][BE] Use mtl_setBytes to upload bools as is (#141037)
But add static assert that size of bool is a single byte, to guard against hard to debug corruptions if someone decides to typedef it to int Fixes https://github.com/pytorch/pytorch/issues/140971 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141037 Approved by: https://github.com/qqaatw, https://github.com/Skylion007
This commit is contained in:
parent
9fac5a16fd
commit
cee3f8541e
1 changed files with 12 additions and 26 deletions
|
|
@ -3,6 +3,8 @@
|
|||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
|
||||
|
||||
static_assert(sizeof(bool) == 1);
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
static constexpr int64_t kChunkSize = 65536;
|
||||
|
|
@ -25,11 +27,8 @@ struct FusedAdamEncodingFunctor {
|
|||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize) const {
|
||||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
mtl_setArgs(
|
||||
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv);
|
||||
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
|
|
@ -41,10 +40,8 @@ struct FusedAdamEncodingFunctor {
|
|||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize) const {
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
mtl_setArgs(
|
||||
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize_lv);
|
||||
computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -63,9 +60,6 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step) const {
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
mtl_setArgs(computeEncoder,
|
||||
tensorArgumentBuffer,
|
||||
metadata_arguments,
|
||||
|
|
@ -73,9 +67,9 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
momentum,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov_lv,
|
||||
maximize_lv,
|
||||
is_first_step_lv);
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step);
|
||||
}
|
||||
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
|
|
@ -88,10 +82,6 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step) const {
|
||||
uint8_t nesterov_lv = nesterov;
|
||||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
|
||||
mtl_setArgs(computeEncoder,
|
||||
tensorArgumentBuffer,
|
||||
metadata_arguments,
|
||||
|
|
@ -99,9 +89,9 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
momentum,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov_lv,
|
||||
maximize_lv,
|
||||
is_first_step_lv);
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -113,9 +103,7 @@ struct FusedSgdEncodingFunctor<false> {
|
|||
const double weight_decay,
|
||||
const double lr,
|
||||
const bool maximize) const {
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv);
|
||||
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize);
|
||||
}
|
||||
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
|
|
@ -124,9 +112,7 @@ struct FusedSgdEncodingFunctor<false> {
|
|||
const double weight_decay,
|
||||
const at::Tensor& lr,
|
||||
const bool maximize) const {
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize_lv);
|
||||
mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue