diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ac255c5b73..12ea0af6ef 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -458,6 +458,7 @@ Do not modify directly.* |Concat|*in* inputs:**T**
*out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -514,7 +515,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |If|*in* cond:**B**
*out* outputs:**V**|13+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -552,8 +553,8 @@ Do not modify directly.* |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| |||[8, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 7]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -637,6 +638,12 @@ Do not modify directly.* |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceConstruct|*in* inputs:**T**
*out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |Shape|*in* data:**T**
*out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index be3dda0c64..d29e4b10b6 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -495,6 +495,7 @@ struct ProviderHostImpl : ProviderHost { // DataTypeImpl (wrapped) MLDataType DataTypeImpl__GetType_Tensor() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_TensorSeq () override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetTypeFromOnnxType (int onnx_type) override { return DataTypeImpl::TensorTypeFromONNXEnum(onnx_type)->GetElementType(); } MLDataType DataTypeImpl__GetType_bool() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_int8() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_uint8() override { return DataTypeImpl::GetType(); } @@ -508,6 +509,7 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_double() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_BFloat16() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_MLFloat16() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_string() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_int8() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_uint8() override { return DataTypeImpl::GetTensorType(); } @@ -521,7 +523,6 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_double() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_BFloat16() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_MLFloat16() override { return DataTypeImpl::GetTensorType(); } - const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } bool DataTypeImpl__IsTensorType(const DataTypeImpl* p) override { return p->IsTensorType(); } bool DataTypeImpl__IsTensorSequenceType(const DataTypeImpl* p) override { return p->IsTensorSequenceType(); } @@ -531,6 +532,9 @@ struct ProviderHostImpl : ProviderHost { const std::vector& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); } const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() override { return DataTypeImpl::AllIEEEFloatTensorTypes(); } const std::vector& DataTypeImpl__AllTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllTensorAndSequenceTensorTypes(); } + const std::vector& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes(); } + const std::vector& DataTypeImpl__AllSequenceTensorTypes() override { return DataTypeImpl::AllSequenceTensorTypes(); } + const std::vector& DataTypeImpl__AllFixedSizeSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeSequenceTensorTypes(); } size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); } const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 000d604961..9d22474d2a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -53,22 +53,12 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe } AllocatorPtr SessionState::GetAllocator(OrtDevice device) const noexcept { - AllocatorPtr result; - - using AllocatorEntry = std::map, - OrtMemoryInfoLessThanIgnoreAllocType>::const_reference; - - auto entry = std::find_if(allocators_.cbegin(), allocators_.cend(), - [device](AllocatorEntry& entry) { - return entry.first.device == device && - entry.first.mem_type == OrtMemTypeDefault; - }); - - if (entry != allocators_.cend()) { - result = entry->second(device.Id(), OrtMemTypeDefault); + for (const auto& iter : allocators_) { + if (iter.first.device == device) { + return iter.second(device.Id(), iter.first.mem_type); + } } - - return result; + return nullptr; } void SessionState::CreateGraphInfo() { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 648a3be1c8..fa6c4e58c2 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -174,8 +174,8 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, return Status::OK(); } + auto allocator = session_state.GetAllocator(copy_info.target_device); if (!target_mlvalue.IsAllocated()) { - auto allocator = session_state.GetAllocator(copy_info.target_device); ORT_ENFORCE(allocator != nullptr, "Failed to find allocator for device ", copy_info.target_device.ToString()); ORT_RETURN_IF_ERROR(utils::AllocateHelper(allocator, source_mlvalue, target_mlvalue)); } @@ -190,9 +190,16 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, } } else if (source_mlvalue.IsTensorSequence()) { const TensorSeq& source_tensor_seq = source_mlvalue.Get(); - const TensorSeq& target_tensor_seq = target_mlvalue.Get(); - ORT_ENFORCE(source_tensor_seq.Size() == target_tensor_seq.Size(), - "source and target tensor sequence have different number of elements."); + TensorSeq& target_tensor_seq = const_cast(target_mlvalue.Get()); + size_t size = 0; + while ((size = target_tensor_seq.Size()) < source_tensor_seq.Size()) { + if (0 == size) { + target_tensor_seq.SetType(source_tensor_seq.DataType()); + } + const Tensor& source_tensor = source_tensor_seq.Get(size); + std::unique_ptr target_tensor = std::make_unique(source_tensor.DataType(), source_tensor.Shape(), allocator); + target_tensor_seq.Add(std::move(*target_tensor)); + } auto source_iter = source_tensor_seq.begin(); auto target_iter = target_tensor_seq.begin(); while (source_iter != source_tensor_seq.end() && diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 19a18d0dc9..e15a1931fb 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -26,10 +26,39 @@ class Memcpy final : public OpKernel { Memcpy(const OpKernelInfo& info) : OpKernel{info} {} Status Compute(OpKernelContext* ctx) const override { - const auto* X = ctx->Input(0); - Tensor* Y = ctx->Output(0, X->Shape()); - Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId()); - return retval; + auto X_type = ctx->InputType(0); + if (X_type->IsTensorType()) { + const auto* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + Tensor* Y = ctx->Output(0, X->Shape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); + return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId()); + } else if (X_type->IsTensorSequenceType()) { + const TensorSeq* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); + TensorSeq* Y = ctx->Output(0); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); + auto X_dtype = X->DataType(); + Y->SetType(X_dtype); + AllocatorPtr alloc; + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy cuda: unable to get an allocator."); + } + auto X_size = X->Size(); + for (size_t i = 0; i < X_size; ++i) { + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor, Info().GetKernelDef().ExecQueueId()); + if (!retval.IsOK()) { + return retval; + } + Y->Add(std::move(*target_tensor)); + } + return Status::OK(); + } + return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } }; @@ -42,7 +71,7 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .ExecQueueId(kCudaStreamCopyIn) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), Memcpy); ONNX_OPERATOR_KERNEL_EX( @@ -53,7 +82,7 @@ ONNX_OPERATOR_KERNEL_EX( (*KernelDefBuilder::Create()) .OutputMemoryType(OrtMemTypeCPUOutput, 0) .ExecQueueId(kCudaStreamCopyOut) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()), Memcpy); } // namespace cuda @@ -770,6 +799,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceLength); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceErase); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceInsert); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Conv); @@ -1589,6 +1625,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.cc b/onnxruntime/core/providers/cuda/tensor/identity_op.cc index 563bb45643..65e4649068 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.cc @@ -57,7 +57,7 @@ ONNX_OPERATOR_KERNEL_EX( 14, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()) .Alias(0, 0), IdentityOp); } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.h b/onnxruntime/core/providers/cuda/tensor/identity_op.h index 35f695c9d9..be90031ec7 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.h +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.h @@ -53,14 +53,11 @@ class IdentityOp final : public CudaKernel { } } else if (X_ml_type->IsTensorSequenceType()) { const TensorSeq* X = context->Input(0); - if (nullptr == X) { - return Status(common::ONNXRUNTIME, common::FAIL, - "IdentityOp cuda: input tensor is missing."); - } + ORT_ENFORCE(X != nullptr, "IdentityOp cuda: input tensor is missing."); TensorSeq* Y = context->Output(0); - if (nullptr == Y) { - return Status(common::ONNXRUNTIME, common::FAIL, - "IdentityOp cuda: failed to allocate output tensor sequence."); + ORT_ENFORCE(Y != nullptr, "IdentityOp cuda: failed to allocate output tensor sequence."); + if (X == Y) { + return Status::OK(); } auto X_type = X->DataType(); Y->SetType(X_type); @@ -73,7 +70,8 @@ class IdentityOp final : public CudaKernel { auto X_size = X->Size(); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(X_type, source_tensor.Shape(), alloc); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), + source_tensor.Shape(), alloc); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), source_tensor.DataRaw(), source_tensor.SizeInBytes(), diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.cc b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc new file mode 100644 index 0000000000..71f277c48b --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sequence_op.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + SequenceAt, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceAt); + +ONNX_OPERATOR_KERNEL_EX( + SequenceConstruct, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + SequenceConstruct); + +ONNX_OPERATOR_KERNEL_EX( + SequenceEmpty, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + SequenceEmpty); + +ONNX_OPERATOR_KERNEL_EX( + SequenceLength, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + SequenceLength); + +ONNX_OPERATOR_KERNEL_EX( + ConcatFromSequence, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + ConcatFromSequence); + +ONNX_OPERATOR_KERNEL_EX( + SequenceErase, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceErase); + +ONNX_OPERATOR_KERNEL_EX( + SequenceInsert, + kOnnxDomain, + 11, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceInsert); + +} // namespace cuda +} // namespace onnxruntime + + diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.h b/onnxruntime/core/providers/cuda/tensor/sequence_op.h new file mode 100644 index 0000000000..d5ed1c06eb --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.h @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/tensor/concat.h" +#include "core/providers/cuda/tensor/concat_impl.h" + +namespace onnxruntime { +namespace cuda { + +template +int64_t ReadIndex(const Tensor& tensor, const char* type_name) { + DataType data{0}; + if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data(), + sizeof(DataType), cudaMemcpyDeviceToHost))) { + ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, "."); + } + return static_cast(data); +} + +class SequenceAt final : public CudaKernel { + public: + SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {} + Status ComputeInternal(OpKernelContext* context) const override { + const TensorSeq* X = context->Input(0); + ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input."); + const Tensor* I = context->Input(1); + ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor."); + + int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + int64_t sequence_size = static_cast(X->Size()); + if (idx < 0) { + idx = sequence_size + idx; + } + ORT_ENFORCE(idx >= 0 && idx < sequence_size, "SequenceAt GPU: Invalid sequence index."); + + const Tensor& source_tensor = X->Get(idx); + auto source_type = source_tensor.DataType(); + const void* source_addr = source_tensor.DataRaw(source_type); + + Tensor* target_tensor = context->Output(0, source_tensor.Shape()); + ORT_ENFORCE(target_tensor != nullptr, "SequenceAt GPU: Got nullptr for output tensor."); + void* target_addr = target_tensor->MutableDataRaw(source_type); + + if (source_addr != target_addr) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_addr, + source_addr, + source_tensor.SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + } + return Status::OK(); + } +}; // SequenceAt + +class SequenceConstruct final : public CudaKernel { + public: + SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {} + Status ComputeInternal(OpKernelContext* context) const override { + TensorSeq* Y = context->Output(0); + ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence."); + + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(), + "SequenceConstruct GPU: Unable to get an allocator."); + + int32_t at = 0; + const Tensor* source_tensor = nullptr; + while (nullptr != (source_tensor = context->Input(at++))) { + if (1 == at) { + Y->SetType(source_tensor->DataType()); + } + std::unique_ptr target_tensor = Tensor::Create(source_tensor->DataType(), + source_tensor->Shape(), alloc); + ORT_ENFORCE(target_tensor, "SequenceConstruct GPU: Failed to allocate new tensor."); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), + source_tensor->DataRaw(), + source_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + Y->Add(std::move(*target_tensor)); // Add will check type consistency inside + } + return Status::OK(); + } +}; // SequenceConstruct + +class SequenceEmpty final : public CudaKernel { + public: + SequenceEmpty(const OpKernelInfo& info) : CudaKernel(info) { + if (!info.GetAttr("dtype", &dtype_).IsOK()) { + dtype_ = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } + } + Status ComputeInternal(OpKernelContext* context) const override { + TensorSeq* Y = context->Output(0); + ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence."); +#ifdef SHARED_PROVIDER + Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast(dtype_))); +#else + Y->SetType(DataTypeImpl::TensorTypeFromONNXEnum(static_cast(dtype_))->GetElementType()); +#endif + return Status::OK(); + } + + private: + int64_t dtype_{}; +}; // SequenceEmpty + +class SequenceLength final : public CudaKernel { + public: + SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {} + Status ComputeInternal(OpKernelContext* context) const override { + const TensorSeq* X = context->Input(0); + ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr."); + Tensor* Y = context->Output(0, {}); + ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence."); + auto X_size = static_cast(X->Size()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(), + &X_size, + sizeof(int64_t), + cudaMemcpyHostToDevice, Stream())); + return Status::OK(); + } +}; // SequenceLength + +class ConcatFromSequence final : public CudaKernel, public ConcatBase { + public: + ConcatFromSequence(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info, true) {} + + Status ComputeInternal(OpKernelContext* context) const override { + const TensorSeq* X = context->Input(0); + ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr."); + int64_t input_count = static_cast(X->Size()); + std::vector input_tensors; + for (int64_t i = 0; i < input_count; ++i) { + input_tensors.push_back(&X->Get(i)); + } + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(context, input_tensors, p)); + if (0 == p.output_num_elements) { + return Status::OK(); + } + + int64_t initial_output_offset = 0; + auto element_bytes = p.output_tensor->DataType()->Size(); + for (int input_index = 0; input_index < input_count; input_index++) { + const auto& prep = p.inputs[input_index]; + if (prep.num_elements == 0) { + continue; + } + auto input_axis_pitch = prep.axis_pitch; + const uint8_t* input = static_cast(prep.tensor->DataRaw()); + + auto input_size = prep.num_elements; + uint8_t* output = static_cast(p.output_tensor->MutableDataRaw()); + int64_t cur_out_offset = 0; + int64_t cur_in_offset = 0; + for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + output + (initial_output_offset + cur_out_offset) * element_bytes, + input + cur_in_offset * element_bytes, input_axis_pitch * element_bytes, + cudaMemcpyHostToDevice, Stream())); + cur_out_offset += p.output_axis_pitch; + cur_in_offset += input_axis_pitch; + } + initial_output_offset += input_axis_pitch; + } + return Status::OK(); + } +}; // ConcatFromSequence + +class SequenceErase final : public CudaKernel { + public: + SequenceErase(const OpKernelInfo& info) : CudaKernel(info) {} + + Status ComputeInternal(OpKernelContext* context) const override { + const TensorSeq* X = context->Input(0); + ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input."); + int64_t X_size = static_cast(X->Size()); + int64_t idx = X_size - 1; + const Tensor* I = context->Input(1); + if (I != nullptr) { + idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + if (idx < 0) { + idx = X_size + idx; + } + ORT_ENFORCE(idx >= 0 && idx < X_size, "SequenceErase GPU: Invalid sequence index."); + } + + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(), + "SequenceErase GPU: Unable to get an allocator."); + TensorSeq* Y = context->Output(0); + ORT_ENFORCE(Y != nullptr, "SequenceErase GPU: Failed to allocate output tensor sequence."); + Y->SetType(X->DataType()); + for (int64_t i = 0; i < X_size; ++i) { + if (i == idx) { + continue; + } + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), + source_tensor.Shape(), alloc); + + ORT_ENFORCE(target_tensor, "SequenceErase GPU: Failed to allocate new tensor."); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), + source_tensor.DataRaw(), + source_tensor.SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + Y->Add(std::move(*target_tensor)); + } + return Status::OK(); + } +}; // SequenceErase + +class SequenceInsert final : public CudaKernel { + public: + SequenceInsert(const OpKernelInfo& info) : CudaKernel(info) {} + + Status ComputeInternal(OpKernelContext* context) const override { + const TensorSeq* S = context->Input(0); + ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input."); + int64_t S_size = static_cast(S->Size()); + int64_t idx = S_size; + const Tensor* I = context->Input(2); + if (I != nullptr) { + idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t"); + if (idx < 0) { + idx = S_size + idx; + } + ORT_ENFORCE(idx >= 0 && idx <= S_size, "SequenceInsert GPU: Invalid sequence index."); + } + const Tensor* X = context->Input(1); + ORT_ENFORCE(X != nullptr, "SequenceInsert GPU: Got nullptr for tensor input."); + + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(), + "SequenceInsert GPU: Unable to get an allocator."); + + TensorSeq* Y = context->Output(0); + ORT_ENFORCE(Y != nullptr, "SequenceInsert GPU: Failed to allocate output tensor sequence."); + Y->SetType(S->DataType()); + for (int64_t i = 0; i < S_size; ++i) { + if (i == idx) { + std::unique_ptr target_tensor = Tensor::Create(X->DataType(), + X->Shape(), alloc); + ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor."); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), + X->DataRaw(), X->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + Y->Add(std::move(*target_tensor)); + } + const Tensor& source_tensor = S->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), + source_tensor.Shape(), alloc); + ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor."); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), + source_tensor.DataRaw(), + source_tensor.SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + Y->Add(std::move(*target_tensor)); // Add will check type consistency inside + } // for + if (idx == S_size) { + std::unique_ptr target_tensor = Tensor::Create(X->DataType(), + X->Shape(), alloc); + ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor."); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(), + X->DataRaw(), X->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); + Y->Add(std::move(*target_tensor)); + } + return Status::OK(); + } +}; // SequenceInsert + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 4c5377c87e..1618e03c35 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -105,6 +105,7 @@ template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Tensor(); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_TensorSeq(); } +MLDataType DataTypeImpl::GetTypeFromOnnxType(int onnx_type) { return Provider_GetHost()->DataTypeImpl__GetTypeFromOnnxType(onnx_type); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_bool(); } template <> @@ -132,6 +133,8 @@ MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTy template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_MLFloat16(); } template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_string(); } +template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_bool(); } template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_int8(); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index e75a8969c7..3468bfbdba 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -423,6 +423,7 @@ struct ProviderHost { // DataTypeImpl virtual MLDataType DataTypeImpl__GetType_Tensor() = 0; virtual MLDataType DataTypeImpl__GetType_TensorSeq() = 0; + virtual MLDataType DataTypeImpl__GetTypeFromOnnxType(int) = 0; virtual MLDataType DataTypeImpl__GetType_bool() = 0; virtual MLDataType DataTypeImpl__GetType_int8() = 0; virtual MLDataType DataTypeImpl__GetType_uint8() = 0; @@ -436,6 +437,7 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetType_double() = 0; virtual MLDataType DataTypeImpl__GetType_BFloat16() = 0; virtual MLDataType DataTypeImpl__GetType_MLFloat16() = 0; + virtual MLDataType DataTypeImpl__GetType_string() = 0; virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0; virtual MLDataType DataTypeImpl__GetTensorType_int8() = 0; virtual MLDataType DataTypeImpl__GetTensorType_uint8() = 0; @@ -458,6 +460,9 @@ struct ProviderHost { virtual const std::vector& DataTypeImpl__AllTensorTypes() = 0; virtual const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() = 0; virtual const std::vector& DataTypeImpl__AllTensorAndSequenceTensorTypes() = 0; + virtual const std::vector& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() = 0; + virtual const std::vector& DataTypeImpl__AllSequenceTensorTypes() = 0; + virtual const std::vector& DataTypeImpl__AllFixedSizeSequenceTensorTypes() = 0; virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0; virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0; @@ -1282,6 +1287,7 @@ class DataTypeImpl final { static MLDataType GetType(); template static MLDataType GetTensorType(); + static MLDataType GetTypeFromOnnxType(int); bool IsTensorType() const { return g_host->DataTypeImpl__IsTensorType(this); } bool IsTensorSequenceType() const { return g_host->DataTypeImpl__IsTensorSequenceType(this); } @@ -1292,6 +1298,9 @@ class DataTypeImpl final { static const std::vector& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); } static const std::vector& AllIEEEFloatTensorTypes() { return g_host->DataTypeImpl__AllIEEEFloatTensorTypes(); } static const std::vector& AllTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllTensorAndSequenceTensorTypes(); } + static const std::vector& AllFixedSizeTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes(); } + static const std::vector& AllSequenceTensorTypes() { return g_host->DataTypeImpl__AllSequenceTensorTypes(); } + static const std::vector& AllFixedSizeSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeSequenceTensorTypes(); } const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 07636bbc15..f574fd3eba 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -202,13 +202,24 @@ Status Environment::Initialize(std::unique_ptr logging_ // Register MemCpy schema; // These ops are internal-only, so register outside of onnx + static std::vector all_fixed_size_types = []() { + std::vector all_types; + std::vector all_tensor_types = OpSchema::all_tensor_types_with_bfloat(); + std::vector all_sequence_types = OpSchema::all_tensor_sequence_types(); + all_types.insert(all_types.end(), all_tensor_types.begin(), all_tensor_types.end()); + all_types.insert(all_types.end(), all_sequence_types.begin(), all_sequence_types.end()); + all_types.emplace_back("seq(tensor(bfloat16))"); + all_types.erase(std::remove_if(all_types.begin(), all_types.end(), + [](const std::string& s) { return s.find("string") != std::string::npos; }), all_types.end()); + return all_types; }(); + ORT_ATTRIBUTE_UNUSED ONNX_OPERATOR_SCHEMA(MemcpyFromHost) .Input(0, "X", "input", "T") .Output(0, "Y", "output", "T") .TypeConstraint( "T", - OpSchema::all_tensor_types_with_bfloat(), - "Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.") + all_fixed_size_types, + "Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.") .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput) .SetDoc(R"DOC( Internal copy node @@ -219,8 +230,8 @@ Internal copy node .Output(0, "Y", "output", "T") .TypeConstraint( "T", - OpSchema::all_tensor_types_with_bfloat(), - "Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.") + all_fixed_size_types, + "Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.") .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput) .SetDoc(R"DOC( Internal copy node