[3/N] Replace at::detail::Array with std::array (#141324)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141324
Approved by: https://github.com/ezyang
This commit is contained in:
cyy 2024-11-24 18:17:32 +00:00 committed by PyTorch MergeBot
parent e3cb167560
commit 259a00b727
4 changed files with 38 additions and 37 deletions

View file

@ -29,7 +29,7 @@ struct OffsetCalculator {
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
using offset_type = std::array<stride_t, std::max<int>(NARGS, 1)>;
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
// the strides will be in # of elements.
@ -80,7 +80,7 @@ struct TrivialOffsetCalculator {
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
using offset_type = std::array<index_t, std::max<int>(NARGS, 1)>;
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;

View file

@ -284,7 +284,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f)
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
at::detail::Array<char*, ntensors> data;
std::array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}

View file

@ -80,10 +80,10 @@ struct unroll_load_helper {
template <int current>
struct multi_outputs_store_helper {
template<int ntensors, int num_outputs, typename ...Args>
template<typename data_t, typename offsets_t, typename ...Args>
C10_HOST_DEVICE static void apply(
at::detail::Array<char*, ntensors> data,
at::detail::Array<uint32_t, num_outputs> offsets,
const data_t& data,
const offsets_t& offsets,
thrust::tuple<Args...> ret) {
using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
@ -102,8 +102,8 @@ struct LoadWithoutCast {
template <int N>
struct LoadWithCast {
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
using array_t = std::array<at::ScalarType, std::max<int>(N, 1)>;
using size_array_t = std::array<uint32_t, std::max<int>(N, 1)>;
array_t dtypes;
size_array_t element_sizes;
@ -133,8 +133,8 @@ struct StoreWithoutCast {
template <int N = 1>
struct StoreWithCast {
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
using array_t = std::array<at::ScalarType, std::max<int>(N, 1)>;
using size_array_t = std::array<uint32_t, std::max<int>(N, 1)>;
array_t dtypes;
size_array_t element_sizes;

View file

@ -11,6 +11,7 @@
#include <ATen/OpMathType.h>
#include <c10/macros/Macros.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <array>
#include <functional>
#include <iosfwd>
#include <type_traits>
@ -21,8 +22,6 @@
namespace at::native {
using at::detail::Array;
static inline int64_t div_up(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
@ -407,7 +406,7 @@ struct ReduceOp {
index_t input_idx = config.input_idx();
auto base_offsets1 = output_calc.get(output_idx)[1];
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using arg_vec_t = std::array<arg_t, output_vec_size>;
arg_vec_t value;
if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
@ -422,8 +421,8 @@ struct ReduceOp {
value = block_x_reduce<output_vec_size>(value, shared_memory);
}
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
using offset_vec_t = std::array<index_t, output_vec_size>;
offset_vec_t base_offsets;
out_ptr_vec_t out;
@ -480,7 +479,7 @@ struct ReduceOp {
}
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
C10_DEVICE std::array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
if (config.vectorize_input) {
CUDA_KERNEL_ASSERT(output_vec_size == 1);
// reduce at the header of input_slice where memory is not aligned,
@ -561,12 +560,12 @@ struct ReduceOp {
}
template <int output_vec_size, typename offset_calc_t>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
C10_DEVICE std::array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
index_t idx = config.input_idx();
const index_t end = config.num_inputs;
const index_t stride = config.step_input;
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using arg_vec_t = std::array<arg_t, output_vec_size>;
using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
// Multiple accumulators to remove dependency between unrolled loops.
@ -634,8 +633,8 @@ struct ReduceOp {
}
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
C10_DEVICE std::array<arg_t, output_vec_size> block_x_reduce(std::array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = std::array<arg_t, output_vec_size>;
int dim_x = blockDim.x;
args_vec_t* shared = (args_vec_t*)shared_memory;
if (dim_x > warpSize) {
@ -668,8 +667,8 @@ struct ReduceOp {
}
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
C10_DEVICE std::array<arg_t, output_vec_size> block_y_reduce(std::array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = std::array<arg_t, output_vec_size>;
args_vec_t* shared = (args_vec_t*)shared_memory;
shared[config.shared_memory_offset(0)] = value;
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
@ -701,12 +700,12 @@ struct ReduceOp {
}
template <int output_vec_size, bool can_acc>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
at::detail::Array<out_scalar_t*, output_vec_size> out,
at::detail::Array<arg_t, output_vec_size> value,
C10_DEVICE std::array<arg_t, output_vec_size> accumulate_in_output(
std::array<out_scalar_t*, output_vec_size> out,
std::array<arg_t, output_vec_size> value,
typename std::enable_if_t<can_acc>* = nullptr
) const {
at::detail::Array<arg_t, output_vec_size> ret;
std::array<arg_t, output_vec_size> ret;
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
ret[i] = ops.combine(*(out[i]), value[i]);
@ -727,13 +726,13 @@ struct ReduceOp {
// it's the version of `accumulate_in_output`
// when accumulation in the output is not possible.
template <int output_vec_size, bool can_acc>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
at::detail::Array<out_scalar_t*, output_vec_size>,
at::detail::Array<arg_t, output_vec_size>,
C10_DEVICE std::array<arg_t, output_vec_size> accumulate_in_output(
std::array<out_scalar_t*, output_vec_size>,
std::array<arg_t, output_vec_size>,
typename std::enable_if_t<!can_acc>* = nullptr
) const {
CUDA_KERNEL_ASSERT(false);
return arg_t {};
return {arg_t{}};
}
// This function should never be called --
@ -771,7 +770,7 @@ struct ReduceOp {
}
template <int output_vec_size>
C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
C10_DEVICE void set_results_to_output(std::array<arg_t, output_vec_size> value, std::array<index_t, output_vec_size> base_offset) const {
CUDA_KERNEL_ASSERT(final_output);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
@ -780,10 +779,10 @@ struct ReduceOp {
}
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
C10_DEVICE std::array<arg_t, output_vec_size> global_reduce(std::array<arg_t, output_vec_size> value, std::array<arg_t, output_vec_size> *acc, char* shared_memory) const {
using arg_vec_t = std::array<arg_t, output_vec_size>;
using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
using offset_vec_t = std::array<index_t, output_vec_size>;
arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
index_t output_idx = config.output_idx<output_vec_size>();
@ -808,7 +807,9 @@ struct ReduceOp {
if (is_last_block_done) {
__threadfence(); // complete the acquire pattern after atomic
value = ident;
for (auto &v : value) {
v = ident;
}
if (config.should_block_x_reduce()) {
index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
index_t step = blockDim.x * blockDim.y;
@ -832,7 +833,7 @@ struct ReduceOp {
}
}
}
value = block_y_reduce(value, shared_memory);
value = block_y_reduce<output_vec_size>(value, shared_memory);
if (config.should_block_x_reduce()) {
value = block_x_reduce<output_vec_size>(value, shared_memory);
}