diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.cc b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc index 71f277c48b..e7ad0c44ce 100644 --- a/onnxruntime/core/providers/cuda/tensor/sequence_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc @@ -12,6 +12,7 @@ ONNX_OPERATOR_KERNEL_EX( 11, kCudaExecutionProvider, (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("I", std::vector{ @@ -44,6 +45,7 @@ ONNX_OPERATOR_KERNEL_EX( 11, kCudaExecutionProvider, (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) .TypeConstraint("I", DataTypeImpl::GetTensorType()), SequenceLength); @@ -63,6 +65,7 @@ ONNX_OPERATOR_KERNEL_EX( 11, kCudaExecutionProvider, (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) .TypeConstraint("I", std::vector{ DataTypeImpl::GetTensorType(), @@ -75,13 +78,12 @@ ONNX_OPERATOR_KERNEL_EX( 11, kCudaExecutionProvider, (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 2) .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) .TypeConstraint("I", std::vector{ DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), SequenceInsert); -} // namespace cuda -} // namespace onnxruntime - - +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.h b/onnxruntime/core/providers/cuda/tensor/sequence_op.h index d5ed1c06eb..07f83c5e95 100644 --- a/onnxruntime/core/providers/cuda/tensor/sequence_op.h +++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.h @@ -10,26 +10,20 @@ namespace onnxruntime { namespace cuda { -template -int64_t ReadIndex(const Tensor& tensor, const char* type_name) { - DataType data{0}; - if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data(), - sizeof(DataType), cudaMemcpyDeviceToHost))) { - ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, "."); - } - return static_cast(data); -} - class SequenceAt final : public CudaKernel { public: SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {} Status ComputeInternal(OpKernelContext* context) const override { const TensorSeq* X = context->Input(0); - ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input."); const Tensor* I = context->Input(1); - ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor."); - int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + int64_t idx = -1; + if (I->IsDataType()) { + idx = static_cast(I->Data()[0]); + } else { + idx = I->Data()[0]; + } + int64_t sequence_size = static_cast(X->Size()); if (idx < 0) { idx = sequence_size + idx; @@ -59,7 +53,6 @@ class SequenceConstruct final : public CudaKernel { SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {} Status ComputeInternal(OpKernelContext* context) const override { TensorSeq* Y = context->Output(0); - ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence."); AllocatorPtr alloc; ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(), @@ -93,7 +86,6 @@ class SequenceEmpty final : public CudaKernel { } Status ComputeInternal(OpKernelContext* context) const override { TensorSeq* Y = context->Output(0); - ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence."); #ifdef SHARED_PROVIDER Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast(dtype_))); #else @@ -111,14 +103,8 @@ class SequenceLength final : public CudaKernel { SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {} Status ComputeInternal(OpKernelContext* context) const override { const TensorSeq* X = context->Input(0); - ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr."); Tensor* Y = context->Output(0, {}); - ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence."); - auto X_size = static_cast(X->Size()); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(), - &X_size, - sizeof(int64_t), - cudaMemcpyHostToDevice, Stream())); + Y->MutableData()[0] = static_cast(X->Size()); return Status::OK(); } }; // SequenceLength @@ -129,7 +115,6 @@ class ConcatFromSequence final : public CudaKernel, public ConcatBase { Status ComputeInternal(OpKernelContext* context) const override { const TensorSeq* X = context->Input(0); - ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr."); int64_t input_count = static_cast(X->Size()); std::vector input_tensors; for (int64_t i = 0; i < input_count; ++i) { @@ -175,12 +160,16 @@ class SequenceErase final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { const TensorSeq* X = context->Input(0); - ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input."); int64_t X_size = static_cast(X->Size()); int64_t idx = X_size - 1; const Tensor* I = context->Input(1); if (I != nullptr) { - idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + if (I->IsDataType()) { + idx = static_cast(I->Data()[0]); + } else { + idx = I->Data()[0]; + } + if (idx < 0) { idx = X_size + idx; } @@ -218,12 +207,16 @@ class SequenceInsert final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { const TensorSeq* S = context->Input(0); - ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input."); int64_t S_size = static_cast(S->Size()); int64_t idx = S_size; const Tensor* I = context->Input(2); if (I != nullptr) { - idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + if (I->IsDataType()) { + idx = static_cast(I->Data()[0]); + } else { + idx = I->Data()[0]; + } + if (idx < 0) { idx = S_size + idx; }