mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Optimize Where CUDA kernel for UniLMV2. (#3799)
Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
65bfece19d
commit
c222ed6327
4 changed files with 208 additions and 151 deletions
|
|
@ -21,6 +21,12 @@ enum class SimpleBroadcast : int32_t {
|
|||
RightPerChannelBatchN = (int32_t)-5,
|
||||
};
|
||||
|
||||
enum class BroadcastIndexType : int32_t {
|
||||
NoBroadcast = (int32_t)0,
|
||||
Scalar = (int32_t)1,
|
||||
NeedCompute = (int32_t)2,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IConstantBuffer {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -73,6 +73,9 @@ struct TernaryElementwisePreparation {
|
|||
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
|
||||
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
|
||||
TArray<fast_divmod> fdm_output_strides;
|
||||
BroadcastIndexType a_index_type = BroadcastIndexType::NoBroadcast;
|
||||
BroadcastIndexType b_index_type = BroadcastIndexType::NoBroadcast;
|
||||
BroadcastIndexType c_index_type = BroadcastIndexType::NoBroadcast;
|
||||
|
||||
TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c)
|
||||
: a_tensor(a), b_tensor(b), c_tensor(c) {}
|
||||
|
|
@ -108,16 +111,34 @@ struct TernaryElementwisePreparation {
|
|||
}
|
||||
};
|
||||
|
||||
if (a_shape != output_shape) {
|
||||
bool has_need_compute = false;
|
||||
if (a_shape.Size() == 1) {
|
||||
a_index_type = BroadcastIndexType::Scalar;
|
||||
} else if (a_shape != output_shape) {
|
||||
padder(a_rank, a_shape, a_padded_strides);
|
||||
a_index_type = BroadcastIndexType::NeedCompute;
|
||||
has_need_compute = true;
|
||||
}
|
||||
|
||||
if (b_shape != output_shape) {
|
||||
padder(b_rank, b_shape, b_padded_strides);
|
||||
if (b_shape.Size() == 1) {
|
||||
b_index_type = BroadcastIndexType::Scalar;
|
||||
} else if (b_shape != output_shape) {
|
||||
padder(b_rank, b_shape, b_padded_strides);
|
||||
b_index_type = BroadcastIndexType::NeedCompute;
|
||||
has_need_compute = true;
|
||||
}
|
||||
|
||||
if (c_shape != output_shape) {
|
||||
if (c_shape.Size() == 1) {
|
||||
c_index_type = BroadcastIndexType::Scalar;
|
||||
} else if (c_shape != output_shape) {
|
||||
padder(c_rank, c_shape, c_padded_strides);
|
||||
c_index_type = BroadcastIndexType::NeedCompute;
|
||||
has_need_compute = true;
|
||||
}
|
||||
|
||||
if (!has_need_compute) {
|
||||
output_rank_or_simple_broadcast = static_cast<size_t>(SimpleBroadcast::NoBroadcast);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorPitches output_pitches(output_shape.GetDims());
|
||||
|
|
@ -154,10 +175,13 @@ Status Where<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
WhereImpl<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.a_index_type,
|
||||
prepare.a_padded_strides,
|
||||
reinterpret_cast<const bool*>(prepare.a_tensor->template Data<bool>()),
|
||||
prepare.b_index_type,
|
||||
prepare.b_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.b_tensor->template Data<T>()),
|
||||
prepare.c_index_type,
|
||||
prepare.c_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.c_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
// broadcast by computing output coordinate from offset, using fast_divmod
|
||||
template <typename T, bool cond_need_compute, bool x_need_compute, bool y_need_compute>
|
||||
template <typename T, BroadcastIndexType CondIndexType, BroadcastIndexType XIndexType, BroadcastIndexType YIndexType, int NumThreadsPerBlock, int NumElementsPerThread>
|
||||
__global__ void _TenaryElementWise(
|
||||
size_t output_rank,
|
||||
const TArray<int64_t> cond_padded_strides,
|
||||
|
|
@ -22,184 +22,208 @@ __global__ void _TenaryElementWise(
|
|||
const TArray<fast_divmod> fdm_output_strides,
|
||||
T* output_data,
|
||||
CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG cond_index = (cond_need_compute ? 0 : id);
|
||||
CUDA_LONG x_index = (x_need_compute ? 0 : id);
|
||||
CUDA_LONG y_index = (y_need_compute ? 0 : id);
|
||||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
bool cond_value[NumElementsPerThread];
|
||||
T x_value[NumElementsPerThread];
|
||||
T y_value[NumElementsPerThread];
|
||||
|
||||
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
|
||||
CUDA_LONG offset = id;
|
||||
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
|
||||
if (dim >= output_rank) {
|
||||
break;
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (id < N) {
|
||||
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
|
||||
CUDA_LONG cond_index = (CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
|
||||
CUDA_LONG x_index = (XIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
|
||||
CUDA_LONG y_index = (YIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
|
||||
CUDA_LONG offset = id;
|
||||
#pragma unroll
|
||||
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
|
||||
if (dim >= output_rank) {
|
||||
break;
|
||||
}
|
||||
|
||||
int q, r;
|
||||
fdm_output_strides[dim].divmod(offset, q, r);
|
||||
|
||||
if (CondIndexType == BroadcastIndexType::NeedCompute) {
|
||||
cond_index += static_cast<int>(cond_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (XIndexType == BroadcastIndexType::NeedCompute) {
|
||||
x_index += static_cast<int>(x_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (YIndexType == BroadcastIndexType::NeedCompute) {
|
||||
y_index += static_cast<int>(y_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
offset = r;
|
||||
}
|
||||
|
||||
cond_value[i] = cond_data[cond_index];
|
||||
x_value[i] = x_data[x_index];
|
||||
y_value[i] = y_data[y_index];
|
||||
id += NumThreadsPerBlock;
|
||||
}
|
||||
|
||||
int q, r;
|
||||
fdm_output_strides[dim].divmod(offset, q, r);
|
||||
|
||||
if (cond_need_compute) {
|
||||
cond_index += static_cast<int>(cond_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (x_need_compute) {
|
||||
x_index += static_cast<int>(x_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (y_need_compute) {
|
||||
y_index += static_cast<int>(y_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
offset = r;
|
||||
}
|
||||
|
||||
output_data[id] = cond_data[cond_index] ? x_data[x_index] : y_data[y_index];
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (id < N) {
|
||||
output_data[id] = cond_value[i] ? x_value[i] : y_value[i];
|
||||
id += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for scalar broadcast or non-broadcast case
|
||||
template <typename T>
|
||||
template <typename T, BroadcastIndexType CondIndexType, BroadcastIndexType XIndexType, BroadcastIndexType YIndexType, int NumThreadsPerBlock, int NumElementsPerThread>
|
||||
__global__ void _TenaryElementWiseSimple(
|
||||
const bool* cond_data,
|
||||
const T* x_data,
|
||||
const T* y_data,
|
||||
T* output_data,
|
||||
CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
output_data[id] = cond_data[id] ? x_data[id] : y_data[id];
|
||||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
bool cond_value[NumElementsPerThread];
|
||||
T x_value[NumElementsPerThread];
|
||||
T y_value[NumElementsPerThread];
|
||||
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (id < N) {
|
||||
cond_value[i] = cond_data[CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
|
||||
x_value[i] = x_data[XIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
|
||||
y_value[i] = y_data[YIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
|
||||
id += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (id < N) {
|
||||
output_data[id] = cond_value[i] ? x_value[i] : y_value[i];
|
||||
id += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
|
||||
case Y_INDEX_TYPE: { \
|
||||
_TenaryElementWiseSimple<T, \
|
||||
COND_INDEX_TYPE, \
|
||||
X_INDEX_TYPE, \
|
||||
Y_INDEX_TYPE, \
|
||||
GridDim::maxThreadsPerBlock, \
|
||||
GridDim::maxElementsPerThread> \
|
||||
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(cond_data, \
|
||||
x_data, \
|
||||
y_data, \
|
||||
output_data, \
|
||||
N); \
|
||||
} break
|
||||
|
||||
#define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
|
||||
case X_INDEX_TYPE: { \
|
||||
switch(Y_INDEX_TYPE_VAL) { \
|
||||
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
|
||||
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
|
||||
} \
|
||||
} break
|
||||
|
||||
#define HANDLE_COND_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
|
||||
case COND_INDEX_TYPE: { \
|
||||
switch(X_INDEX_TYPE_VAL) { \
|
||||
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
|
||||
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
|
||||
} \
|
||||
} break
|
||||
|
||||
#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
|
||||
case Y_INDEX_TYPE: { \
|
||||
_TenaryElementWise<T, \
|
||||
COND_INDEX_TYPE, \
|
||||
X_INDEX_TYPE, \
|
||||
Y_INDEX_TYPE, \
|
||||
GridDim::maxThreadsPerBlock, \
|
||||
GridDim::maxElementsPerThread> \
|
||||
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output_rank_or_simple_broadcast, \
|
||||
cond_padded_strides, \
|
||||
cond_data, \
|
||||
x_padded_strides, \
|
||||
x_data, \
|
||||
y_padded_strides, \
|
||||
y_data, \
|
||||
fdm_output_strides, \
|
||||
output_data, \
|
||||
N); \
|
||||
} break
|
||||
|
||||
#define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
|
||||
case X_INDEX_TYPE: { \
|
||||
switch(Y_INDEX_TYPE_VAL) { \
|
||||
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
|
||||
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
|
||||
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NeedCompute); \
|
||||
} \
|
||||
} break
|
||||
|
||||
#define HANDLE_COND_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
|
||||
case COND_INDEX_TYPE: { \
|
||||
switch(X_INDEX_TYPE_VAL) { \
|
||||
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
|
||||
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
|
||||
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NeedCompute, Y_INDEX_TYPE_VAL); \
|
||||
} \
|
||||
} break
|
||||
|
||||
template <typename T>
|
||||
void WhereImpl(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
BroadcastIndexType cond_index_type,
|
||||
const TArray<int64_t>& cond_padded_strides,
|
||||
const bool* cond_data,
|
||||
BroadcastIndexType x_index_type,
|
||||
const TArray<int64_t>& x_padded_strides,
|
||||
const T* x_data,
|
||||
BroadcastIndexType y_index_type,
|
||||
const TArray<int64_t>& y_padded_strides,
|
||||
const T* y_data,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
size_t count) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
|
||||
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(count);
|
||||
|
||||
if (output_rank_or_simple_broadcast == static_cast<size_t>(SimpleBroadcast::NoBroadcast)) {
|
||||
_TenaryElementWiseSimple<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
cond_data,
|
||||
x_data,
|
||||
y_data,
|
||||
output_data,
|
||||
N);
|
||||
switch(cond_index_type) {
|
||||
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
|
||||
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
|
||||
}
|
||||
} else {
|
||||
if (cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else {
|
||||
_TenaryElementWise<T, false, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
}
|
||||
switch(cond_index_type) {
|
||||
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
|
||||
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
|
||||
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NeedCompute, x_index_type, y_index_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void WhereImpl<T>(size_t output_rank_or_simple_broadcast, \
|
||||
const TArray<int64_t>& cond_padded_strides, \
|
||||
const bool* cond_data, \
|
||||
const TArray<int64_t>& x_padded_strides, \
|
||||
const T* x_data, \
|
||||
const TArray<int64_t>& y_padded_strides, \
|
||||
const T* y_data, \
|
||||
const TArray<fast_divmod>& fdm_output_strides, \
|
||||
T* output_data, \
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void WhereImpl<T>(size_t output_rank_or_simple_broadcast, \
|
||||
BroadcastIndexType cond_index_type, \
|
||||
const TArray<int64_t>& cond_padded_strides, \
|
||||
const bool* cond_data, \
|
||||
BroadcastIndexType x_index_type, \
|
||||
const TArray<int64_t>& x_padded_strides, \
|
||||
const T* x_data, \
|
||||
BroadcastIndexType y_index_type, \
|
||||
const TArray<int64_t>& y_padded_strides, \
|
||||
const T* y_data, \
|
||||
const TArray<fast_divmod>& fdm_output_strides, \
|
||||
T* output_data, \
|
||||
size_t count);
|
||||
|
||||
SPECIALIZED_IMPL(uint8_t)
|
||||
|
|
@ -209,4 +233,4 @@ SPECIALIZED_IMPL(float)
|
|||
SPECIALIZED_IMPL(half)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,10 +12,13 @@ namespace cuda {
|
|||
template <typename T>
|
||||
void WhereImpl(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
BroadcastIndexType cond_index_type,
|
||||
const TArray<int64_t>& cond_padded_strides,
|
||||
const bool* cond_data,
|
||||
BroadcastIndexType x_index_type,
|
||||
const TArray<int64_t>& x_padded_strides,
|
||||
const T* x_data,
|
||||
BroadcastIndexType y_index_type,
|
||||
const TArray<int64_t>& y_padded_strides,
|
||||
const T* y_data,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
|
|
|
|||
Loading…
Reference in a new issue