From 4c6dc6a1a479dcb9dc3ca9b08c480fdcefd26113 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 28 Sep 2022 17:12:25 +0000 Subject: [PATCH] [BE] Do not use VLA (#85800) [Variable Length Array](https://en.wikipedia.org/wiki/Variable-length_array) is part of C99 standard, but has never been adopted to C++ Also, warning if they are used (surprisingly those warnings can not be turned into errors. Remove code duplication in `OperationUtils.mm` Pull Request resolved: https://github.com/pytorch/pytorch/pull/85800 Approved by: https://github.com/kulinseth, https://github.com/jeanschmidt --- CMakeLists.txt | 1 + aten/src/ATen/native/mps/OperationUtils.mm | 22 ++++--------------- .../native/mps/operations/AdaptivePooling.mm | 12 +++++----- aten/src/ATen/native/mps/operations/Linear.mm | 10 ++++----- .../ATen/native/mps/operations/ReduceOps.mm | 10 +++------ aten/src/ATen/native/mps/operations/Repeat.mm | 9 ++++---- .../native/mps/operations/ScatterGather.mm | 4 ++-- aten/src/ATen/native/mps/operations/Shape.mm | 6 ++--- .../native/mps/operations/TensorCompare.mm | 16 +++++++------- 9 files changed, 35 insertions(+), 55 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e8ab95f4e5a..6becf22d894 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -838,6 +838,7 @@ if(NOT MSVC) append_cxx_flag_if_supported("-Wno-strict-overflow" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-strict-aliasing" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-error=deprecated-declarations" CMAKE_CXX_FLAGS) + append_cxx_flag_if_supported("-Wvla-extension" CMAKE_CXX_FLAGS) if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") string(APPEND CMAKE_CXX_FLAGS " -Wno-range-loop-analysis") string(APPEND CMAKE_CXX_FLAGS " -Wno-pass-failed") diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index e237254cd70..16585d9128b 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -183,19 +183,7 @@ double getMPSScalarValue(const Tensor& t) { } MPSShape* getMPSShape(const Tensor& t) { - const int sz = t.dim(); - const int sz_ = (sz > 0) ? sz : 1; - - NSNumber* numbers[sz_]; - - for (int i = 0; i < sz_; i++) - { - NSInteger sz_i = (i < sz) ? t.size(i) : 1; - - NSNumber* number = [NSNumber numberWithInteger:sz_i]; - numbers[i] = number; - } - return [NSArray arrayWithObjects:numbers count:sz_]; + return getMPSShape(t.sizes()); } MPSShape* getMPSShape(c10::MaybeOwned t) { @@ -207,16 +195,14 @@ MPSShape* getMPSShape(IntArrayRef sizes) { const int sz = sizes.size(); const int sz_ = (sz > 0) ? sz : 1; - NSNumber* numbers[sz_]; + std::vector numbers(sz_); - for (int i = 0; i < sz_; i++) - { + for (int i = 0; i < sz_; i++) { NSInteger sz_i = (i < sz) ? sizes[i] : 1; - NSNumber* number = [NSNumber numberWithInteger:sz_i]; numbers[i] = number; } - return [NSArray arrayWithObjects:numbers count:sz_]; + return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; } void printTensorNDArray(const Tensor& t) { diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm index c4184ee3efe..e13deb805bb 100644 --- a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm +++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm @@ -97,13 +97,11 @@ Tensor& adaptive_avg_pool2d_out_mps c10::nullopt); } else { Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto num_input_dims = input.sizes().size(); - int64_t phony_shape[num_input_dims]; - for(int i = 0; i < num_input_dims - 2; i++) - phony_shape[i] = input.size(i); - phony_shape[num_input_dims-2] = output_size[0]; - phony_shape[num_input_dims-1] = output_size[1]; - phony_grad.resize_(IntArrayRef(phony_shape, num_input_dims)); + auto input_sizes = input.sizes(); + std::vector phony_shape{input_sizes.begin(), input_sizes.end() -2}; + phony_shape.push_back(output_size[0]); + phony_shape.push_back(output_size[1]); + phony_grad.resize_(IntArrayRef(phony_shape)); output = at::avg_pool2d_backward(input, phony_grad, IntArrayRef({kernel_sizeH, kernel_sizeW}), diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 33ffa1c7bba..ddaa6ce9796 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -152,14 +152,12 @@ Tensor _mps_linear( mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - // Shave off '1' present at the end of the shape + // Shave off '1' present at the end of the shape if(weight_arg.dim() == 1) { // Number of elements in new output shape - auto N = output.dim() - 1; - int64_t out_shape[N]; - for(int i = 0; i < N; i++) - out_shape[i] = output.size(i); - return output.view(IntArrayRef(out_shape, N)); + auto output_sizes = output.sizes(); + std::vector out_shape(output_sizes.begin(), output_sizes.end()-1); + return output.view(IntArrayRef(out_shape)); } else return output; diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 56d1e0fbbaf..36a68fc5331 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -333,12 +333,8 @@ TORCH_IMPL_FUNC(amin_out_mps) Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { - auto num_dims = self.dim(); - - int64_t dims[num_dims]; - - for(int i = 0; i < num_dims; i++) - dims[i] = i; + std::vector dims(self.dim()); + std::iota(dims.begin(), dims.end(), 0); Tensor output_t = at::native::empty_mps( {}, @@ -348,7 +344,7 @@ Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { c10::nullopt, c10::nullopt); - reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast(output_t), MPSReductionType::PROD, "prod_mps"); + reduction_out_mps(self, IntArrayRef(dims), false, opt_dtype, const_cast(output_t), MPSReductionType::PROD, "prod_mps"); return output_t; } diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 53bcddf405c..8b6b709da64 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -108,16 +108,17 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { num_repeat_dims); // Set output shape - int64_t output_shape[num_repeat_dims]; + std::vector output_shape(num_repeat_dims); bool zero_tensor = false; - for(int i = 0; i < num_repeat_dims; i++) { + for(auto i : c10::irange(num_repeat_dims)) { output_shape[i] = repeats[i] * [apparent_input_shape[i] intValue]; - if(output_shape[i] == 0) + if(output_shape[i] == 0) { zero_tensor = true; + } } Tensor output = at::native::empty_mps( - IntArrayRef(output_shape, num_repeat_dims), + IntArrayRef(output_shape), self.scalar_type(), c10::nullopt, kMPS, diff --git a/aten/src/ATen/native/mps/operations/ScatterGather.mm b/aten/src/ATen/native/mps/operations/ScatterGather.mm index c4943d1242d..a7e2c8b8665 100644 --- a/aten/src/ATen/native/mps/operations/ScatterGather.mm +++ b/aten/src/ATen/native/mps/operations/ScatterGather.mm @@ -358,13 +358,13 @@ void scatter_mps_general // 2. Flatten the values // 3. Scatter into input with add mode - int shape_data[num_input_dims]; + std::vector shape_data(num_input_dims); for(int i = 0; i < num_input_dims; i++) { shape_data[i] = {[scatterInputShape[i] intValue]}; } - MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data length:num_input_dims * sizeof(int)] + MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)] shape:@[[NSNumber numberWithInt:num_input_dims]] dataType:MPSDataTypeInt32]; diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 9beafb3a15f..f491f2ff823 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -389,8 +389,8 @@ TORCH_IMPL_FUNC(cat_out_mps) // Create placeholders auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); - MPSGraphTensor* inputMPSGraphTensors[len_tensor_array]; - MPSGraphTensor* castInputMPSGraphTensors[len_tensor_array]; + std::vector inputMPSGraphTensors(len_tensor_array); + std::vector castInputMPSGraphTensors(len_tensor_array); int graph_tensor_idx = 0; for(const Tensor* tensor : input_tensors) { @@ -411,7 +411,7 @@ TORCH_IMPL_FUNC(cat_out_mps) graph_tensor_idx++; } - auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors + auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors.data() count:len_tensor_array]; // Use concatTensors to concatenate MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index f30d8af15ec..f9beef8fd37 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -104,17 +104,17 @@ void clamp_tensor_out_mps(const Tensor& input_t, auto num_max_dims = max_opt->dim(); auto num_input_dims = input_t.dim(); - int64_t new_min_arr[num_input_dims]; - int64_t new_max_arr[num_input_dims]; + std::vector new_min_arr(num_input_dims); + std::vector new_max_arr(num_input_dims); if(has_min && num_min_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_min_dims, new_min_arr, min_opt->sizes()); - new_min_shape = IntArrayRef({new_min_arr, num_input_dims}); + fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes()); + new_min_shape = IntArrayRef(new_min_arr); } if(has_max && num_max_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_max_dims, new_max_arr, max_opt->sizes()); - new_max_shape = IntArrayRef({new_max_arr, num_input_dims}); + fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes()); + new_max_shape = IntArrayRef(new_max_arr); } Tensor min_opt_tensor; @@ -390,7 +390,7 @@ Tensor where_mps(const Tensor& condition, TORCH_CHECK(max_dim == 0 || !(sum_dims % max_dim), "All inputs of where should have same/compatible number of dims") - int64_t out_arr[max_dim]; + std::vector out_arr(max_dim); // Broadcasted output shape for(int i = 0; i < max_dim; i++) { @@ -402,7 +402,7 @@ Tensor where_mps(const Tensor& condition, out_arr[i] = std::max(cond_num, std::max(self_num, other_num)); } - Tensor ret = empty_mps(IntArrayRef(out_arr, max_dim), + Tensor ret = empty_mps(IntArrayRef(out_arr), self.scalar_type(), c10::nullopt, kMPS,