use constexpr (#12953)

This commit is contained in:
Weixing Zhang 2022-09-20 14:34:33 -07:00 committed by GitHub
parent dd39f0293d
commit 4113df0e21
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View file

@ -101,7 +101,7 @@ TEST(CudaKernelTest, LayerNormGrad_SmallSizeTensor) {
TEST(CudaKernelTest, LayerNormGrad_SmallSizeTensor_IntermediateAxis) {
const std::vector<int64_t> X_dims{4, 20, 16, 8};
const int64_t axis = -2;
constexpr int64_t axis = -2;
TestLayerNormGrad(X_dims, LAYER_NORM_GRAD_OP, axis);
}
@ -122,7 +122,7 @@ TEST(CudaKernelTest, SimplifiedLayerNormGrad_SmallSizeTensor) {
TEST(CudaKernelTest, SimplifiedLayerNormGrad_SmallSizeTensor_IntermediateAxis) {
const std::vector<int64_t> X_dims{4, 20, 16, 8};
const int64_t axis = -2;
constexpr int64_t axis = -2;
TestLayerNormGrad(X_dims, SIMPLIFIED_LAYER_NORM_GRAD_OP, axis);
}
@ -239,7 +239,7 @@ TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor) {
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis) {
const std::vector<int64_t> X_dims{4, 20, 16, 8};
const int64_t axis = -2;
constexpr int64_t axis = -2;
TestInvertibleLayerNormGrad(X_dims, axis);
}
@ -260,7 +260,7 @@ TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_FP16) {
TEST(CudaKernelTest, InvertibleLayerNormGrad_SmallSizeTensor_IntermediateAxis_FP16) {
const std::vector<int64_t> X_dims{4, 20, 16, 8};
const int64_t axis = -2;
constexpr int64_t axis = -2;
TestInvertibleLayerNormGrad(X_dims, axis, 2e-3, true);
}

View file

@ -135,10 +135,10 @@ Status InvertibleLayerNormGrad<T, U, V>::ComputeInternal(OpKernelContext* p_op_k
auto bias_grad_data = reinterpret_cast<CudaV*>(bias_grad->template MutableData<V>());
#ifndef USE_ROCM
const int part_size = 16;
constexpr int part_size = 16;
#else
// Optimization for ROCm MI100
const int part_size = 64;
constexpr int part_size = 64;
#endif
auto part_grad_gamma = GetScratchBuffer<CudaU>(part_size * n2);
auto part_grad_beta = GetScratchBuffer<CudaU>(part_size * n2);