mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Optimize sequence type usage on CUDA [1/n] (#8598)
This commit is contained in:
parent
e791faeca5
commit
484e9de55c
2 changed files with 26 additions and 31 deletions
|
|
@ -12,6 +12,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
11,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("I", std::vector<MLDataType>{
|
||||
|
|
@ -44,6 +45,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
11,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.OutputMemoryType(OrtMemTypeCPUInput, 0)
|
||||
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
|
||||
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
SequenceLength);
|
||||
|
|
@ -63,6 +65,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
11,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
|
||||
.TypeConstraint("I", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
|
|
@ -75,13 +78,12 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
11,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
|
||||
.TypeConstraint("I", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
SequenceInsert);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,26 +10,20 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename DataType>
|
||||
int64_t ReadIndex(const Tensor& tensor, const char* type_name) {
|
||||
DataType data{0};
|
||||
if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data<DataType>(),
|
||||
sizeof(DataType), cudaMemcpyDeviceToHost))) {
|
||||
ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, ".");
|
||||
}
|
||||
return static_cast<int64_t>(data);
|
||||
}
|
||||
|
||||
class SequenceAt final : public CudaKernel {
|
||||
public:
|
||||
SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const TensorSeq* X = context->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input.");
|
||||
const Tensor* I = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor.");
|
||||
|
||||
int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
|
||||
int64_t idx = -1;
|
||||
if (I->IsDataType<int32_t>()) {
|
||||
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
|
||||
} else {
|
||||
idx = I->Data<int64_t>()[0];
|
||||
}
|
||||
|
||||
int64_t sequence_size = static_cast<int64_t>(X->Size());
|
||||
if (idx < 0) {
|
||||
idx = sequence_size + idx;
|
||||
|
|
@ -59,7 +53,6 @@ class SequenceConstruct final : public CudaKernel {
|
|||
SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
TensorSeq* Y = context->Output<TensorSeq>(0);
|
||||
ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence.");
|
||||
|
||||
AllocatorPtr alloc;
|
||||
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
|
||||
|
|
@ -93,7 +86,6 @@ class SequenceEmpty final : public CudaKernel {
|
|||
}
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
TensorSeq* Y = context->Output<TensorSeq>(0);
|
||||
ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence.");
|
||||
#ifdef SHARED_PROVIDER
|
||||
Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast<int>(dtype_)));
|
||||
#else
|
||||
|
|
@ -111,14 +103,8 @@ class SequenceLength final : public CudaKernel {
|
|||
SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const TensorSeq* X = context->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr.");
|
||||
Tensor* Y = context->Output(0, {});
|
||||
ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence.");
|
||||
auto X_size = static_cast<int64_t>(X->Size());
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(),
|
||||
&X_size,
|
||||
sizeof(int64_t),
|
||||
cudaMemcpyHostToDevice, Stream()));
|
||||
Y->MutableData<int64_t>()[0] = static_cast<int64_t>(X->Size());
|
||||
return Status::OK();
|
||||
}
|
||||
}; // SequenceLength
|
||||
|
|
@ -129,7 +115,6 @@ class ConcatFromSequence final : public CudaKernel, public ConcatBase {
|
|||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const TensorSeq* X = context->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr.");
|
||||
int64_t input_count = static_cast<int64_t>(X->Size());
|
||||
std::vector<const Tensor*> input_tensors;
|
||||
for (int64_t i = 0; i < input_count; ++i) {
|
||||
|
|
@ -175,12 +160,16 @@ class SequenceErase final : public CudaKernel {
|
|||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const TensorSeq* X = context->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input.");
|
||||
int64_t X_size = static_cast<int64_t>(X->Size());
|
||||
int64_t idx = X_size - 1;
|
||||
const Tensor* I = context->Input<Tensor>(1);
|
||||
if (I != nullptr) {
|
||||
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
|
||||
if (I->IsDataType<int32_t>()) {
|
||||
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
|
||||
} else {
|
||||
idx = I->Data<int64_t>()[0];
|
||||
}
|
||||
|
||||
if (idx < 0) {
|
||||
idx = X_size + idx;
|
||||
}
|
||||
|
|
@ -218,12 +207,16 @@ class SequenceInsert final : public CudaKernel {
|
|||
|
||||
Status ComputeInternal(OpKernelContext* context) const override {
|
||||
const TensorSeq* S = context->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input.");
|
||||
int64_t S_size = static_cast<int64_t>(S->Size());
|
||||
int64_t idx = S_size;
|
||||
const Tensor* I = context->Input<Tensor>(2);
|
||||
if (I != nullptr) {
|
||||
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
|
||||
if (I->IsDataType<int32_t>()) {
|
||||
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
|
||||
} else {
|
||||
idx = I->Data<int64_t>()[0];
|
||||
}
|
||||
|
||||
if (idx < 0) {
|
||||
idx = S_size + idx;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue