Optimize sequence type usage on CUDA [1/n] (#8598)

This commit is contained in:
Hariharan Seshadri 2021-08-05 23:25:52 -07:00 committed by GitHub
parent e791faeca5
commit 484e9de55c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 31 deletions

View file

@ -12,6 +12,7 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
@ -44,6 +45,7 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
SequenceLength);
@ -63,6 +65,7 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
@ -75,13 +78,12 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 2)
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
SequenceInsert);
} // namespace cuda
} // namespace onnxruntime
} // namespace cuda
} // namespace onnxruntime

View file

@ -10,26 +10,20 @@
namespace onnxruntime {
namespace cuda {
template <typename DataType>
int64_t ReadIndex(const Tensor& tensor, const char* type_name) {
DataType data{0};
if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data<DataType>(),
sizeof(DataType), cudaMemcpyDeviceToHost))) {
ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, ".");
}
return static_cast<int64_t>(data);
}
class SequenceAt final : public CudaKernel {
public:
SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input.");
const Tensor* I = context->Input<Tensor>(1);
ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor.");
int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
int64_t idx = -1;
if (I->IsDataType<int32_t>()) {
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
} else {
idx = I->Data<int64_t>()[0];
}
int64_t sequence_size = static_cast<int64_t>(X->Size());
if (idx < 0) {
idx = sequence_size + idx;
@ -59,7 +53,6 @@ class SequenceConstruct final : public CudaKernel {
SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence.");
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
@ -93,7 +86,6 @@ class SequenceEmpty final : public CudaKernel {
}
Status ComputeInternal(OpKernelContext* context) const override {
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence.");
#ifdef SHARED_PROVIDER
Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast<int>(dtype_)));
#else
@ -111,14 +103,8 @@ class SequenceLength final : public CudaKernel {
SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr.");
Tensor* Y = context->Output(0, {});
ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence.");
auto X_size = static_cast<int64_t>(X->Size());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(),
&X_size,
sizeof(int64_t),
cudaMemcpyHostToDevice, Stream()));
Y->MutableData<int64_t>()[0] = static_cast<int64_t>(X->Size());
return Status::OK();
}
}; // SequenceLength
@ -129,7 +115,6 @@ class ConcatFromSequence final : public CudaKernel, public ConcatBase {
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr.");
int64_t input_count = static_cast<int64_t>(X->Size());
std::vector<const Tensor*> input_tensors;
for (int64_t i = 0; i < input_count; ++i) {
@ -175,12 +160,16 @@ class SequenceErase final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input.");
int64_t X_size = static_cast<int64_t>(X->Size());
int64_t idx = X_size - 1;
const Tensor* I = context->Input<Tensor>(1);
if (I != nullptr) {
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
if (I->IsDataType<int32_t>()) {
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
} else {
idx = I->Data<int64_t>()[0];
}
if (idx < 0) {
idx = X_size + idx;
}
@ -218,12 +207,16 @@ class SequenceInsert final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* S = context->Input<TensorSeq>(0);
ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input.");
int64_t S_size = static_cast<int64_t>(S->Size());
int64_t idx = S_size;
const Tensor* I = context->Input<Tensor>(2);
if (I != nullptr) {
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
if (I->IsDataType<int32_t>()) {
idx = static_cast<int64_t>(I->Data<int32_t>()[0]);
} else {
idx = I->Data<int64_t>()[0];
}
if (idx < 0) {
idx = S_size + idx;
}