mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Use IMMA for int8 matmul to leverage Turing Tensor Core (#3413)
Use IMMA for int8 matmul to leverage Turing Tensor Core Format files under onnxruntime/core/providers/cude
This commit is contained in:
parent
de60a14c16
commit
4d71958ccf
42 changed files with 371 additions and 113 deletions
|
|
@ -757,6 +757,11 @@ if (onnxruntime_USE_CUDA)
|
|||
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
|
||||
set(ONNXRUNTIME_CUDA_LIBRARIES ${CUDA_LIBRARIES})
|
||||
list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublas cudnn curand)
|
||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "10.1.0")
|
||||
list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublasLt)
|
||||
else()
|
||||
message(WARNING "cublasLT is not supported in CUDA with version lower than 10.1.")
|
||||
endif()
|
||||
if (WIN32)
|
||||
link_directories(${onnxruntime_CUDNN_HOME}/lib/x64)
|
||||
|
||||
|
|
|
|||
|
|
@ -159,13 +159,20 @@ class CudaKernel : public OpKernel {
|
|||
return provider_->PerThreadCublasHandle();
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
inline cublasLtHandle_t CublasLtHandle() const {
|
||||
return provider_->PerThreadCublasLtHandle();
|
||||
}
|
||||
#endif
|
||||
|
||||
inline cudnnHandle_t CudnnHandle() const {
|
||||
return provider_->PerThreadCudnnHandle();
|
||||
}
|
||||
|
||||
inline curandGenerator_t CurandGenerator() const {
|
||||
return provider_->PerThreadCurandGenerator();
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline const T* GetConstOnes(size_t count) const {
|
||||
return provider_->template GetConstOnes<T>(count);
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ thread_local std::unique_ptr<CUDAExecutionProvider::PerThreadContextMap> CUDAExe
|
|||
CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit) {
|
||||
CUDA_CALL_THROW(cudaSetDevice(device_id));
|
||||
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
|
||||
#if CUDA_VERSION >= 10010
|
||||
CUBLAS_CALL_THROW(cublasLtCreate(&cublasLt_handle_));
|
||||
#endif
|
||||
CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
|
||||
CURAND_CALL_THROW(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
|
||||
|
||||
|
|
@ -69,6 +72,14 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
|
|||
LOGS_DEFAULT(ERROR) << "cublasDestroy threw:" << ex.what();
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
try {
|
||||
CUBLAS_CALL(cublasLtDestroy(cublasLt_handle_));
|
||||
} catch (const std::exception& ex) {
|
||||
LOGS_DEFAULT(ERROR) << "cublasLtDestroy threw:" << ex.what();
|
||||
}
|
||||
#endif
|
||||
|
||||
try {
|
||||
CUDNN_CALL(cudnnDestroy(cudnn_handle_));
|
||||
} catch (const std::exception& ex) {
|
||||
|
|
@ -743,7 +754,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin);
|
||||
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
return GetPerThreadContext().CublasHandle();
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
cublasLtHandle_t PerThreadCublasLtHandle() {
|
||||
return GetPerThreadContext().CublasLtHandle();
|
||||
}
|
||||
#endif
|
||||
|
||||
cudnnHandle_t PerThreadCudnnHandle() {
|
||||
return GetPerThreadContext().CudnnHandle();
|
||||
}
|
||||
|
|
@ -95,6 +101,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
return cublas_handle_;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
cublasLtHandle_t CublasLtHandle() const {
|
||||
return cublasLt_handle_;
|
||||
}
|
||||
#endif
|
||||
|
||||
cudnnHandle_t CudnnHandle() const {
|
||||
return cudnn_handle_;
|
||||
}
|
||||
|
|
@ -135,6 +147,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
|
||||
private:
|
||||
cublasHandle_t cublas_handle_ = nullptr;
|
||||
#if CUDA_VERSION >= 10010
|
||||
cublasLtHandle_t cublasLt_handle_ = nullptr;
|
||||
#endif
|
||||
cudnnHandle_t cudnn_handle_ = nullptr;
|
||||
curandGenerator_t curand_generator_ = nullptr;
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@
|
|||
#include <curand.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
// support of cublasLt starts 10.1
|
||||
#if CUDA_VERSION >= 10010
|
||||
#include <cublasLt.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_NCCL
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -115,12 +115,12 @@ Status CudnnFilterDescriptor::Set(const std::vector<int64_t>& filter_dims, cudnn
|
|||
template <typename ElemType>
|
||||
cudnnDataType_t CudnnTensor::GetDataType() {
|
||||
ORT_THROW("cuDNN engine currently supports only single/double/half/int8/uint8 precision data types. Got:",
|
||||
typeid(ElemType).name());
|
||||
typeid(ElemType).name());
|
||||
// Not reachable but GCC complains
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
cudnnDataType_t CudnnTensor::GetDataType<float>() {
|
||||
return CUDNN_DATA_FLOAT;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ Status ConstantOfShape::ComputeInternal(OpKernelContext* ctx) const {
|
|||
const void* value_ptr = GetValuePtr();
|
||||
const auto element_size = output_tensor->DataType()->Size();
|
||||
|
||||
#define CASE(TYPE) \
|
||||
case sizeof(TYPE): \
|
||||
if (size > 0) { \
|
||||
cuda::Fill(reinterpret_cast<TYPE*>(output_data), *(reinterpret_cast<const TYPE*>(value_ptr)), size); \
|
||||
} \
|
||||
#define CASE(TYPE) \
|
||||
case sizeof(TYPE): \
|
||||
if (size > 0) { \
|
||||
cuda::Fill(reinterpret_cast<TYPE*>(output_data), *(reinterpret_cast<const TYPE*>(value_ptr)), size); \
|
||||
} \
|
||||
break;
|
||||
|
||||
switch (element_size) {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace cuda {
|
|||
|
||||
class ConstantOfShape final : public ConstantOfShapeBase, public CudaKernel {
|
||||
public:
|
||||
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info) {};
|
||||
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info){};
|
||||
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
|
||||
template <typename T>
|
||||
bool RangeImpl(const T start, const T delta, const int count, T* output);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ GPUDataTransfer::~GPUDataTransfer() {
|
|||
}
|
||||
|
||||
bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
|
||||
return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED
|
||||
|| dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED;
|
||||
return src_device.Type() == OrtDevice::GPU ||
|
||||
src_device.MemType() == OrtDevice::MemType::CUDA_PINNED ||
|
||||
dst_device.Type() == OrtDevice::GPU ||
|
||||
dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED;
|
||||
}
|
||||
|
||||
common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
|
||||
|
|
|
|||
177
onnxruntime/core/providers/cuda/igemm.cc
Normal file
177
onnxruntime/core/providers/cuda/igemm.cc
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "igemm.h"
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
void LtIgemmTensor(int m,
|
||||
int n,
|
||||
int k,
|
||||
int32_t alpha_matmul,
|
||||
int32_t beta_matmul,
|
||||
const int8_t* a,
|
||||
int lda,
|
||||
const int8_t* b,
|
||||
int ldb,
|
||||
int32_t* c,
|
||||
int ldc,
|
||||
const CudaKernel* cuda_kernel,
|
||||
cublasLtHandle_t lt_handle) {
|
||||
// Create descriptors for the original matrices
|
||||
cublasLtMatrixLayout_t a_desc = nullptr;
|
||||
cublasLtMatrixLayout_t b_desc = nullptr;
|
||||
cublasLtMatrixLayout_t c_desc = nullptr;
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&a_desc, CUDA_R_8I, m, k, lda));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&b_desc, CUDA_R_8I, n, k, ldb));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&c_desc, CUDA_R_32I, m, n, ldc));
|
||||
|
||||
// Set A and C row major order.
|
||||
// No need for B because B need to be transposed
|
||||
cublasLtOrder_t row_order = CUBLASLT_ORDER_ROW;
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row_order, sizeof(row_order)));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row_order, sizeof(row_order)));
|
||||
|
||||
// The tensor operations IGEMM kernels require specialized memory order of data.
|
||||
// Matrix A and Matrix C need to be in CUBLASLT_ORDER_COL32 order
|
||||
// And Matric B needs to be in CUBLASLT_ORDER_COL4_4R2_8C order
|
||||
|
||||
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
|
||||
cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
|
||||
|
||||
// For CUBLASLT_ORDER_COL32 order, Data is ordered in column-major ordered tiles of 32 columns.
|
||||
// The leading dimension is the stride (in elements) to the beginning of next group of 32-columns.
|
||||
|
||||
// For CUBLASLT_ORDER_COL4_4R2_8C, Data is ordered in column-major ordered tiles of composite tiles
|
||||
// with total 32 columns and 8 rows.
|
||||
// A tile is composed of interleaved inner tiles of 4 columns within 4 even or odd rows in an alternating pattern.
|
||||
// The leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile
|
||||
// for the next 32-wide group of columns.
|
||||
int lda_transform = 32 * m;
|
||||
int ldb_transform = 32 * roundoff(n, 8);
|
||||
int ldc_transform = 32 * m;
|
||||
|
||||
// Allocate memory for transform
|
||||
IAllocatorUniquePtr<int8_t> a_transform = cuda_kernel->GetScratchBuffer<int8_t>(roundoff(k, 32) / 32 * lda_transform);
|
||||
IAllocatorUniquePtr<int8_t> b_transform = cuda_kernel->GetScratchBuffer<int8_t>(roundoff(k, 32) / 32 * ldb_transform);
|
||||
IAllocatorUniquePtr<int32_t> c_transform = cuda_kernel->GetScratchBuffer<int32_t>(roundoff(k, 32) / 32 * ldc_transform);
|
||||
|
||||
// Create descriptors for the transformed matrices
|
||||
cublasLtMatrixLayout_t a_transform_desc = nullptr;
|
||||
cublasLtMatrixLayout_t b_transform_desc = nullptr;
|
||||
cublasLtMatrixLayout_t c_transform_desc = nullptr;
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&a_transform_desc, CUDA_R_8I, m, k, lda_transform));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&b_transform_desc, CUDA_R_8I, n, k, ldb_transform));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&c_transform_desc, CUDA_R_32I, m, n, ldc_transform));
|
||||
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(a_transform_desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_ORDER,
|
||||
&order_COL32,
|
||||
sizeof(order_COL32)));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(b_transform_desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_ORDER,
|
||||
&order_COL4_4R2_8C,
|
||||
sizeof(order_COL4_4R2_8C)));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(c_transform_desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_ORDER,
|
||||
&order_COL32,
|
||||
sizeof(order_COL32)));
|
||||
|
||||
cublasLtMatrixTransformDesc_t transform_desc = nullptr;
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F));
|
||||
|
||||
float alpha_transform = 1.0f;
|
||||
float beta_transform = 0.0f;
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle,
|
||||
transform_desc,
|
||||
&alpha_transform,
|
||||
a,
|
||||
a_desc,
|
||||
&beta_transform,
|
||||
nullptr,
|
||||
nullptr,
|
||||
a_transform.get(),
|
||||
a_transform_desc,
|
||||
0));
|
||||
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle,
|
||||
transform_desc,
|
||||
&alpha_transform,
|
||||
b,
|
||||
b_desc,
|
||||
&beta_transform,
|
||||
nullptr,
|
||||
nullptr,
|
||||
b_transform.get(),
|
||||
b_transform_desc,
|
||||
0));
|
||||
|
||||
if (beta_matmul == 1) {
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle,
|
||||
transform_desc,
|
||||
&alpha_transform,
|
||||
c,
|
||||
c_desc,
|
||||
&beta_transform,
|
||||
nullptr,
|
||||
nullptr,
|
||||
c_transform.get(),
|
||||
c_transform_desc,
|
||||
0));
|
||||
}
|
||||
|
||||
// Tensor op igemm kernels only support NT gemm
|
||||
cublasLtMatmulDesc_t matmul_desc = nullptr;
|
||||
cublasOperation_t op_trans = CUBLAS_OP_T;
|
||||
CUBLAS_CALL_THROW(cublasLtMatmulDescCreate(&matmul_desc, CUDA_R_32I));
|
||||
CUBLAS_CALL_THROW(cublasLtMatmulDescSetAttribute(matmul_desc,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op_trans,
|
||||
sizeof(op_trans)));
|
||||
|
||||
CUBLAS_CALL_THROW(cublasLtMatmul(lt_handle,
|
||||
matmul_desc,
|
||||
&alpha_matmul,
|
||||
a_transform.get(),
|
||||
a_transform_desc,
|
||||
b_transform.get(),
|
||||
b_transform_desc,
|
||||
&beta_matmul,
|
||||
c_transform.get(),
|
||||
c_transform_desc,
|
||||
c_transform.get(),
|
||||
c_transform_desc,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0));
|
||||
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle,
|
||||
transform_desc,
|
||||
&alpha_transform,
|
||||
c_transform.get(),
|
||||
c_transform_desc,
|
||||
&beta_transform,
|
||||
nullptr,
|
||||
nullptr,
|
||||
c,
|
||||
c_desc,
|
||||
0));
|
||||
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(c_transform_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(b_transform_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(a_transform_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(c_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(b_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(a_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatmulDescDestroy(matmul_desc));
|
||||
CUBLAS_CALL_THROW(cublasLtMatrixTransformDescDestroy(transform_desc));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
31
onnxruntime/core/providers/cuda/igemm.h
Normal file
31
onnxruntime/core/providers/cuda/igemm.h
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
inline int roundoff(int v, int d) {
|
||||
return (v + d - 1) / d * d;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
void LtIgemmTensor(int m,
|
||||
int n,
|
||||
int k,
|
||||
int32_t alpha_matmul,
|
||||
int32_t beta_matmul,
|
||||
const int8_t* a,
|
||||
int lda,
|
||||
const int8_t* b,
|
||||
int ldb,
|
||||
int32_t* c,
|
||||
int ldc,
|
||||
const CudaKernel* cuda_kernel,
|
||||
cublasLtHandle_t lt_handle);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
@ -14,7 +14,7 @@ struct BinaryElementwisePreparation {
|
|||
const Tensor* lhs_tensor = nullptr;
|
||||
const Tensor* rhs_tensor = nullptr;
|
||||
Tensor* output_tensor = nullptr;
|
||||
int32_t output_rank_or_simple_broadcast = 0; // for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums
|
||||
int32_t output_rank_or_simple_broadcast = 0; // for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums
|
||||
|
||||
TArray<int64_t> lhs_padded_strides;
|
||||
TArray<int64_t> rhs_padded_strides;
|
||||
|
|
@ -42,8 +42,8 @@ struct BinaryElementwisePreparation {
|
|||
// early return if one operand is scalar
|
||||
if (lhs_shape.Size() == 1 || rhs_shape.Size() == 1) {
|
||||
output_rank_or_simple_broadcast = static_cast<int32_t>(lhs_shape.Size() == 1
|
||||
? SimpleBroadcast::LeftScalar
|
||||
: SimpleBroadcast::RightScalar);
|
||||
? SimpleBroadcast::LeftScalar
|
||||
: SimpleBroadcast::RightScalar);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,18 +34,18 @@ namespace cuda {
|
|||
// NOTE that cu files are compiled with nvcc and should not refer to any onnxruntime headers
|
||||
// so struct BinaryElementwisePreparation cannot be used here
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
void Impl_##name( \
|
||||
int32_t output_rank_or_simple_broadcast, \
|
||||
const TArray<int64_t>* lhs_padded_strides, \
|
||||
const T* lhs_data, \
|
||||
const TArray<int64_t>* rhs_padded_strides, \
|
||||
const T* rhs_data, \
|
||||
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
void Impl_##name( \
|
||||
int32_t output_rank_or_simple_broadcast, \
|
||||
const TArray<int64_t>* lhs_padded_strides, \
|
||||
const T* lhs_data, \
|
||||
const TArray<int64_t>* rhs_padded_strides, \
|
||||
const T* rhs_data, \
|
||||
const TArray<fast_divmod>* fdm_output_strides, \
|
||||
const fast_divmod& fdm_H, \
|
||||
const fast_divmod& fdm_C, \
|
||||
T* output_data, \
|
||||
const fast_divmod& fdm_H, \
|
||||
const fast_divmod& fdm_C, \
|
||||
T* output_data, \
|
||||
size_t count)
|
||||
|
||||
#define BINARY_OP_NAME_EXPR(name, expr) BINARY_ELEMENTWISE_IMPL_DECLARATION(name);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/cuda_allocator.h"
|
||||
#include "core/providers/cuda/igemm.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -106,6 +107,28 @@ Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
beta = 1;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
if (DeviceProp::GetDeviceProps().major >= 7 && DeviceProp::GetDeviceProps().minor >= 5) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
LtIgemmTensor(
|
||||
static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
alpha,
|
||||
beta,
|
||||
a_ptr + helper.LeftOffsets()[batch],
|
||||
static_cast<int>(helper.K()),
|
||||
b_ptr + helper.RightOffsets()[batch],
|
||||
static_cast<int>(helper.N()),
|
||||
output_ptr + helper.OutputOffsets()[batch],
|
||||
static_cast<int>(helper.N()),
|
||||
this,
|
||||
Base::CublasLtHandle());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// pad A and B to make their leading dimension be multiples of 32
|
||||
// because cublasGemmEx requires:
|
||||
// 1. leading dimension is multiples of 4
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace cuda {
|
|||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
TopK,
|
||||
kOnnxDomain,
|
||||
1,9,
|
||||
1, 9,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
TopK<false>);
|
||||
|
|
@ -18,7 +18,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
TopK,
|
||||
kOnnxDomain,
|
||||
10,10,
|
||||
10, 10,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
TopK<true>);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<bool inputk>
|
||||
template <bool inputk>
|
||||
class TopK final : public CudaKernel {
|
||||
public:
|
||||
TopK(const OpKernelInfo&);
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class lru_unordered_map {
|
|||
lru_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
using list_type = std::list<Key, ListAllocator>;
|
||||
using iterator_type = typename list_type::iterator;
|
||||
struct value_type {
|
||||
|
|
@ -126,12 +126,12 @@ struct CudnnConvState {
|
|||
CudnnConvolutionDescriptor conv_desc;
|
||||
|
||||
struct PerfResultParams {
|
||||
decltype(AlgoPerfType().algo) algo;
|
||||
decltype(AlgoPerfType().memory) memory;
|
||||
decltype(AlgoPerfType().algo) algo;
|
||||
decltype(AlgoPerfType().memory) memory;
|
||||
decltype(AlgoPerfType().mathType) mathType;
|
||||
};
|
||||
|
||||
lru_unordered_map<std::vector<int64_t>, PerfResultParams, vector_hash<int64_t>> cached_benchmark_results { MAX_CACHED_ALGO_PERF_RESULTS };
|
||||
lru_unordered_map<std::vector<int64_t>, PerfResultParams, vector_hash<int64_t>> cached_benchmark_results{MAX_CACHED_ALGO_PERF_RESULTS};
|
||||
|
||||
// note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing
|
||||
OrtMutex mutex;
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ namespace cuda {
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()).TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
|
||||
Pool<data_type, pool_type>);
|
||||
|
||||
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9)
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9)
|
||||
POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9)
|
||||
|
|
@ -64,7 +63,6 @@ POOLING_KERNEL(MaxPool, MLFloat16, MaxPool<8>, 12)
|
|||
POOLING_KERNEL(MaxPool, int8_t, MaxPool<8>, 12)
|
||||
POOLING_KERNEL(MaxPool, uint8_t, MaxPool<8>, 12)
|
||||
|
||||
|
||||
POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1)
|
||||
POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1)
|
||||
POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1)
|
||||
|
|
@ -167,8 +165,8 @@ Status Pool<T, PoolType>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
cudnnPoolingMode_t mode = CUDNN_POOLING_MAX;
|
||||
if (PoolType::type == onnxruntime::PoolType::kAveragePool) {
|
||||
mode = pool_attrs_.count_include_pad ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
|
||||
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||
mode = pool_attrs_.count_include_pad ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
|
||||
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||
}
|
||||
CudnnPoolingDescriptor pooling_desc;
|
||||
ORT_RETURN_IF_ERROR(pooling_desc.Set(mode, kernel_shape, pads, strides));
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "roialign.h"
|
||||
#include "roialign_impl.h"
|
||||
|
|
@ -15,7 +15,7 @@ namespace cuda {
|
|||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
|
||||
RoiAlign<T>);
|
||||
|
||||
|
|
@ -58,8 +58,7 @@ Status RoiAlign<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
num_roi_cols,
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(Y.template MutableData<T>()),
|
||||
this->mode_ == RoiAlignMode::avg,
|
||||
batch_indices_ptr->template Data<int64_t>()
|
||||
);
|
||||
batch_indices_ptr->template Data<int64_t>());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ namespace cuda {
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
|
||||
// CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so
|
||||
// it's safer to wrap it with automatically memory deleter as CudnnReduceDescriptor.
|
||||
// An implicit caster from CudnnReduceDescriptor to cudnnReduceTensorDescriptor_t
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class CudnnRnnBase : public CudaKernel {
|
|||
rnn_mode_ = CUDNN_LSTM;
|
||||
weight_cached_ = false;
|
||||
w_data_cache_ = nullptr;
|
||||
|
||||
|
||||
size_t state_size;
|
||||
cudnn_dropout_desc_.CreateDescriptorIfNeeded();
|
||||
cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ namespace cuda {
|
|||
|
||||
template <typename T>
|
||||
class LSTM final : public CudnnRnnBase<T> {
|
||||
|
||||
public:
|
||||
LSTM(const OpKernelInfo& info) : CudnnRnnBase<T>(info) {
|
||||
CudnnRnnBase<T>::SetRNNMode(CUDNN_LSTM);
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void ReverseBySequence(const int32_t seq_length,
|
||||
const int32_t batch_size,
|
||||
const int32_t input_or_hidden_size,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ struct TArray {
|
|||
ORT_ENFORCE(size <= capacity, "TArray size was set to ", size, ", exeeding the capacity limit of ", capacity);
|
||||
}
|
||||
|
||||
TArray(const std::vector<T>& vec) : TArray(static_cast<int32_t>(vec.size())) {
|
||||
TArray(const std::vector<T>& vec) : TArray(static_cast<int32_t>(vec.size())) {
|
||||
memcpy(data_, vec.data(), vec.size() * sizeof(T));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
|
|||
const double* beta,
|
||||
double* C, int ldc,
|
||||
long long int strideC,
|
||||
int batch_count){
|
||||
int batch_count) {
|
||||
return cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batch_count);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ static void CalcEffectiveDims(vector<int64_t>& x_dims, vector<int64_t>& y_dims)
|
|||
if (xdim == ydim || xdim == 1) {
|
||||
x_reverse.push_back(xdim);
|
||||
y_reverse.push_back(ydim);
|
||||
}
|
||||
else { // xdim < ydim && xdim > 1, split
|
||||
} else { // xdim < ydim && xdim > 1, split
|
||||
ydim /= xdim;
|
||||
x_reverse.push_back(xdim);
|
||||
y_reverse.push_back(xdim);
|
||||
|
|
@ -44,18 +43,15 @@ static void CalcEffectiveDims(vector<int64_t>& x_dims, vector<int64_t>& y_dims)
|
|||
}
|
||||
if (x_dims.back() == 1) {
|
||||
y_dims.back() *= y_reverse[i];
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
x_dims.push_back(1);
|
||||
y_dims.push_back(y_reverse[i]);
|
||||
}
|
||||
}
|
||||
else { // x_reverse[i] == y_reverse[i]
|
||||
} else { // x_reverse[i] == y_reverse[i]
|
||||
if (x_dims.back() == y_dims.back()) {
|
||||
x_dims.back() *= x_reverse[i];
|
||||
y_dims.back() *= y_reverse[i];
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
x_dims.push_back(x_reverse[i]);
|
||||
y_dims.push_back(y_reverse[i]);
|
||||
}
|
||||
|
|
@ -107,7 +103,6 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
|
|||
input_strides);
|
||||
}
|
||||
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Expand,
|
||||
kOnnxDomain,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,5 @@ Status ExpandImpl(
|
|||
const TArray<fast_divmod>& output_strides,
|
||||
const TArray<int64_t>& input_strides);
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,25 +21,23 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
}),
|
||||
DataTypeImpl::GetTensorType<int32_t>()})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()}),
|
||||
EyeLike);
|
||||
|
||||
#define TYPED_FUNCTION_CALL(T) \
|
||||
EyeLikeImpl<typename ToCudaType<T>::MappedType>( \
|
||||
offset, \
|
||||
dim1 + 1, \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType *>(T2->template MutableData<T>()), \
|
||||
diag_count); \
|
||||
break;
|
||||
#define TYPED_FUNCTION_CALL(T) \
|
||||
EyeLikeImpl<typename ToCudaType<T>::MappedType>( \
|
||||
offset, \
|
||||
dim1 + 1, \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(T2->template MutableData<T>()), \
|
||||
diag_count); \
|
||||
break;
|
||||
|
||||
Status EyeLike::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* T1 = context->Input<Tensor>(0);
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ namespace cuda {
|
|||
|
||||
template <typename T>
|
||||
void EyeLikeImpl(
|
||||
size_t offset, // offset of first element in diagnal
|
||||
size_t stripe, // stripe, here it's width + 1
|
||||
T* output_data, // output buffer
|
||||
size_t diag_count // total number of elements in diagnal
|
||||
);
|
||||
size_t offset, // offset of first element in diagnal
|
||||
size_t stripe, // stripe, here it's width + 1
|
||||
T* output_data, // output buffer
|
||||
size_t diag_count // total number of elements in diagnal
|
||||
);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
7, 9,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()})
|
||||
.Alias(0, 0),
|
||||
IdentityOp<true>);
|
||||
|
|
|
|||
|
|
@ -10,21 +10,20 @@ namespace cuda {
|
|||
|
||||
int NonZeroCalcBlockCount(int64_t x_size);
|
||||
|
||||
cudaError_t NonZeroCalcPrefixSumTempStorageBytes(int* prefix_counts, int number_of_blocks, size_t& );
|
||||
cudaError_t NonZeroCalcPrefixSumTempStorageBytes(int* prefix_counts, int number_of_blocks, size_t&);
|
||||
|
||||
cudaError_t NonZeroInclusivePrefixSum(void* d_temp_storage, size_t temp_storage_bytes, int* prefix_counts, int number_of_blocks);
|
||||
|
||||
// count nonzero elements in each block into counts_in_blocks,
|
||||
// count nonzero elements in each block into counts_in_blocks,
|
||||
// the counts_in_blocks buffer is pre-allocated on gpu first.
|
||||
template<typename InputT>
|
||||
template <typename InputT>
|
||||
cudaError_t NonZeroCountEachBlock(const InputT* x, int64_t x_size, int* counts_in_blocks);
|
||||
|
||||
// output nonzero positions using input x and prefix_counts for each blocks
|
||||
template<typename InputT>
|
||||
template <typename InputT>
|
||||
cudaError_t NonZeroOutputPositions(
|
||||
const InputT *x, int64_t x_size, int x_rank, const TArray<fast_divmod>& x_strides,
|
||||
const InputT* x, int64_t x_size, int x_rank, const TArray<fast_divmod>& x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
class NonZero final: public CudaKernel {
|
||||
public:
|
||||
class NonZero final : public CudaKernel {
|
||||
public:
|
||||
NonZero(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
ReverseSequenceOp);
|
||||
|
||||
#define ReverseSequenceCallCudaImplTypeAs(T, TEqual) \
|
||||
if (X.IsDataType<T>()) { \
|
||||
if (X.IsDataType<T>()) { \
|
||||
CUDA_RETURN_IF_ERROR(ReverseSequenceCudaImpl( \
|
||||
reinterpret_cast<const typename ToCudaType<TEqual>::MappedType*>(X.template Data<T>()), \
|
||||
seq_lengths.Data<int64_t>(), \
|
||||
|
|
|
|||
1
onnxruntime/core/providers/cuda/tensor/scatter_elements.h
Executable file → Normal file
1
onnxruntime/core/providers/cuda/tensor/scatter_elements.h
Executable file → Normal file
|
|
@ -24,4 +24,3 @@ class ScatterElements final : public CudaKernel {
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
1
onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h
Executable file → Normal file
1
onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h
Executable file → Normal file
|
|
@ -26,4 +26,3 @@ Status ScatterElementsImpl(
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename Tind, bool dynamic>
|
||||
template <typename Tind, bool dynamic>
|
||||
class Slice final : public CudaKernel, public SliceBase {
|
||||
public:
|
||||
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
Status SplitImpl(const size_t element_size,
|
||||
Status SplitImpl(const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t* split_sizes,
|
||||
|
|
|
|||
|
|
@ -8,16 +8,16 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Tile, \
|
||||
kOnnxDomain, \
|
||||
6, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Tile, \
|
||||
kOnnxDomain, \
|
||||
6, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()), \
|
||||
Tile<T>);
|
||||
|
||||
|
|
|
|||
|
|
@ -68,10 +68,10 @@ struct TernaryElementwisePreparation {
|
|||
const Tensor* a_tensor = nullptr;
|
||||
const Tensor* b_tensor = nullptr;
|
||||
const Tensor* c_tensor = nullptr;
|
||||
size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums
|
||||
TArray<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
|
||||
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
|
||||
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
|
||||
size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums
|
||||
TArray<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
|
||||
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
|
||||
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
|
||||
TArray<fast_divmod> fdm_output_strides;
|
||||
|
||||
TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c)
|
||||
|
|
|
|||
Loading…
Reference in a new issue