mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-13 01:09:22 +00:00
Implement reduce_matrix_columns() to optimize ReduceSum (#5639)
Implement reduce_matrix_columns() to optimize ReduceSum.
This commit is contained in:
parent
c46515cd56
commit
858040faaa
16 changed files with 965 additions and 472 deletions
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "core/common/optional.h"
|
||||
#include "core/framework/tensor_shape.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ApplicableMatrixReduction get_applicable_matrix_reduction(
|
||||
const cudnnReduceTensorOp_t cudnn_reduce_op,
|
||||
const std::vector<int64_t>& dims, const std::vector<int64_t>& original_axes,
|
||||
int& m_out, int& n_out) {
|
||||
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD) {
|
||||
return ApplicableMatrixReduction::None;
|
||||
}
|
||||
|
||||
const auto rank = gsl::narrow<int64_t>(dims.size());
|
||||
|
||||
// min and max of single contiguous range of axes
|
||||
const auto minmax_axes = [&]() -> optional<std::pair<int64_t, int64_t>> {
|
||||
// empty axes means reduce all dimensions
|
||||
if (original_axes.empty()) {
|
||||
return std::make_pair(int64_t{0}, rank - 1);
|
||||
}
|
||||
|
||||
// normalize axis values and sort
|
||||
const std::vector<int64_t> axes = [&original_axes, rank]() {
|
||||
std::vector<int64_t> result(original_axes);
|
||||
std::for_each(
|
||||
result.begin(), result.end(),
|
||||
[rank](int64_t& axis) { axis = HandleNegativeAxis(axis, rank); });
|
||||
std::sort(result.begin(), result.end());
|
||||
return result;
|
||||
}();
|
||||
|
||||
for (auto a = axes.begin(), b = axes.begin() + 1;
|
||||
b != axes.end();
|
||||
++a, ++b) {
|
||||
ORT_ENFORCE(*a != *b, "axes must not contain duplicate values");
|
||||
if (*a + 1 != *b) { // not contiguous
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(axes.front(), axes.back());
|
||||
}();
|
||||
|
||||
if (!minmax_axes.has_value()) {
|
||||
return ApplicableMatrixReduction::None;
|
||||
}
|
||||
|
||||
const auto& min_axis = minmax_axes.value().first;
|
||||
const auto& max_axis = minmax_axes.value().second;
|
||||
|
||||
// axes from beginning means row reduction, axes to end means column reduction
|
||||
// currently we don't handle axes from beginning to end, but that could be either
|
||||
const bool axes_from_beginning = min_axis == 0;
|
||||
const bool axes_to_end = max_axis == rank - 1;
|
||||
|
||||
// handle axes anchored to one of beginning or end, not both
|
||||
if (axes_from_beginning == axes_to_end) {
|
||||
return ApplicableMatrixReduction::None;
|
||||
}
|
||||
|
||||
const int64_t m_end_axis = axes_from_beginning ? max_axis + 1 : min_axis;
|
||||
|
||||
const TensorShape& shape = TensorShape::ReinterpretBaseType(dims);
|
||||
|
||||
const auto m = shape.SizeToDimension(m_end_axis);
|
||||
const auto n = shape.SizeFromDimension(m_end_axis);
|
||||
|
||||
ORT_ENFORCE(m > 0 && n > 0, "shape must not have negative dimensions: ", shape);
|
||||
|
||||
if (m > std::numeric_limits<int>::max() ||
|
||||
n > std::numeric_limits<int>::max()) {
|
||||
return ApplicableMatrixReduction::None;
|
||||
}
|
||||
|
||||
m_out = gsl::narrow_cast<int>(m);
|
||||
n_out = gsl::narrow_cast<int>(n);
|
||||
|
||||
return axes_from_beginning
|
||||
? ApplicableMatrixReduction::Rows
|
||||
: ApplicableMatrixReduction::Columns;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,51 +1,123 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
#include "reduction_functions.h"
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
#define NUM_ELEMENTS_PER_THREAD 4
|
||||
#define NUM_WARPS_PER_BLOCK 8
|
||||
#define MAX_NUM_BLOCKS 256
|
||||
|
||||
#define ALL_ONE_MASK 0xFFFFFFFF
|
||||
#define ONE_MASK 0x00000001
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cuda/reduction/reduction_utils.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
std::pair<int, int> compute_block_size(int size) {
|
||||
int x = GPU_WARP_SIZE;
|
||||
int y = std::min(NUM_WARPS_PER_BLOCK, std::max(1, size / (NUM_ELEMENTS_PER_THREAD * GPU_WARP_SIZE)));
|
||||
return std::make_pair(x, y);
|
||||
namespace detail {
|
||||
constexpr auto MAX_NUM_ELEMENTS_PER_THREAD = 4;
|
||||
constexpr auto MAX_NUM_WARPS_PER_BLOCK = 8;
|
||||
constexpr auto MAX_NUM_BLOCKS_IN_GRID_ROW = 256;
|
||||
constexpr auto MAX_NUM_GRID_ROWS = 32768;
|
||||
|
||||
dim3 compute_block_dim(int num_cols) {
|
||||
const int x = GPU_WARP_SIZE;
|
||||
const int y = std::min(MAX_NUM_WARPS_PER_BLOCK, std::max(1, num_cols / (MAX_NUM_ELEMENTS_PER_THREAD * x)));
|
||||
return dim3(x, y);
|
||||
}
|
||||
|
||||
int compute_grid_size(int size) {
|
||||
const auto block = compute_block_size(size);
|
||||
return std::min(MAX_NUM_BLOCKS, std::max(1, size / (NUM_ELEMENTS_PER_THREAD * block.first * block.second)));
|
||||
std::pair<dim3, dim3> compute_grid_and_block_dims(int num_rows, int num_cols) {
|
||||
const auto block_dim = compute_block_dim(num_cols);
|
||||
const auto grid_x =
|
||||
std::min<int>(
|
||||
MAX_NUM_BLOCKS_IN_GRID_ROW,
|
||||
std::max<int>(1, num_cols / (MAX_NUM_ELEMENTS_PER_THREAD * block_dim.x * block_dim.y)));
|
||||
const auto grid_y = std::min(MAX_NUM_GRID_ROWS, num_rows);
|
||||
const dim3 grid_dim(grid_x, grid_y);
|
||||
return {grid_dim, block_dim};
|
||||
}
|
||||
|
||||
int compute_reduction_buffer_size(int element_size, int size) {
|
||||
const int num_blocks = compute_grid_size(size);
|
||||
return static_cast<int>(num_blocks * element_size + sizeof(int));
|
||||
uintptr_t round_up_to_aligned(uintptr_t original, size_t alignment) {
|
||||
assert((alignment & (alignment - 1)) == 0);
|
||||
const size_t alignment_mask = ~(alignment - 1);
|
||||
return (original + alignment - 1) & alignment_mask;
|
||||
}
|
||||
|
||||
template<typename TIn, typename TOut, typename TOp, typename TFinalOp, bool DivideResultBySize>
|
||||
__global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output, TOut* buffer) {
|
||||
extern __shared__ unsigned char shared_memory_[];
|
||||
TOut* shared_memory = reinterpret_cast<TOut*>(shared_memory_);
|
||||
// Thread-level indexes:
|
||||
/**
|
||||
* call_reduce_matrix_columns() intermediate buffer layout
|
||||
*
|
||||
* Given buffer element type TBuf, the intermediate buffer layout looks like this:
|
||||
*
|
||||
* -----
|
||||
* m * num_blocks_per_row * sizeof(TBuf) bytes for block reductions per row
|
||||
* alignment padding bytes as needed
|
||||
* m * sizeof(int) bytes for block done counts per row
|
||||
* -----
|
||||
*/
|
||||
|
||||
size_t compute_reduce_matrix_columns_intermediate_buffer_size(
|
||||
int element_size, int num_rows, int num_cols) {
|
||||
ORT_ENFORCE(element_size >= 0 && num_rows >= 0 && num_cols >= 0);
|
||||
|
||||
const auto grid_dim = compute_grid_and_block_dims(num_rows, num_cols).first;
|
||||
|
||||
size_t buffer_size{};
|
||||
|
||||
// at the beginning, for sizing purposes, assume we are aligned
|
||||
buffer_size += static_cast<size_t>(num_rows) * grid_dim.x * element_size;
|
||||
|
||||
buffer_size = round_up_to_aligned(buffer_size, alignof(int));
|
||||
buffer_size += static_cast<size_t>(num_rows) * sizeof(int);
|
||||
|
||||
// add padding to give us room to align
|
||||
buffer_size += alignof(max_align_t) - 1;
|
||||
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
template <typename TBuf>
|
||||
Status get_reduction_buffers(
|
||||
int num_rows, int num_cols, void* buffer, size_t buffer_size,
|
||||
TBuf*& block_reductions_buffer, int*& block_done_counts_buffer) {
|
||||
const auto grid_dim = compute_grid_and_block_dims(num_rows, num_cols).first;
|
||||
|
||||
const uintptr_t begin_addr = reinterpret_cast<uintptr_t>(buffer);
|
||||
const uintptr_t block_reductions_addr =
|
||||
round_up_to_aligned(begin_addr, alignof(TBuf));
|
||||
const uintptr_t block_done_counts_buffer_addr =
|
||||
round_up_to_aligned(
|
||||
block_reductions_addr + static_cast<size_t>(num_rows) * grid_dim.x * sizeof(TBuf), alignof(int));
|
||||
const uintptr_t end_addr =
|
||||
block_done_counts_buffer_addr + static_cast<size_t>(num_rows) * sizeof(int);
|
||||
const size_t required_size = end_addr - begin_addr;
|
||||
|
||||
ORT_RETURN_IF_NOT(
|
||||
required_size <= buffer_size,
|
||||
"Buffer size is too small (", buffer_size, " bytes). ",
|
||||
"At least ", required_size, " bytes are needed from the given base address (", buffer, ").");
|
||||
|
||||
block_reductions_buffer = reinterpret_cast<TBuf*>(block_reductions_addr);
|
||||
block_done_counts_buffer = reinterpret_cast<int*>(block_done_counts_buffer_addr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut, typename TBuf, typename TOp, typename TFinalOp, bool DivideResultBySize>
|
||||
__device__ void reduce_all(
|
||||
const int num_elements, const TIn* const input, TOut* const output,
|
||||
TBuf* const block_reductions_buffer, int* const block_done_count_buffer) {
|
||||
extern __shared__ unsigned char shared_memory_bytes[];
|
||||
TBuf* shared_memory = reinterpret_cast<TBuf*>(shared_memory_bytes);
|
||||
// Thread-level indices:
|
||||
// Linear index of thread in block.
|
||||
const int tid_in_block = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
// Total number of threads in a 2-D block.
|
||||
const int num_threads_in_block = blockDim.x * blockDim.y;
|
||||
|
||||
// Warp-level indexes:
|
||||
// Warp-level indices:
|
||||
// Warp index of thread.
|
||||
const int wid_in_block = tid_in_block / GPU_WARP_SIZE;
|
||||
// Lane index of thread.
|
||||
|
|
@ -53,35 +125,35 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
// Warp count per block.
|
||||
const int num_warps_in_block = num_threads_in_block / GPU_WARP_SIZE;
|
||||
|
||||
// Grid-level indexes:
|
||||
// Linear index of block in grid.
|
||||
const int bid_in_grid = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
// Linear index of thread in grid.
|
||||
const int tid_in_grid = bid_in_grid * (blockDim.x * blockDim.y) + tid_in_block;
|
||||
// Total number of blocks in a 2-D grid.
|
||||
const int num_blocks_in_grid = gridDim.x * gridDim.y;
|
||||
// Total number of threads in a 2-D grid with 2-D blocks.
|
||||
const int num_threads_in_grid = num_blocks_in_grid * num_threads_in_block;
|
||||
// Grid-level indices:
|
||||
// Linear index of block in grid row.
|
||||
const int bid_in_grid_row = blockIdx.x;
|
||||
// Linear index of thread in grid row.
|
||||
const int tid_in_grid_row = bid_in_grid_row * (blockDim.x * blockDim.y) + tid_in_block;
|
||||
// Total number of blocks in a grid row.
|
||||
const int num_blocks_in_grid_row = gridDim.x;
|
||||
// Total number of threads in a grid row with 2-D blocks.
|
||||
const int num_threads_in_grid_row = num_blocks_in_grid_row * num_threads_in_block;
|
||||
|
||||
// Thread-level reduction (storage change: global memory -> register).
|
||||
// One thread reduces NUM_ELEMENTS_PER_THREAD elements to a thread register
|
||||
// One thread reduces MAX_NUM_ELEMENTS_PER_THREAD elements to a thread register
|
||||
// in one iteration.
|
||||
TOut value = 0;
|
||||
for (int id = tid_in_grid; id < size; id += NUM_ELEMENTS_PER_THREAD * num_threads_in_grid) {
|
||||
TOut v[NUM_ELEMENTS_PER_THREAD];
|
||||
TBuf value = 0;
|
||||
for (int id = tid_in_grid_row; id < num_elements; id += MAX_NUM_ELEMENTS_PER_THREAD * num_threads_in_grid_row) {
|
||||
TBuf v[MAX_NUM_ELEMENTS_PER_THREAD];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ELEMENTS_PER_THREAD; i++) {
|
||||
int offset = id + i * num_threads_in_grid;
|
||||
if (offset < size) {
|
||||
v[i] = TOut(TOp()(data[offset]));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MAX_NUM_ELEMENTS_PER_THREAD; i++) {
|
||||
const int offset = id + i * num_threads_in_grid_row;
|
||||
if (offset < num_elements) {
|
||||
v[i] = TOp()(TBuf(input[offset]));
|
||||
} else {
|
||||
v[i] = TOut(0.0f);
|
||||
v[i] = TBuf(0);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ELEMENTS_PER_THREAD; i++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MAX_NUM_ELEMENTS_PER_THREAD; i++) {
|
||||
value += v[i];
|
||||
}
|
||||
}
|
||||
|
|
@ -95,31 +167,30 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
// Warp-level reduction (storage change: register -> register).
|
||||
// The values in a warp will be summed up to a scalar. After warp-level
|
||||
// reduction, each block holds num_warps_in_block values in the shared memory.
|
||||
TOut value_ = value;
|
||||
#pragma unroll
|
||||
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
value_ += WARP_SHFL_DOWN(value_, stride);
|
||||
value += WARP_SHFL_DOWN(value, stride);
|
||||
}
|
||||
|
||||
// Return early if only one warp is used for reduction.
|
||||
// Given a fixed amount of threads, we perfer threads over warps over blocks so that we never have cases such as
|
||||
// Given a fixed amount of threads, we prefer threads over warps over blocks so that we never have cases such as
|
||||
// 1. two blocks and each of them has only 1 warp (32 threads).
|
||||
// 2. two warps and each of them has only 2 threads.
|
||||
if (num_warps_in_block == 1) {
|
||||
if (tid_in_grid == 0) {
|
||||
if (tid_in_grid_row == 0) {
|
||||
// Compilation time if-else branch controlled by template argument can be
|
||||
// optimized out, so there will be no branch in real computation phase.
|
||||
if (DivideResultBySize) {
|
||||
output[0] = TFinalOp()(value_ / TOut(size));
|
||||
output[0] = TFinalOp()(TOut(value) / TOut(num_elements));
|
||||
} else {
|
||||
output[0] = TFinalOp()(value_);
|
||||
output[0] = TFinalOp()(TOut(value));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (lid_in_block == 0) {
|
||||
shared_memory[wid_in_block] = value_;
|
||||
shared_memory[wid_in_block] = value;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
|
@ -129,10 +200,10 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
// Note that the values are stored in the shared memory.
|
||||
// Here we assume that the size of shared_memory is smaller
|
||||
// than num_warps_in_block, so we just keep halving the number
|
||||
// of threads in each iteartion. Our assumption is always true because
|
||||
// of threads in each iteration. Our assumption is always true because
|
||||
// the size of shared_memory equals to the number of warps.
|
||||
#pragma unroll
|
||||
for (int stride = NUM_WARPS_PER_BLOCK / 2; stride > 0; stride /= 2) {
|
||||
for (int stride = MAX_NUM_WARPS_PER_BLOCK / 2; stride > 0; stride /= 2) {
|
||||
if (tid_in_block + stride < num_warps_in_block) {
|
||||
shared_memory[tid_in_block] += shared_memory[tid_in_block + stride];
|
||||
}
|
||||
|
|
@ -140,47 +211,46 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
}
|
||||
|
||||
// Return early if only one block is used for reduction.
|
||||
if (num_blocks_in_grid == 1) {
|
||||
if (tid_in_grid == 0) {
|
||||
if (num_blocks_in_grid_row == 1) {
|
||||
if (tid_in_grid_row == 0) {
|
||||
// Compilation time if-else branch controlled by template argument can be
|
||||
// optimized out, so there will be no branch in real computation phase.
|
||||
if (DivideResultBySize) {
|
||||
output[0] = TFinalOp()(shared_memory[0] / TOut(size));
|
||||
output[0] = TFinalOp()(TOut(shared_memory[0]) / TOut(num_elements));
|
||||
} else {
|
||||
output[0] = TFinalOp()(shared_memory[0]);
|
||||
output[0] = TFinalOp()(TOut(shared_memory[0]));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (tid_in_block == 0) {
|
||||
buffer[bid_in_grid] = shared_memory[0];
|
||||
block_reductions_buffer[bid_in_grid_row] = shared_memory[0];
|
||||
}
|
||||
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
// Grid-level reduciton. We use the last block to sum up values
|
||||
// stored in the global buffer.
|
||||
// Grid-level reduction. We use the last block to sum up values
|
||||
// stored in the global block_reductions_buffer.
|
||||
__shared__ bool is_last_block_done;
|
||||
|
||||
if (tid_in_block == 0) {
|
||||
int* p_lock = reinterpret_cast<int*>(buffer + num_blocks_in_grid);
|
||||
int count = atomicAdd(p_lock, 1);
|
||||
is_last_block_done = (count == (num_blocks_in_grid - 1));
|
||||
const int count = atomicAdd(block_done_count_buffer, 1);
|
||||
is_last_block_done = (count == (num_blocks_in_grid_row - 1));
|
||||
}
|
||||
|
||||
// All threads in each block see if they belong the last active block
|
||||
// (i.e., the value of is_last_block_done).
|
||||
__syncthreads();
|
||||
|
||||
// Only the block which saw that count equals to num_blocks_in_grid - 1 can
|
||||
// Only the block which saw that count equals to num_blocks_in_grid_row - 1 can
|
||||
// enter the following block.
|
||||
if (is_last_block_done) {
|
||||
const int pow2_bound = least_pow2_bound(num_blocks_in_grid);
|
||||
const int pow2_bound = least_pow2_bound(num_blocks_in_grid_row);
|
||||
for (int stride = pow2_bound / 2; stride > 0; stride /= 2) {
|
||||
if (tid_in_block < stride && tid_in_block + stride < num_blocks_in_grid) {
|
||||
buffer[tid_in_block] += buffer[tid_in_block + stride];
|
||||
if (tid_in_block < stride && tid_in_block + stride < num_blocks_in_grid_row) {
|
||||
block_reductions_buffer[tid_in_block] += block_reductions_buffer[tid_in_block + stride];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
|
@ -190,107 +260,124 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
// Compilation time if-else branch controlled by template argument can be
|
||||
// optimized out, so there will be no branch in real computation phase.
|
||||
if (DivideResultBySize) {
|
||||
output[0] = TFinalOp()(buffer[0] / TOut(size));
|
||||
output[0] = TFinalOp()(TOut(block_reductions_buffer[0]) / TOut(num_elements));
|
||||
} else {
|
||||
output[0] = TFinalOp()(buffer[0]);
|
||||
output[0] = TFinalOp()(TOut(block_reductions_buffer[0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TIn, typename TOut, typename TOp, typename TFinalOp, bool DivideResultBySize>
|
||||
void call_reduce_all_kernel(const TIn *data, TOut *output, int size, TOut *buffer) {
|
||||
const auto block_size = compute_block_size(size);
|
||||
const int num_blocks = compute_grid_size(size);
|
||||
const dim3 block(block_size.first, block_size.second, 1);
|
||||
const dim3 grid(num_blocks, 1, 1);
|
||||
template <typename TIn, typename TOut, typename TBuf, typename TOp, typename TFinalOp, bool DivideResultBySize>
|
||||
__global__ void reduce_matrix_columns_kernel(
|
||||
const int num_rows, const int num_cols, const TIn* const input, TOut* const output,
|
||||
TBuf* const block_reductions_buffer, int* const block_done_counts_buffer) {
|
||||
const int num_blocks_in_grid_row = gridDim.x;
|
||||
const int row_id_in_grid = blockIdx.y;
|
||||
const int num_grid_rows = gridDim.y;
|
||||
|
||||
// If more than one blocks are used, then inter-blocks reduction is needed.
|
||||
if (num_blocks != 1) {
|
||||
CUDA_CALL_THROW(cudaMemsetAsync(buffer + num_blocks, 0, sizeof(int)));
|
||||
// one row per iteration
|
||||
// row_id is int64_t to avoid int overflow in offset calculations
|
||||
for (int64_t row_id = row_id_in_grid; row_id < num_rows; row_id += num_grid_rows) {
|
||||
const TIn* const row_data = input + row_id * num_cols;
|
||||
TOut* const row_output = output + row_id;
|
||||
TBuf* const row_block_reductions_buffer = block_reductions_buffer + row_id * num_blocks_in_grid_row;
|
||||
int* const row_block_done_counts_buffer = block_done_counts_buffer + row_id;
|
||||
|
||||
reduce_all<TIn, TOut, TBuf, TOp, TFinalOp, DivideResultBySize>(
|
||||
num_cols, row_data, row_output,
|
||||
row_block_reductions_buffer, row_block_done_counts_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut, typename TOp, typename TFinalOp, bool DivideResultBySize>
|
||||
Status call_reduce_matrix_columns(
|
||||
const TIn* input, TOut* output, const int num_rows, const int num_cols, void* buffer, size_t buffer_size) {
|
||||
ORT_ENFORCE(num_rows >= 0 && num_cols >= 0);
|
||||
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
|
||||
const auto grid_and_block_dims = compute_grid_and_block_dims(num_rows, num_cols);
|
||||
const dim3& grid_dim = grid_and_block_dims.first;
|
||||
const dim3& block_dim = grid_and_block_dims.second;
|
||||
|
||||
TBuf* block_reductions_buffer;
|
||||
int* block_done_counts_buffer;
|
||||
ORT_RETURN_IF_ERROR(get_reduction_buffers(
|
||||
num_rows, num_cols, buffer, buffer_size,
|
||||
block_reductions_buffer, block_done_counts_buffer));
|
||||
|
||||
// If more than one block is used per grid row, then inter-block reduction is needed.
|
||||
if (grid_dim.x > 1) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(block_done_counts_buffer, 0, num_rows * sizeof(int)));
|
||||
}
|
||||
|
||||
const int shared_mem_size = sizeof(TOut) * block_size.first * block_size.second / GPU_WARP_SIZE;
|
||||
reduce_all_kernel<TIn, TOut, TOp, TFinalOp, DivideResultBySize><<<grid, block, shared_mem_size>>>(size, data, output, buffer);
|
||||
const int shared_mem_size = sizeof(TBuf) * block_dim.x * block_dim.y / GPU_WARP_SIZE;
|
||||
reduce_matrix_columns_kernel<TIn, TOut, TBuf, TOp, TFinalOp, DivideResultBySize>
|
||||
<<<grid_dim, block_dim, shared_mem_size>>>(
|
||||
num_rows, num_cols, input, output, block_reductions_buffer, block_done_counts_buffer);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
Status reduce_sum(
|
||||
const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) {
|
||||
return detail::call_reduce_matrix_columns<TIn, TOut, Identity, Identity, false>(
|
||||
input, output, 1, size, buffer, buffer_size);
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_sum(const TIn* data, TOut* output, int size, TOut* buffer) {
|
||||
call_reduce_all_kernel<TIn, TOut, Cast<TOut, TIn>, Identity<TOut>, false>(
|
||||
data, output, size, buffer);
|
||||
Status reduce_square_sum(
|
||||
const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) {
|
||||
return detail::call_reduce_matrix_columns<TIn, TOut, Square, Identity, false>(
|
||||
input, output, 1, size, buffer, buffer_size);
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_square_sum(const TIn* data, TOut* output, int size, TOut* buffer) {
|
||||
call_reduce_all_kernel<TIn, TOut, Square<TOut, TIn>, Identity<TOut>, false>(
|
||||
data, output, size, buffer);
|
||||
Status reduce_l2_norm(
|
||||
const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) {
|
||||
return detail::call_reduce_matrix_columns<TIn, TOut, Square, Sqrt, false>(
|
||||
input, output, 1, size, buffer, buffer_size);
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_l2_norm(const TIn* data, TOut* output, int size, TOut* buffer) {
|
||||
call_reduce_all_kernel<TIn, TOut, Square<TOut, TIn>, Sqrt<TOut>, false>(
|
||||
data, output, size, buffer);
|
||||
Status reduce_mean(
|
||||
const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) {
|
||||
return detail::call_reduce_matrix_columns<TIn, TOut, Identity, Identity, true>(
|
||||
input, output, 1, size, buffer, buffer_size);
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_mean(const TIn* data, TOut* output, int size, TOut* buffer) {
|
||||
call_reduce_all_kernel<TIn, TOut, Cast<TOut, TIn>, Identity<TOut>, true>(
|
||||
data, output, size, buffer);
|
||||
}
|
||||
#define INSTANTIATE_REDUCE_SUM(TIn, TOut) \
|
||||
template Status reduce_sum<TIn, TOut>(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_SUM(half, float);
|
||||
INSTANTIATE_REDUCE_SUM(float, float);
|
||||
INSTANTIATE_REDUCE_SUM(double, double);
|
||||
#undef INSTANTIATE_REDUCE_SUM
|
||||
|
||||
template void reduce_sum<half, float>(
|
||||
const half* data, float* output, int size, float* buffer);
|
||||
template void reduce_sum<float, float>(
|
||||
const float* data, float* output, int size, float* buffer);
|
||||
template void reduce_sum<double, double>(
|
||||
const double* data, double* output, int size, double* buffer);
|
||||
#define INSTANTIATE_REDUCE_SQUARE_SUM(TIn, TOut) \
|
||||
template Status reduce_square_sum<TIn, TOut>(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_SQUARE_SUM(half, float);
|
||||
INSTANTIATE_REDUCE_SQUARE_SUM(float, float);
|
||||
INSTANTIATE_REDUCE_SQUARE_SUM(double, double);
|
||||
#undef INSTANTIATE_REDUCE_SQUARE_SUM
|
||||
|
||||
template void reduce_square_sum<half, float>(
|
||||
const half* data, float* output, int size, float* buffer);
|
||||
template void reduce_square_sum<float, float>(
|
||||
const float* data, float* output, int size, float* buffer);
|
||||
template void reduce_square_sum<double, double>(
|
||||
const double* data, double* output, int size, double* buffer);
|
||||
#define INSTANTIATE_REDUCE_L2_NORM(TIn, TOut) \
|
||||
template Status reduce_l2_norm<TIn, TOut>(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_L2_NORM(half, float);
|
||||
INSTANTIATE_REDUCE_L2_NORM(float, float);
|
||||
INSTANTIATE_REDUCE_L2_NORM(double, double);
|
||||
#undef INSTANTIATE_REDUCE_L2_NORM
|
||||
|
||||
template void reduce_l2_norm<half, float>(
|
||||
const half* data, float* output, int size, float* buffer);
|
||||
template void reduce_l2_norm<float, float>(
|
||||
const float* data, float* output, int size, float* buffer);
|
||||
template void reduce_l2_norm<double, double>(
|
||||
const double* data, double* output, int size, double* buffer);
|
||||
|
||||
template void reduce_mean<half, float>(
|
||||
const half* data, float* output, int size, float* buffer);
|
||||
template void reduce_mean<float, float>(
|
||||
const float* data, float* output, int size, float* buffer);
|
||||
template void reduce_mean<double, double>(
|
||||
const double* data, double* output, int size, double* buffer);
|
||||
|
||||
bool is_matrix_row_reduction(
|
||||
const cudnnReduceTensorOp_t cudnn_reduce_op,
|
||||
const int m,
|
||||
const int n,
|
||||
const size_t rank,
|
||||
std::vector<int64_t> axes) {
|
||||
if (m < 1)
|
||||
return false;
|
||||
|
||||
if (n < 1)
|
||||
return false;
|
||||
|
||||
if (rank < 2)
|
||||
return false;
|
||||
|
||||
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD)
|
||||
return false;
|
||||
|
||||
//empty axes, default reduction
|
||||
if (axes.size() < 1)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
#define INSTANTIATE_REDUCE_MEAN(TIn, TOut) \
|
||||
template Status reduce_mean<TIn, TOut>(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_MEAN(half, float);
|
||||
INSTANTIATE_REDUCE_MEAN(float, float);
|
||||
INSTANTIATE_REDUCE_MEAN(double, double);
|
||||
#undef INSTANTIATE_REDUCE_MEAN
|
||||
|
||||
namespace detail {
|
||||
template <typename TIn, typename TOut, typename TBuf>
|
||||
__global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m, int n) {
|
||||
constexpr int x_load_count_per_thread = 1;
|
||||
|
|
@ -304,8 +391,8 @@ __global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m,
|
|||
const int tid_in_block = threadIdx.x + blockDim.x * threadIdx.y;
|
||||
|
||||
// Shape is blockDim.y-by-blockDim.x and element type is TBuf.
|
||||
extern __shared__ unsigned char shared_memory_[];
|
||||
TBuf* shared_memory = reinterpret_cast<TBuf*>(shared_memory_);
|
||||
extern __shared__ unsigned char shared_memory_bytes[];
|
||||
TBuf* shared_memory = reinterpret_cast<TBuf*>(shared_memory_bytes);
|
||||
|
||||
// to prevent int overflow in index calculation for input size m*n
|
||||
const int64_t n_int64 = static_cast<int64_t>(n);
|
||||
|
|
@ -348,11 +435,14 @@ __global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m,
|
|||
}
|
||||
}
|
||||
|
||||
// This function reduces the given input tensor along all but the last axis.
|
||||
// For example, [N, C, H, W]-tensor may lead to a output [W]-tensor.
|
||||
// It's implementation is in reduction_ops.cu and called in reduction_ops.cc.
|
||||
template <typename TIn, typename TOut, typename TBuf>
|
||||
void call_reduce_matrix_rows(const TIn* input, TOut* output, int m, int n) {
|
||||
Status call_reduce_matrix_rows(const TIn* input, TOut* output, int m, int n, bool reset_initial_output) {
|
||||
ORT_ENFORCE(m >= 0 && n >= 0);
|
||||
|
||||
if (reset_initial_output) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output, 0, n * sizeof(TOut)));
|
||||
}
|
||||
|
||||
constexpr int max_num_threads_in_block = 512;
|
||||
constexpr int max_num_blocks_in_grid = 512;
|
||||
constexpr int load_count_per_thread = 4;
|
||||
|
|
@ -367,22 +457,36 @@ void call_reduce_matrix_rows(const TIn* input, TOut* output, int m, int n) {
|
|||
|
||||
reduce_matrix_rows_kernel<TIn, TOut, TBuf><<<grid, block, block.y * block.x * sizeof(TBuf)>>>(
|
||||
input, output, m, n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_matrix_rows(const TIn* data, TOut* output, int m, int n) {
|
||||
call_reduce_matrix_rows<TIn, TOut, TOut>(data, output, m, n);
|
||||
Status reduce_matrix_rows(const TIn* input, TOut* output, int m, int n, bool reset_initial_output) {
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
return detail::call_reduce_matrix_rows<TIn, TOut, TBuf>(input, output, m, n, reset_initial_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void reduce_matrix_rows<half, half>(const half* data, half* output, int m, int n) {
|
||||
call_reduce_matrix_rows<half, half, float>(data, output, m, n);
|
||||
#define INSTANTIATE_REDUCE_MATRIX_ROWS(T) \
|
||||
template Status reduce_matrix_rows<T, T>(const T* input, T* output, int m, int n, bool reset_initial_output)
|
||||
INSTANTIATE_REDUCE_MATRIX_ROWS(half);
|
||||
INSTANTIATE_REDUCE_MATRIX_ROWS(float);
|
||||
INSTANTIATE_REDUCE_MATRIX_ROWS(double);
|
||||
#undef INSTANTIATE_REDUCE_MATRIX_ROWS
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
Status reduce_matrix_columns(const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size) {
|
||||
return detail::call_reduce_matrix_columns<TIn, TOut, Identity, Identity, false>(
|
||||
input, output, m, n, buffer, buffer_size);
|
||||
}
|
||||
|
||||
template void reduce_matrix_rows<float, float>(
|
||||
const float* data, float* output, int m, int n);
|
||||
template void reduce_matrix_rows<double, double>(
|
||||
const double* data, double* output, int m, int n);
|
||||
#define INSTANTIATE_REDUCE_MATRIX_COLUMNS(T) \
|
||||
template Status reduce_matrix_columns<T, T>(const T* input, T* output, int m, int n, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_MATRIX_COLUMNS(half);
|
||||
INSTANTIATE_REDUCE_MATRIX_COLUMNS(float);
|
||||
INSTANTIATE_REDUCE_MATRIX_COLUMNS(double);
|
||||
#undef INSTANTIATE_REDUCE_MATRIX_COLUMNS
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,35 +2,106 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/accumulation_type.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
int compute_reduction_buffer_size(int element_size, int size);
|
||||
namespace detail {
|
||||
size_t compute_reduce_matrix_columns_intermediate_buffer_size(
|
||||
int element_size, int num_rows, int num_cols);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the size in bytes of the intermediate buffer needed by reduce_matrix_columns().
|
||||
* @tparam TIn The input data type.
|
||||
* @param m The number of matrix rows.
|
||||
* @param n The number of matrix columns.
|
||||
* @return The size of the intermediate buffer.
|
||||
*/
|
||||
template <typename TIn>
|
||||
size_t compute_reduce_matrix_columns_buffer_size(int m, int n) {
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
return detail::compute_reduce_matrix_columns_intermediate_buffer_size(
|
||||
sizeof(TBuf), m, n);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the size in bytes of the intermediate buffer needed by the reduce_x() functions.
|
||||
* @tparam TIn The input data type.
|
||||
* @param size The number of elements.
|
||||
* @return The size of the intermediate buffer.
|
||||
*/
|
||||
template <typename TIn>
|
||||
size_t compute_reduction_buffer_size(int size) {
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
return detail::compute_reduce_matrix_columns_intermediate_buffer_size(
|
||||
sizeof(TBuf), 1, size);
|
||||
}
|
||||
|
||||
/** Computes the sum of the given elements. */
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_sum(const TIn* input, TOut* output, int size, TOut* buffer);
|
||||
Status reduce_sum(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size);
|
||||
|
||||
/** Computes the sum of the squares of the given elements. */
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_square_sum(const TIn* input, TOut* output, int size, TOut* buffer);
|
||||
Status reduce_square_sum(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size);
|
||||
|
||||
/** Computes the L2 norm of the given elements. */
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_l2_norm(const TIn* input, TOut* output, int size, TOut* buffer);
|
||||
Status reduce_l2_norm(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size);
|
||||
|
||||
/** Computes the mean of the given elements. */
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_mean(const TIn* data, TOut* output, int size, TOut* buffer);
|
||||
Status reduce_mean(const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size);
|
||||
|
||||
// Determine if a CUDNN reduction can be computed by reduce_matrix_rows.
|
||||
bool is_matrix_row_reduction(
|
||||
enum class ApplicableMatrixReduction {
|
||||
// can use reduce_matrix_rows()
|
||||
Rows,
|
||||
// can use reduce_matrix_columns()
|
||||
Columns,
|
||||
// no optimized matrix reduction function applies
|
||||
None,
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines whether a cuDNN reduction can be computed by an optimized matrix reduction function.
|
||||
* @param cudnn_reduce_op The cuDNN reduction op type.
|
||||
* @param dims The input dimensions.
|
||||
* @param axes The reduction axes.
|
||||
* @param[out] m If matrix reduction is possible, the number of matrix rows to use.
|
||||
* @param[out] n If matrix reduction is possible, the number of matrix columns to use.
|
||||
* @return The type of matrix reduction that can be done.
|
||||
*/
|
||||
ApplicableMatrixReduction get_applicable_matrix_reduction(
|
||||
const cudnnReduceTensorOp_t cudnn_reduce_op,
|
||||
const int m,
|
||||
const int n,
|
||||
const size_t rank,
|
||||
std::vector<int64_t> axes);
|
||||
const std::vector<int64_t>& dims, const std::vector<int64_t>& axes,
|
||||
int& m, int& n);
|
||||
|
||||
/**
|
||||
* Reduces the rows in a row-major matrix to a single row containing the sum of each column.
|
||||
* @param input The input data.
|
||||
* @param output The output data.
|
||||
* @param m The number of matrix rows.
|
||||
* @param n The number of matrix columns.
|
||||
* @param reset_initial_output Whether to reset (i.e., zero) the output values first.
|
||||
*/
|
||||
template <typename TIn, typename TOut>
|
||||
void reduce_matrix_rows(const TIn* data, TOut* output, int m, int n);
|
||||
Status reduce_matrix_rows(const TIn* input, TOut* output, int m, int n, bool reset_initial_output = true);
|
||||
|
||||
/**
|
||||
* Reduces the columns in a row-major matrix to a single column containing the sum of each row.
|
||||
* @param input The input data.
|
||||
* @param output The output data.
|
||||
* @param m The number of matrix rows.
|
||||
* @param n The number of matrix columns.
|
||||
* @param buffer The intermediate buffer.
|
||||
* @param buffer_size The size of the intermediate buffer in bytes.
|
||||
*/
|
||||
template <typename TIn, typename TOut>
|
||||
Status reduce_matrix_columns(const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "reduction_ops.h"
|
||||
#include "core/providers/cuda/reduction/reduction_ops.h"
|
||||
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
|
@ -27,7 +29,7 @@ namespace cuda {
|
|||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
11, 12, \
|
||||
11, 12, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
|
|
@ -135,22 +137,27 @@ Status ReduceKernel<allow_multi_axes>::ReduceKernelShared(
|
|||
cudnnReduceTensorOp_t cudnn_reduce_op,
|
||||
std::vector<int64_t>& output_dims) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
typedef typename ToCudaType<OutT>::MappedType CudaOutT;
|
||||
cudnnDataType_t cudnn_type_X = CudnnTensor::GetDataType<CudaT>();
|
||||
const auto rank = input_shape.NumDimensions();
|
||||
|
||||
// Block of fast matrix row reduction.
|
||||
const auto stride = input_shape[input_shape.NumDimensions() - 1];
|
||||
const auto reduction_size = input_shape.Size() / stride;
|
||||
if (fast_reduction_ && reduction_size <= std::numeric_limits<int>::max() && stride <= std::numeric_limits<int>::max() &&
|
||||
is_matrix_row_reduction(cudnn_reduce_op,
|
||||
static_cast<int>(reduction_size),
|
||||
static_cast<int>(stride), rank, axes_)) {
|
||||
reduce_matrix_rows(
|
||||
reinterpret_cast<const CudaT*>(X),
|
||||
reinterpret_cast<CudaT*>(Y),
|
||||
static_cast<int>(reduction_size),
|
||||
static_cast<int>(stride));
|
||||
return Status::OK();
|
||||
// Block of fast matrix reduction.
|
||||
if (fast_reduction_) {
|
||||
int m{}, n{};
|
||||
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
|
||||
cudnn_reduce_op, input_shape.GetDims(), axes_, m, n);
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
return reduce_matrix_rows(
|
||||
reinterpret_cast<const CudaT*>(X),
|
||||
reinterpret_cast<CudaOutT*>(Y),
|
||||
m, n, false);
|
||||
}
|
||||
case ApplicableMatrixReduction::Columns:
|
||||
// don't call reduce_matrix_columns() since it will reset initial output data
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
|
|
@ -335,11 +342,8 @@ Status PrepareForReduce(const Tensor* X,
|
|||
ORT_ENFORCE(nullptr != X);
|
||||
|
||||
const TensorShape& input_shape = input_shape_override ? *input_shape_override : X->Shape();
|
||||
int64_t rank = static_cast<int64_t>(input_shape.NumDimensions());
|
||||
prepare_reduce_metadata.rank = rank;
|
||||
const int64_t rank = gsl::narrow<int64_t>(input_shape.NumDimensions());
|
||||
prepare_reduce_metadata.input_count = input_shape.Size();
|
||||
prepare_reduce_metadata.stride = (rank > 0) ? input_shape[input_shape.NumDimensions() - 1] : 1;
|
||||
prepare_reduce_metadata.contiguous_axes = false;
|
||||
|
||||
if (rank > 8) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "cuDNN only supports up to 8-D tensors in reduction");
|
||||
|
|
@ -349,36 +353,15 @@ Status PrepareForReduce(const Tensor* X,
|
|||
std::vector<bool> reduced(rank, false);
|
||||
prepare_reduce_metadata.output_dims.reserve(input_dims.size());
|
||||
if (axes.size() > 0) {
|
||||
int64_t reduced_axis;
|
||||
std::vector<uint64_t> reduced_axes(axes.size());
|
||||
prepare_reduce_metadata.output_dims = input_dims;
|
||||
for (size_t i = 0; i < axes.size(); i++) {
|
||||
reduced_axis = axes[i];
|
||||
const int64_t axis = HandleNegativeAxis(reduced_axis, rank);
|
||||
for (auto axis : axes) {
|
||||
axis = HandleNegativeAxis(axis, rank);
|
||||
ORT_ENFORCE(input_dims[axis] != 0,
|
||||
"Can't reduce on dim with value of 0 if 'keepdims' is false. "
|
||||
"Invalid output shape would be produced. input_shape:",
|
||||
input_shape);
|
||||
prepare_reduce_metadata.output_dims[axis] = 1;
|
||||
reduced[axis] = true;
|
||||
reduced_axes[i] = axis;
|
||||
}
|
||||
|
||||
bool contiguous_axes = true;
|
||||
std::sort(reduced_axes.begin(), reduced_axes.end());
|
||||
for (size_t i = 0; i < reduced_axes.size(); i++) {
|
||||
if (reduced_axes[i] != i) {
|
||||
contiguous_axes = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
int64_t stride = 1;
|
||||
if (contiguous_axes) {
|
||||
for (size_t s = rank - 1; s >= reduced_axes.size(); s--) {
|
||||
stride *= input_dims[s];
|
||||
}
|
||||
prepare_reduce_metadata.stride = stride;
|
||||
prepare_reduce_metadata.contiguous_axes = true;
|
||||
}
|
||||
} else {
|
||||
// no axes provided (i.e.) default axes => reduce on all dims
|
||||
|
|
@ -416,10 +399,6 @@ Status PrepareForReduce(const Tensor* X,
|
|||
|
||||
prepare_reduce_metadata.output_count = TensorShape(prepare_reduce_metadata.output_dims).Size();
|
||||
|
||||
if (prepare_reduce_metadata.rank == 0) {
|
||||
prepare_reduce_metadata.rank = 1;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -438,8 +417,6 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
std::vector<int64_t>& output_dims = prepare_reduce_metadata.output_dims;
|
||||
std::vector<int64_t>& input_dims_cudnn = prepare_reduce_metadata.input_dims_cudnn;
|
||||
std::vector<int64_t>& output_dims_cudnn = prepare_reduce_metadata.output_dims_cudnn;
|
||||
int64_t rank = prepare_reduce_metadata.rank;
|
||||
int64_t stride = prepare_reduce_metadata.stride;
|
||||
|
||||
// special case when there is a dim value of 0 in the shape.
|
||||
if (input_count == 0) {
|
||||
|
|
@ -447,6 +424,31 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Block of fast matrix reduction.
|
||||
if (fast_reduction) {
|
||||
int m{}, n{};
|
||||
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
|
||||
cudnn_reduce_op, input_shape.GetDims(), axes, m, n);
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
return reduce_matrix_rows(
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
|
||||
m, n);
|
||||
}
|
||||
case ApplicableMatrixReduction::Columns: {
|
||||
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<CudaT>(m, n);
|
||||
auto buffer = cuda_ep.GetScratchBuffer<void>(buffer_size_bytes);
|
||||
return reduce_matrix_columns(
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
|
||||
m, n, buffer.get(), buffer_size_bytes);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// This reduction keep adding values to this buffer. If a non-zero value, say 1000, is here, the sum will start with 1000.
|
||||
// Therefore zeroing out the memory is required
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes()));
|
||||
|
|
@ -454,22 +456,6 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
IAllocatorUniquePtr<float> temp_X;
|
||||
cudnnDataType_t cudnn_type_X = CudnnTensor::GetDataType<CudaT>();
|
||||
|
||||
// Block of fast matrix row reduction.
|
||||
// It relies on new atomicAdd for half type, so old CUDA can't use it.
|
||||
const auto reduction_size = input_count / stride;
|
||||
if (!std::is_same<T, int8_t>::value && !std::is_same<T, uint8_t>::value) {
|
||||
if (fast_reduction && reduction_size <= std::numeric_limits<int>::max() && stride <= std::numeric_limits<int>::max() &&
|
||||
prepare_reduce_metadata.contiguous_axes &&
|
||||
is_matrix_row_reduction(cudnn_reduce_op, static_cast<int>(reduction_size), static_cast<int>(stride), rank, axes)) {
|
||||
reduce_matrix_rows(
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
|
||||
static_cast<int>(reduction_size),
|
||||
static_cast<int>(stride));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same<T, MLFloat16>::value) {
|
||||
// ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn
|
||||
temp_X = cuda_ep.GetScratchBuffer<float>(input_count);
|
||||
|
|
@ -677,7 +663,7 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
|
|||
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
|
||||
bool fast_reduction = fast_reduction_;
|
||||
if (fast_reduction) {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto ctx_internal = dynamic_cast<OpKernelContextInternal*>(ctx);
|
||||
if (ctx_internal && ctx_internal->GetUseDeterministicCompute())
|
||||
fast_reduction = false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,9 +33,6 @@ struct PrepareReduceMetadata {
|
|||
std::vector<int64_t> squeezed_output_dims;
|
||||
std::vector<int64_t> input_dims_cudnn;
|
||||
std::vector<int64_t> output_dims_cudnn;
|
||||
int64_t rank;
|
||||
int64_t stride;
|
||||
bool contiguous_axes;
|
||||
};
|
||||
|
||||
template <bool allow_multi_axes>
|
||||
|
|
|
|||
|
|
@ -18,51 +18,26 @@ __forceinline__ __host__ __device__ int least_pow2_bound(int value) {
|
|||
return static_cast<int>(++value_);
|
||||
}
|
||||
|
||||
template <typename TAccumulated, typename TValue>
|
||||
struct Cast {
|
||||
__forceinline__ __device__ TAccumulated operator()(const TValue& value) {
|
||||
return TAccumulated(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TAccumulated, typename TValue>
|
||||
struct Square {
|
||||
__forceinline__ __device__ TAccumulated operator()(const TValue& value) {
|
||||
return TAccumulated(value) * TAccumulated(value);
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T operator()(const T& value) {
|
||||
return value * value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TAccumulated, typename TValue>
|
||||
struct Abs {
|
||||
__forceinline__ __device__ TAccumulated operator()(const TValue& value) {
|
||||
TAccumulated value_ = TAccumulated(value);
|
||||
return value_ > TAccumulated(0) ? value_ : -value_;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T operator()(const T& value) {
|
||||
return _Sqrt(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Identity {
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T operator()(const T& value) {
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ToBuffer {
|
||||
typedef T Type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ToBuffer<half> {
|
||||
typedef float Type;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,9 +8,6 @@
|
|||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/providers/cpu/reduction/reduction_test_cases.h"
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -1944,80 +1941,6 @@ TEST(ReductionOpTest, ArgMin_int32_select_last) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kNGraphExecutionProvider});
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
void test_reduce_apis(int64_t size, float relative_error_tolerance = 1e-4f) {
|
||||
float output_sum = 0;
|
||||
float output_square_sum = 0;
|
||||
float output_mean = 0;
|
||||
float expected_output_sum = 0;
|
||||
float expected_output_square_sum = 0;
|
||||
float expected_output_mean = 0;
|
||||
const std::vector<int64_t> shape = {size};
|
||||
RandomValueGenerator random_value_generator{};
|
||||
const auto input = random_value_generator.Uniform<float>(shape, 0.1f, 1.0f);
|
||||
for (const auto input_value : input) {
|
||||
expected_output_sum += input_value;
|
||||
expected_output_square_sum += input_value * input_value;
|
||||
expected_output_mean += input_value / float(size);
|
||||
}
|
||||
const int buffer_size_in_byte = onnxruntime::cuda::compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(float)), static_cast<int>(size));
|
||||
|
||||
float* device_input = NULL;
|
||||
float* device_output_sum = NULL;
|
||||
float* device_output_square_sum = NULL;
|
||||
float* device_output_mean = NULL;
|
||||
float* buffer = NULL;
|
||||
|
||||
cudaMalloc((void**)&device_input, size * sizeof(float));
|
||||
cudaMalloc((void**)&device_output_sum, 1 * sizeof(float));
|
||||
cudaMalloc((void**)&device_output_square_sum, 1 * sizeof(float));
|
||||
cudaMalloc((void**)&device_output_mean, 1 * sizeof(float));
|
||||
cudaMalloc((void**)&buffer, buffer_size_in_byte);
|
||||
|
||||
cudaMemcpy(device_input, input.data(), size * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
onnxruntime::cuda::reduce_sum(device_input,
|
||||
device_output_sum,
|
||||
static_cast<int>(size),
|
||||
buffer);
|
||||
onnxruntime::cuda::reduce_square_sum(device_input,
|
||||
device_output_square_sum,
|
||||
static_cast<int>(size), buffer);
|
||||
onnxruntime::cuda::reduce_mean(
|
||||
device_input,
|
||||
device_output_mean,
|
||||
static_cast<int>(size),
|
||||
buffer);
|
||||
|
||||
cudaMemcpy(&output_sum, device_output_sum, 1 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(&output_square_sum, device_output_square_sum, 1 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(&output_mean, device_output_mean, 1 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
cudaFree(device_input);
|
||||
cudaFree(buffer);
|
||||
cudaFree(device_output_sum);
|
||||
cudaFree(device_output_square_sum);
|
||||
cudaFree(device_output_mean);
|
||||
|
||||
EXPECT_LT(std::abs(output_sum - expected_output_sum) / expected_output_sum, relative_error_tolerance);
|
||||
EXPECT_LT(std::abs(output_square_sum - expected_output_square_sum) / expected_output_square_sum,
|
||||
relative_error_tolerance);
|
||||
EXPECT_LT(std::abs(output_mean - expected_output_mean) / expected_output_mean, relative_error_tolerance);
|
||||
}
|
||||
|
||||
TEST(ReduceApiTest, Sum) {
|
||||
test_reduce_apis(3);
|
||||
test_reduce_apis(19);
|
||||
test_reduce_apis(123);
|
||||
test_reduce_apis(1128);
|
||||
test_reduce_apis(5566);
|
||||
test_reduce_apis(941736, 2e-4f);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(ReductionOpTest, ArgMin_int32_neg_axis) {
|
||||
OpTester test("ArgMin");
|
||||
test.AddAttribute("axis", (int64_t)(-3));
|
||||
|
|
|
|||
308
onnxruntime/test/providers/cuda/reduction_functions_test.cc
Normal file
308
onnxruntime/test/providers/cuda/reduction_functions_test.cc
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
struct DeviceMemoryDeleter {
|
||||
template <typename T>
|
||||
void operator()(T* p) {
|
||||
cudaFree(p);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<T, DeviceMemoryDeleter> AllocateDeviceMemory(size_t n = 1) {
|
||||
T* p{};
|
||||
cudaMalloc(&p, n * sizeof(T));
|
||||
return std::unique_ptr<T, DeviceMemoryDeleter>(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckDeviceValues(size_t n, const T* d_actual, const T* expected, float relative_error_tolerance) {
|
||||
std::vector<T> actual(n);
|
||||
cudaMemcpy(actual.data(), d_actual, n * sizeof(T), cudaMemcpyDeviceToHost);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
EXPECT_LE(std::abs(actual[i] - expected[i]) / expected[i], relative_error_tolerance)
|
||||
<< "i: " << i << ", actual[i]: " << actual[i] << ", expected[i]: " << expected[i];
|
||||
}
|
||||
}
|
||||
|
||||
void TestReduceRowToScalarApis(int size, float relative_error_tolerance = 1e-4f) {
|
||||
SCOPED_TRACE(MakeString("size: ", size));
|
||||
|
||||
float expected_output_sum = 0;
|
||||
float expected_output_square_sum = 0;
|
||||
float expected_output_mean = 0;
|
||||
const std::vector<int64_t> shape = {size};
|
||||
RandomValueGenerator random_value_generator{};
|
||||
const auto input = random_value_generator.Uniform<float>(shape, 0.1f, 1.0f);
|
||||
for (const auto input_value : input) {
|
||||
expected_output_sum += input_value;
|
||||
expected_output_square_sum += input_value * input_value;
|
||||
expected_output_mean += input_value / float(size);
|
||||
}
|
||||
const auto buffer_size_in_bytes =
|
||||
cuda::compute_reduction_buffer_size<float>(size);
|
||||
|
||||
auto device_input = AllocateDeviceMemory<float>(size);
|
||||
auto device_output_sum = AllocateDeviceMemory<float>();
|
||||
auto device_output_square_sum = AllocateDeviceMemory<float>();
|
||||
auto device_output_mean = AllocateDeviceMemory<float>();
|
||||
auto buffer = AllocateDeviceMemory<char>(buffer_size_in_bytes);
|
||||
|
||||
cudaMemcpy(device_input.get(), input.data(), size * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
ASSERT_STATUS_OK(cuda::reduce_sum(
|
||||
device_input.get(),
|
||||
device_output_sum.get(),
|
||||
size,
|
||||
buffer.get(),
|
||||
buffer_size_in_bytes));
|
||||
ASSERT_STATUS_OK(cuda::reduce_square_sum(
|
||||
device_input.get(),
|
||||
device_output_square_sum.get(),
|
||||
size,
|
||||
buffer.get(),
|
||||
buffer_size_in_bytes));
|
||||
ASSERT_STATUS_OK(cuda::reduce_mean(
|
||||
device_input.get(),
|
||||
device_output_mean.get(),
|
||||
size,
|
||||
buffer.get(),
|
||||
buffer_size_in_bytes));
|
||||
|
||||
ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()));
|
||||
|
||||
CheckDeviceValues(1, device_output_sum.get(), &expected_output_sum, relative_error_tolerance);
|
||||
CheckDeviceValues(1, device_output_square_sum.get(), &expected_output_square_sum, relative_error_tolerance);
|
||||
CheckDeviceValues(1, device_output_mean.get(), &expected_output_mean, relative_error_tolerance);
|
||||
}
|
||||
|
||||
void TestReduceRowsToRow(int m, int n, bool reset_initial_output, float relative_error_tolerance = 1e-4f) {
|
||||
SCOPED_TRACE(MakeString("m: ", m, ", n:", n, ", reset_initial_output: ", reset_initial_output));
|
||||
|
||||
const TensorShape shape{m, n};
|
||||
RandomValueGenerator random{};
|
||||
const auto values = random.Uniform<float>(shape.GetDims(), 1.0f, 10.0f);
|
||||
const auto initial_value = reset_initial_output ? 0.0f : 5.0f;
|
||||
const std::vector<float> expected_row =
|
||||
[m, n, &values, initial_value]() {
|
||||
std::vector<float> row(n, initial_value);
|
||||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
row[j] += values[i * n + j];
|
||||
}
|
||||
}
|
||||
return row;
|
||||
}();
|
||||
|
||||
auto d_in = AllocateDeviceMemory<float>(m * n);
|
||||
auto d_out = AllocateDeviceMemory<float>(n);
|
||||
|
||||
cudaMemcpy(d_in.get(), values.data(), m * n * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
if (!reset_initial_output) {
|
||||
// manually initialize output data
|
||||
cuda::Fill(d_out.get(), initial_value, n);
|
||||
}
|
||||
|
||||
ASSERT_STATUS_OK(cuda::reduce_matrix_rows(
|
||||
d_in.get(), d_out.get(),
|
||||
m, n,
|
||||
reset_initial_output));
|
||||
|
||||
ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()));
|
||||
|
||||
CheckDeviceValues(n, d_out.get(), expected_row.data(), relative_error_tolerance);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> ExpectedReduceMatrixColumnsOutput(
|
||||
int m, int n, const std::vector<T>& values) {
|
||||
std::vector<T> column(m);
|
||||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
column[i] += values[i * n + j];
|
||||
}
|
||||
}
|
||||
return column;
|
||||
}
|
||||
|
||||
void TestReduceColumnsToColumn(int m, int n, float relative_error_tolerance = 1e-4f) {
|
||||
SCOPED_TRACE(MakeString("m: ", m, ", n:", n));
|
||||
|
||||
const TensorShape shape{m, n};
|
||||
RandomValueGenerator random{};
|
||||
const auto values = random.Uniform<float>(shape.GetDims(), 1.0f, 10.0f);
|
||||
const auto expected_column = ExpectedReduceMatrixColumnsOutput(m, n, values);
|
||||
|
||||
auto d_in = AllocateDeviceMemory<float>(m * n);
|
||||
auto d_out = AllocateDeviceMemory<float>(m);
|
||||
|
||||
cudaMemcpy(d_in.get(), values.data(), m * n * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
size_t buffer_size_in_bytes =
|
||||
cuda::compute_reduce_matrix_columns_buffer_size<float>(m, n);
|
||||
auto d_buffer = AllocateDeviceMemory<char>(buffer_size_in_bytes);
|
||||
|
||||
ASSERT_STATUS_OK(cuda::reduce_matrix_columns(
|
||||
d_in.get(), d_out.get(),
|
||||
m, n,
|
||||
d_buffer.get(), buffer_size_in_bytes));
|
||||
|
||||
ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()));
|
||||
|
||||
CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(ReductionFunctionsTest, ReduceRowToScalar) {
|
||||
TestReduceRowToScalarApis(3);
|
||||
TestReduceRowToScalarApis(19);
|
||||
TestReduceRowToScalarApis(123);
|
||||
TestReduceRowToScalarApis(1128);
|
||||
TestReduceRowToScalarApis(5566);
|
||||
TestReduceRowToScalarApis(941736, 2e-4f);
|
||||
}
|
||||
|
||||
TEST(ReductionFunctionsTest, ReduceRowsToRow) {
|
||||
for (int m : {3, 193, 2945}) {
|
||||
for (int n : {3, 193, 2945}) {
|
||||
TestReduceRowsToRow(m, n, true);
|
||||
TestReduceRowsToRow(m, n, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ReductionFunctionsTest, ReduceColumnsToColumn) {
|
||||
for (int m : {3, 193, 2945}) {
|
||||
for (int n : {3, 193, 2945}) {
|
||||
TestReduceColumnsToColumn(m, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ReductionFunctionsTest, BufferOffsets) {
|
||||
const int m = 2048;
|
||||
const int n = 1024;
|
||||
|
||||
const size_t max_buffer_offset = 15;
|
||||
|
||||
const size_t buffer_size_in_bytes =
|
||||
cuda::compute_reduce_matrix_columns_buffer_size<double>(m, n) + max_buffer_offset;
|
||||
|
||||
auto d_input = AllocateDeviceMemory<double>(m * n);
|
||||
auto d_output = AllocateDeviceMemory<double>(m);
|
||||
auto d_buffer = AllocateDeviceMemory<char>(buffer_size_in_bytes);
|
||||
|
||||
RandomValueGenerator random{};
|
||||
const float relative_error_tolerance = 1e-4f;
|
||||
|
||||
for (size_t buffer_offset = 1; buffer_offset <= max_buffer_offset; ++buffer_offset) {
|
||||
SCOPED_TRACE(MakeString("buffer offset: ", buffer_offset));
|
||||
|
||||
const auto input = random.Uniform<double>({m, n}, 1.0, 10.0);
|
||||
cudaMemcpy(d_input.get(), input.data(), m * n * sizeof(double), cudaMemcpyHostToDevice);
|
||||
|
||||
ASSERT_STATUS_OK(cuda::reduce_matrix_columns(
|
||||
d_input.get(), d_output.get(),
|
||||
m, n,
|
||||
d_buffer.get() + buffer_offset,
|
||||
buffer_size_in_bytes - buffer_offset));
|
||||
|
||||
const auto expected_column = ExpectedReduceMatrixColumnsOutput(m, n, input);
|
||||
CheckDeviceValues(m, d_output.get(), expected_column.data(), relative_error_tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ReductionFunctionsTest, InvalidBufferSize) {
|
||||
const int m = 2048;
|
||||
const int n = 1024;
|
||||
|
||||
// this should be too small
|
||||
const size_t buffer_size_in_bytes =
|
||||
cuda::compute_reduce_matrix_columns_buffer_size<float>(m, n) / 10;
|
||||
|
||||
auto d_input = AllocateDeviceMemory<float>(m * n);
|
||||
auto d_output = AllocateDeviceMemory<float>(m);
|
||||
auto d_buffer = AllocateDeviceMemory<char>(buffer_size_in_bytes);
|
||||
|
||||
RandomValueGenerator random{};
|
||||
const auto input = random.Uniform<float>({m, n}, 1.0, 10.0);
|
||||
cudaMemcpy(d_input.get(), input.data(), m * n * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
const auto status =
|
||||
cuda::reduce_matrix_columns(d_input.get(), d_output.get(), m, n, d_buffer.get(), buffer_size_in_bytes);
|
||||
ASSERT_FALSE(status.IsOK());
|
||||
}
|
||||
|
||||
TEST(ReductionFunctionsTest, GetApplicableMatrixReduction) {
|
||||
const cudnnReduceTensorOp_t valid_op_type = CUDNN_REDUCE_TENSOR_ADD;
|
||||
int m{}, n{};
|
||||
|
||||
// contiguous axes from beginning
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {0, 1}, m, n),
|
||||
cuda::ApplicableMatrixReduction::Rows);
|
||||
EXPECT_EQ(m, 2 * 4);
|
||||
EXPECT_EQ(n, 8 * 16);
|
||||
|
||||
// contiguous axes to end
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {1, 2, 3}, m, n),
|
||||
cuda::ApplicableMatrixReduction::Columns);
|
||||
EXPECT_EQ(m, 2);
|
||||
EXPECT_EQ(n, 4 * 8 * 16);
|
||||
|
||||
// single axis
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {3}, m, n),
|
||||
cuda::ApplicableMatrixReduction::Columns);
|
||||
EXPECT_EQ(m, 2 * 4 * 8);
|
||||
EXPECT_EQ(n, 16);
|
||||
|
||||
// unsupported axes
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {0, 1, 2, 3}, m, n),
|
||||
cuda::ApplicableMatrixReduction::None);
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {}, m, n),
|
||||
cuda::ApplicableMatrixReduction::None);
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16, 32, 64}, {0, 1, 3, 4}, m, n),
|
||||
cuda::ApplicableMatrixReduction::None);
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
valid_op_type, {2, 4, 8, 16}, {1, 2}, m, n),
|
||||
cuda::ApplicableMatrixReduction::None);
|
||||
|
||||
// invalid op type
|
||||
EXPECT_EQ(
|
||||
cuda::get_applicable_matrix_reduction(
|
||||
CUDNN_REDUCE_TENSOR_MAX, {2, 4, 8, 16}, {0, 1}, m, n),
|
||||
cuda::ApplicableMatrixReduction::None);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
|
|
@ -8,12 +8,12 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
// These variables control the behavior of GetTestRandomSeed().
|
||||
// These environment variables control the behavior of GetTestRandomSeed().
|
||||
namespace test_random_seed_env_vars {
|
||||
// Specifies a fixed seed value to return.
|
||||
// If set, this has the highest precedence.
|
||||
constexpr const char* kValue = "ORT_TEST_RANDOM_SEED_VALUE";
|
||||
// If set (and not using a fixed value), specifies that a new seed value is returned each time.
|
||||
// If set to 1 (and not using a fixed value), specifies that a new seed value is returned each time.
|
||||
// The default behavior is to return the same cached seed value per process.
|
||||
// This is useful when repeatedly running flaky tests to reproduce errors.
|
||||
constexpr const char* kDoNotCache = "ORT_TEST_RANDOM_SEED_DO_NOT_CACHE";
|
||||
|
|
|
|||
|
|
@ -24,10 +24,13 @@ static void TestReduceSum(const std::vector<int64_t>& X_dims,
|
|||
|
||||
// create rand inputs
|
||||
RandomValueGenerator random{};
|
||||
std::vector<float> X_data = random.Uniform<float>(X_dims, -10.0f, 10.0f);
|
||||
const bool is_positive = random.Uniform<int>({1}, 0, 2)[0] == 0;
|
||||
const float range_begin = is_positive ? 1.0f : -10.0f;
|
||||
const float range_end = is_positive ? 10.0f : -1.0f;
|
||||
const std::vector<float> X_data = random.Uniform<float>(X_dims, range_begin, range_end);
|
||||
test.AddInput<float>("X", X_dims, X_data);
|
||||
|
||||
std::vector<float> Y_data = FillZeros<float>(Y_dims);
|
||||
const std::vector<float> Y_data = FillZeros<float>(Y_dims);
|
||||
test.AddOutput<float>("Y", Y_dims, Y_data);
|
||||
|
||||
test.CompareWithCPU(kGpuExecutionProvider, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
|
|
@ -62,8 +65,7 @@ TEST(CudaKernelTest, ReduceSum_MidTensor) {
|
|||
std::vector<int64_t> Y_dims{3072};
|
||||
std::vector<int64_t> axes{0, 1};
|
||||
bool keepdims = false;
|
||||
double per_sample_tolerance = 4e-4;
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims, per_sample_tolerance);
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ReduceSum_LargeTensor) {
|
||||
|
|
@ -71,9 +73,31 @@ TEST(CudaKernelTest, ReduceSum_LargeTensor) {
|
|||
std::vector<int64_t> Y_dims{30528};
|
||||
std::vector<int64_t> axes{0, 1};
|
||||
bool keepdims = false;
|
||||
double per_sample_tolerance = 5e-4;
|
||||
double relative_per_sample_tolerance = 5e-2;
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims, per_sample_tolerance, relative_per_sample_tolerance);
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ReduceSum_SmallTensorTrailingAxes) {
|
||||
std::vector<int64_t> X_dims{128, 2, 128};
|
||||
std::vector<int64_t> Y_dims{128};
|
||||
std::vector<int64_t> axes{1, 2};
|
||||
bool keepdims = false;
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ReduceSum_MidTensorTrailingAxes) {
|
||||
std::vector<int64_t> X_dims{3072, 2, 512};
|
||||
std::vector<int64_t> Y_dims{3072};
|
||||
std::vector<int64_t> axes{1, 2};
|
||||
bool keepdims = false;
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, ReduceSum_LargeTensorTrailingAxes) {
|
||||
std::vector<int64_t> X_dims{30528, 4, 512};
|
||||
std::vector<int64_t> Y_dims{30528};
|
||||
std::vector<int64_t> axes{1, 2};
|
||||
bool keepdims = false;
|
||||
TestReduceSum(X_dims, Y_dims, axes, keepdims);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_VERSIONED_TYPED_TWO_TYPES(Class, T, Tin, domain, startver, endver) \
|
||||
ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \
|
||||
Class, \
|
||||
domain, \
|
||||
startver, endver, \
|
||||
T, Tin, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("Tin", DataTypeImpl::GetTensorType<Tin>()), \
|
||||
#define REGISTER_KERNEL_VERSIONED_TYPED_TWO_TYPES(Class, T, Tin, domain, startver, endver) \
|
||||
ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \
|
||||
Class, \
|
||||
domain, \
|
||||
startver, endver, \
|
||||
T, Tin, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("Tin", DataTypeImpl::GetTensorType<Tin>()), \
|
||||
Class<T, Tin>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_TWO_TYPES(Class, T, Tin, domain, version) \
|
||||
|
|
@ -113,16 +113,17 @@ Status SoftmaxCrossEntropyLoss<T, Tin>::ComputeInternal(OpKernelContext* ctx) co
|
|||
auto normalize_factor_data = GetScratchBuffer<T>(1);
|
||||
if (reduction_ == ReductionType::MEAN) {
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size = static_cast<size_t>(
|
||||
compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(T)), static_cast<int>(N_D)));
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N_D));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
buffer_size);
|
||||
reduce_sum(weight_data_nd_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reinterpret_cast<T*>(reduction_buffer.get()));
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
weight_data_nd_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
} else {
|
||||
const T normalize_factor = static_cast<T>(1);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice));
|
||||
|
|
@ -213,16 +214,17 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
|
|||
auto normalize_factor_data = GetScratchBuffer<T>(1);
|
||||
if (reduction_ == ReductionType::MEAN) {
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size = static_cast<size_t>(
|
||||
compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(T)), static_cast<int>(N_D)));
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N_D));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
buffer_size);
|
||||
reduce_sum(weight_data_nd_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reinterpret_cast<T*>(reduction_buffer.get()));
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
weight_data_nd_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
} else {
|
||||
const T normalize_factor = static_cast<T>(1);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice));
|
||||
|
|
|
|||
|
|
@ -175,16 +175,17 @@ Status SparseSoftmaxCrossEntropy<T, Tin>::ComputeInternal(OpKernelContext* ctx)
|
|||
cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size = static_cast<size_t>(
|
||||
compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(T)), static_cast<int>(N)));
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
buffer_size);
|
||||
reduce_sum(weight_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N),
|
||||
reinterpret_cast<T*>(reduction_buffer.get()));
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
weight_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -251,16 +252,17 @@ Status SparseSoftmaxCrossEntropyGrad<T, Tin>::ComputeInternal(OpKernelContext* c
|
|||
cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice);
|
||||
} else {
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size = static_cast<size_t>(
|
||||
compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(T)), static_cast<int>(N)));
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
buffer_size);
|
||||
reduce_sum(weight_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N),
|
||||
reinterpret_cast<T*>(reduction_buffer.get()));
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
weight_data,
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,24 +46,24 @@ std::vector<std::pair<int, int>> GenerateLambExtraAliasMapping() {
|
|||
}
|
||||
|
||||
// TODO: Once Schema is checked in to onnx lets fix this to match that
|
||||
#define REGISTER_LAMB_KERNEL_TYPED(T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
LambOptimizer, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1##_##T2##_##T3##_##T4##_##T_GRAD_NORM##_##T_MIXED_PRECISION_FP, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.Alias(GenerateLambExtraAliasMapping()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep do_update in CPU */ \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(4) /* Keep iteration_count in CPU */ \
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* Keep iteration_count in CPU */ \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()) \
|
||||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T3>()) \
|
||||
.TypeConstraint("T4", DataTypeImpl::GetTensorType<T4>()) \
|
||||
.TypeConstraint("T_MIXED_PRECISION_FP", DataTypeImpl::GetTensorType<T_MIXED_PRECISION_FP>()) \
|
||||
.TypeConstraint("T_GRAD_NORM", DataTypeImpl::GetTensorType<T_GRAD_NORM>()), \
|
||||
#define REGISTER_LAMB_KERNEL_TYPED(T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
LambOptimizer, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1##_##T2##_##T3##_##T4##_##T_GRAD_NORM##_##T_MIXED_PRECISION_FP, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.Alias(GenerateLambExtraAliasMapping()) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep do_update in CPU */ \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(4) /* Keep iteration_count in CPU */ \
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0) /* Keep iteration_count in CPU */ \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()) \
|
||||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T3>()) \
|
||||
.TypeConstraint("T4", DataTypeImpl::GetTensorType<T4>()) \
|
||||
.TypeConstraint("T_MIXED_PRECISION_FP", DataTypeImpl::GetTensorType<T_MIXED_PRECISION_FP>()) \
|
||||
.TypeConstraint("T_GRAD_NORM", DataTypeImpl::GetTensorType<T_GRAD_NORM>()), \
|
||||
LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>);
|
||||
|
||||
REGISTER_LAMB_KERNEL_TYPED(float, float, MLFloat16, float, MLFloat16, MLFloat16)
|
||||
|
|
@ -112,7 +112,6 @@ Status copy_inputs_to_outputs(
|
|||
const int group_count,
|
||||
const int input_group_size,
|
||||
const int output_group_size) {
|
||||
|
||||
const Tensor* step_tensor = ctx->Input<Tensor>(4);
|
||||
if (step_tensor) {
|
||||
const int64_t* step_data = step_tensor->template Data<int64_t>();
|
||||
|
|
@ -204,10 +203,10 @@ Status launch_lamb_compute_direction(
|
|||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
// For the first iteration (indexed by 0), the update count should be 2.
|
||||
const float alpha_correction = do_bias_correction ?
|
||||
onnxruntime::contrib::compute_bias_correction_coefficient(alphas[i], update_count) : 1.f;
|
||||
const float beta_correction = do_bias_correction ?
|
||||
onnxruntime::contrib::compute_bias_correction_coefficient(betas[i], update_count) : 1.f;
|
||||
const float alpha_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alphas[i], update_count) : 1.f;
|
||||
const float beta_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(betas[i], update_count) : 1.f;
|
||||
|
||||
LambComputeDirection(
|
||||
p_ws[i],
|
||||
|
|
@ -248,9 +247,9 @@ Status launch_lamb_compute_direction(
|
|||
|
||||
// For the first iteration (indexed by 0), the update count should be 1.
|
||||
const float alpha_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alpha, update_count) : 1.f;
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(alpha, update_count) : 1.f;
|
||||
const float beta_correction =
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(beta, update_count) : 1.f;
|
||||
do_bias_correction ? onnxruntime::contrib::compute_bias_correction_coefficient(beta, update_count) : 1.f;
|
||||
|
||||
typedef LambMultiTensorComputeDirectionFunctor<CudaT2, CudaT3, CudaT4, CudaT_GRAD_NORM> LambStage1;
|
||||
LambStage1 lamb_stage1;
|
||||
|
|
@ -274,7 +273,8 @@ Status launch_lamb_reduction(
|
|||
std::vector<CudaTNorm*>& p_d_norms,
|
||||
std::vector<const CudaTIn1*>& p_ws,
|
||||
std::vector<CudaTIn2*>& p_ds,
|
||||
CudaTNorm* reduction_buffer) {
|
||||
void* reduction_buffer,
|
||||
size_t reduction_buffer_size) {
|
||||
ORT_ENFORCE(group_count == static_cast<int>(tensor_sizes.size()));
|
||||
|
||||
ORT_ENFORCE(group_count == static_cast<int>(p_w_norms.size()));
|
||||
|
|
@ -292,16 +292,18 @@ Status launch_lamb_reduction(
|
|||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
reduce_square_sum(
|
||||
ORT_RETURN_IF_ERROR(reduce_square_sum(
|
||||
p_ws[i],
|
||||
p_w_norms[i],
|
||||
tensor_sizes[i],
|
||||
reduction_buffer);
|
||||
reduce_square_sum(
|
||||
reduction_buffer,
|
||||
reduction_buffer_size));
|
||||
ORT_RETURN_IF_ERROR(reduce_square_sum(
|
||||
p_ds[i],
|
||||
p_d_norms[i],
|
||||
tensor_sizes[i],
|
||||
reduction_buffer);
|
||||
reduction_buffer,
|
||||
reduction_buffer_size));
|
||||
} else {
|
||||
std::vector<void*> ptrs(tensor_count_per_group);
|
||||
ptrs[0] = const_cast<CudaTIn1*>(p_ws[i]); // weight tensor
|
||||
|
|
@ -406,12 +408,13 @@ Status launch_lamb_update(
|
|||
// Only launch multi-tensor function if we have at least one tensor in the buckets.
|
||||
if (tensor_sizes_in_bucket.size() > 0 && buckets.size() > 0) {
|
||||
typedef LambMultiTensorUpdateFunctor<
|
||||
CudaT1, CudaT2, CudaT3, CudaT_MIXED_PRECISION_FP> LambStage2;
|
||||
CudaT1, CudaT2, CudaT3, CudaT_MIXED_PRECISION_FP>
|
||||
LambStage2;
|
||||
LambStage2 lamb_stage2;
|
||||
|
||||
launch_multi_tensor_functor<
|
||||
tensor_count_per_group, LambStage2,
|
||||
const CudaT1*, const float, const float>(
|
||||
tensor_count_per_group, LambStage2,
|
||||
const CudaT1*, const float, const float>(
|
||||
2048 * 32,
|
||||
tensor_sizes_in_bucket,
|
||||
buckets,
|
||||
|
|
@ -540,13 +543,11 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
|
|||
}
|
||||
|
||||
// Allocate a buffer in byte for reduction API calls.
|
||||
const auto buffer_size = static_cast<size_t>(
|
||||
compute_reduction_buffer_size(
|
||||
static_cast<int>(sizeof(T2)), max_tensor_size));
|
||||
const auto reduction_buffer_size =
|
||||
compute_reduction_buffer_size<CudaT2>(max_tensor_size);
|
||||
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(buffer_size);
|
||||
CudaT2* reduction_data = reinterpret_cast<CudaT2*>(reduction_buffer.get());
|
||||
// Allocate reduction buffer whose size is reduction_buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(reduction_buffer_size);
|
||||
|
||||
// Input tensors' pointers.
|
||||
std::vector<const CudaT2*> p_ws(group_count);
|
||||
|
|
@ -626,8 +627,7 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
|
|||
p_w_mixed_precision_news[group_index] = w_mixed_precision_new != nullptr ? reinterpret_cast<CudaT_MIXED_PRECISION_FP*>(w_mixed_precision_new->template MutableData<T_MIXED_PRECISION_FP>()) : nullptr;
|
||||
}
|
||||
|
||||
|
||||
launch_lamb_compute_direction(
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_compute_direction(
|
||||
step_data ? *step_data : 0,
|
||||
group_count,
|
||||
loss_scale_data,
|
||||
|
|
@ -637,18 +637,19 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
|
|||
p_ds,
|
||||
p_m1_news, p_m2_news,
|
||||
alpha_, beta_, lambda_, epsilon_,
|
||||
do_bias_correction_);
|
||||
do_bias_correction_));
|
||||
|
||||
launch_lamb_reduction(
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_reduction(
|
||||
group_count,
|
||||
tensor_sizes,
|
||||
p_w_norms,
|
||||
p_d_norms,
|
||||
p_ws,
|
||||
p_ds,
|
||||
reduction_data);
|
||||
reduction_buffer.get(),
|
||||
reduction_buffer_size));
|
||||
|
||||
launch_lamb_update(
|
||||
ORT_RETURN_IF_ERROR(launch_lamb_update(
|
||||
group_count,
|
||||
eta_data,
|
||||
ratio_min_,
|
||||
|
|
@ -660,7 +661,7 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
|
|||
p_ds,
|
||||
p_w_news,
|
||||
p_g_news,
|
||||
p_w_mixed_precision_news);
|
||||
p_w_mixed_precision_news));
|
||||
|
||||
if (step_tensor) {
|
||||
Tensor* step_tensor_new = ctx->Output(0, step_tensor->Shape());
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ Status ReduceAllL2<TIn, TOut>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
CudaTOut* p_output = reinterpret_cast<CudaTOut*>(output->template MutableData<TOut>());
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(p_output, 0, sizeof(CudaTOut)));
|
||||
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto ctx_internal = dynamic_cast<OpKernelContextInternal*>(ctx);
|
||||
bool deterministic = ctx_internal && ctx_internal->GetUseDeterministicCompute();
|
||||
|
||||
if (!deterministic) {
|
||||
|
|
@ -65,32 +65,36 @@ Status ReduceAllL2<TIn, TOut>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
// alternate path only for deterministic compute ..
|
||||
typedef AccumulationType_t<CudaTOut> CudaTAcc;
|
||||
|
||||
// find scratch buffer size needed by 'reduce_square_sum' for each tensor
|
||||
int scratch_size = 0;
|
||||
// find reduction buffer size needed by 'reduce_square_sum' for each tensor
|
||||
size_t reduction_buffer_size = 0;
|
||||
for (int i = 0; i < total_tensor_count; ++i) {
|
||||
scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(CudaTAcc), tensor_sizes[i]));
|
||||
reduction_buffer_size =
|
||||
std::max(reduction_buffer_size, compute_reduction_buffer_size<CudaTAcc>(tensor_sizes[i]));
|
||||
}
|
||||
|
||||
// enlarge scratch buffer size for 'reduce_sum' over tensor square norms
|
||||
scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(CudaTAcc), total_tensor_count));
|
||||
|
||||
// add head room for final output and square norms of each tensor
|
||||
scratch_size += (1 + total_tensor_count) * sizeof(CudaTAcc);
|
||||
// enlarge reduction buffer size for 'reduce_sum' over tensor square norms
|
||||
reduction_buffer_size =
|
||||
std::max(reduction_buffer_size, compute_reduction_buffer_size<CudaTAcc>(total_tensor_count));
|
||||
|
||||
// create GPU scratch space and zero target for each tensor square norm
|
||||
auto scratch_buffer = GetScratchBuffer<uint8_t>(scratch_size);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(scratch_buffer.get(), 0, sizeof(CudaTAcc) * (1 + total_tensor_count)));
|
||||
auto reduction_buffer = GetScratchBuffer<void>(reduction_buffer_size);
|
||||
|
||||
CudaTAcc* p_global_sqnorm = reinterpret_cast<CudaTAcc*>(scratch_buffer.get());
|
||||
// buffer for final output and square norms of each tensor
|
||||
auto results_buffer = GetScratchBuffer<CudaTAcc>(1 + total_tensor_count);
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(results_buffer.get(), 0, sizeof(CudaTAcc) * (1 + total_tensor_count)));
|
||||
|
||||
CudaTAcc* p_global_sqnorm = results_buffer.get();
|
||||
CudaTAcc* p_tensor_sqnorm = p_global_sqnorm + 1;
|
||||
CudaTAcc* p_reduce_buffer = p_tensor_sqnorm + total_tensor_count;
|
||||
|
||||
// perform reduction l2norm = sqrt[sum(tensor[i][j]**2)] for i,j over all tensor elements
|
||||
for (int i = 0; i < total_tensor_count; ++i) {
|
||||
CudaTIn* p_tensor_i = reinterpret_cast<CudaTIn*>(grouped_tensor_pointers[i][0]);
|
||||
reduce_square_sum(p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], p_reduce_buffer);
|
||||
ORT_RETURN_IF_ERROR(reduce_square_sum(
|
||||
p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], reduction_buffer.get(), reduction_buffer_size));
|
||||
}
|
||||
reduce_sum(p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, p_reduce_buffer);
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, reduction_buffer.get(), reduction_buffer_size));
|
||||
ScalarSqrt(p_global_sqnorm, p_output);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,19 @@
|
|||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
#include "core/providers/cuda/reduction/reduction_utils.cuh"
|
||||
#include "core/providers/cuda/shared_inc/accumulation_type.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename Tin, typename Tout>
|
||||
__global__ void _ScalarSqrtImpl(Tin* input, Tout* output) {
|
||||
template <typename Tin, typename Tout>
|
||||
__global__ void ScalarSqrtKernel(Tin* input, Tout* output) {
|
||||
*output = (Tout)_Sqrt(*input);
|
||||
};
|
||||
|
||||
template<typename Tin, typename Tout>
|
||||
template <typename Tin, typename Tout>
|
||||
void ScalarSqrt(Tin* input, Tout* output) {
|
||||
_ScalarSqrtImpl<<<1, 1, 0>>>(input, output);
|
||||
ScalarSqrtKernel<<<1, 1, 0>>>(input, output);
|
||||
};
|
||||
|
||||
template void ScalarSqrt(float* input, float* output);
|
||||
|
|
@ -25,7 +26,7 @@ template void ScalarSqrt(half* input, half* output);
|
|||
template void ScalarSqrt(float* input, half* output);
|
||||
|
||||
template <typename TIn, typename TOut, typename TBuf, typename TInOp, typename TOutOp>
|
||||
__global__ void _MultiTensorReduceImpl(ChunkGroup<1> chunk_group, TOut* output) {
|
||||
__global__ void MultiTensorReduceKernel(ChunkGroup<1> chunk_group, TOut* output) {
|
||||
const int group_index = chunk_group.block_index_to_tensor_group_index[blockIdx.x];
|
||||
const int tensor_size = chunk_group.tensor_sizes[group_index];
|
||||
const int chunk_size = chunk_group.chunk_size;
|
||||
|
|
@ -76,7 +77,7 @@ __global__ void _MultiTensorReduceImpl(ChunkGroup<1> chunk_group, TOut* output)
|
|||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomic_add(w_norm, TOutOp()(shared_memory[0]));
|
||||
atomic_add(w_norm, TOutOp()(TOut(shared_memory[0])));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -91,13 +92,13 @@ void MultiTensorReduce(ChunkGroup<1> chunk_group, TOut* output) {
|
|||
ORT_ENFORCE(thread_count % GPU_WARP_SIZE == 0);
|
||||
ORT_ENFORCE((thread_count & (thread_count - 1)) == 0);
|
||||
|
||||
_MultiTensorReduceImpl<TIn, TOut, TBuf, TInOp, TOutOp><<<chunk_group.chunk_count, thread_count, shared_memory_size>>>(chunk_group, output);
|
||||
MultiTensorReduceKernel<TIn, TOut, TBuf, TInOp, TOutOp><<<chunk_group.chunk_count, thread_count, shared_memory_size>>>(chunk_group, output);
|
||||
}
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
void MultiTensorReduceL2<TIn, TOut>::operator()(ChunkGroup<1> chunk_group, TOut* output) {
|
||||
typedef typename ToBuffer<TIn>::Type TBuf;
|
||||
MultiTensorReduce<TIn, TOut, TBuf, Square<TBuf, TIn>, Cast<TOut, TBuf>>(chunk_group, output);
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
MultiTensorReduce<TIn, TOut, TBuf, Square, Identity>(chunk_group, output);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MULTI_TENSOR_REDUCTION_L2_FUNCTOR(TIn, TOut) \
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ Status ReduceKernel<allow_multi_axes>::ComputeImplEx(OpKernelContext* ctx, cudnn
|
|||
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
|
||||
bool fast_reduction = fast_reduction_;
|
||||
if (fast_reduction) {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto ctx_internal = dynamic_cast<OpKernelContextInternal*>(ctx);
|
||||
if (ctx_internal && ctx_internal->GetUseDeterministicCompute())
|
||||
fast_reduction = false;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue