diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ac255c5b73..12ea0af6ef 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -458,6 +458,7 @@ Do not modify directly.*
|Concat|*in* inputs:**T**
*out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -514,7 +515,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
-|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|If|*in* cond:**B**
*out* outputs:**V**|13+|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -552,8 +553,8 @@ Do not modify directly.*
|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
|||[8, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 7]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)|
-|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
@@ -637,6 +638,12 @@ Do not modify directly.*
|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|SequenceConstruct|*in* inputs:**T**
*out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
+|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
+|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
+|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|Shape|*in* data:**T**
*out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc
index be3dda0c64..d29e4b10b6 100644
--- a/onnxruntime/core/framework/provider_bridge_ort.cc
+++ b/onnxruntime/core/framework/provider_bridge_ort.cc
@@ -495,6 +495,7 @@ struct ProviderHostImpl : ProviderHost {
// DataTypeImpl (wrapped)
MLDataType DataTypeImpl__GetType_Tensor() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetType_TensorSeq () override { return DataTypeImpl::GetType(); }
+ MLDataType DataTypeImpl__GetTypeFromOnnxType (int onnx_type) override { return DataTypeImpl::TensorTypeFromONNXEnum(onnx_type)->GetElementType(); }
MLDataType DataTypeImpl__GetType_bool() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetType_int8() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetType_uint8() override { return DataTypeImpl::GetType(); }
@@ -508,6 +509,7 @@ struct ProviderHostImpl : ProviderHost {
MLDataType DataTypeImpl__GetType_double() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetType_BFloat16() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetType_MLFloat16() override { return DataTypeImpl::GetType(); }
+ MLDataType DataTypeImpl__GetType_string() override { return DataTypeImpl::GetType(); }
MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); }
MLDataType DataTypeImpl__GetTensorType_int8() override { return DataTypeImpl::GetTensorType(); }
MLDataType DataTypeImpl__GetTensorType_uint8() override { return DataTypeImpl::GetTensorType(); }
@@ -521,7 +523,6 @@ struct ProviderHostImpl : ProviderHost {
MLDataType DataTypeImpl__GetTensorType_double() override { return DataTypeImpl::GetTensorType(); }
MLDataType DataTypeImpl__GetTensorType_BFloat16() override { return DataTypeImpl::GetTensorType(); }
MLDataType DataTypeImpl__GetTensorType_MLFloat16() override { return DataTypeImpl::GetTensorType(); }
-
const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); }
bool DataTypeImpl__IsTensorType(const DataTypeImpl* p) override { return p->IsTensorType(); }
bool DataTypeImpl__IsTensorSequenceType(const DataTypeImpl* p) override { return p->IsTensorSequenceType(); }
@@ -531,6 +532,9 @@ struct ProviderHostImpl : ProviderHost {
const std::vector& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); }
const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() override { return DataTypeImpl::AllIEEEFloatTensorTypes(); }
const std::vector& DataTypeImpl__AllTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllTensorAndSequenceTensorTypes(); }
+ const std::vector& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes(); }
+ const std::vector& DataTypeImpl__AllSequenceTensorTypes() override { return DataTypeImpl::AllSequenceTensorTypes(); }
+ const std::vector& DataTypeImpl__AllFixedSizeSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeSequenceTensorTypes(); }
size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); }
const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); }
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 000d604961..9d22474d2a 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -53,22 +53,12 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe
}
AllocatorPtr SessionState::GetAllocator(OrtDevice device) const noexcept {
- AllocatorPtr result;
-
- using AllocatorEntry = std::map,
- OrtMemoryInfoLessThanIgnoreAllocType>::const_reference;
-
- auto entry = std::find_if(allocators_.cbegin(), allocators_.cend(),
- [device](AllocatorEntry& entry) {
- return entry.first.device == device &&
- entry.first.mem_type == OrtMemTypeDefault;
- });
-
- if (entry != allocators_.cend()) {
- result = entry->second(device.Id(), OrtMemTypeDefault);
+ for (const auto& iter : allocators_) {
+ if (iter.first.device == device) {
+ return iter.second(device.Id(), iter.first.mem_type);
+ }
}
-
- return result;
+ return nullptr;
}
void SessionState::CreateGraphInfo() {
diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc
index 648a3be1c8..fa6c4e58c2 100644
--- a/onnxruntime/core/framework/utils.cc
+++ b/onnxruntime/core/framework/utils.cc
@@ -174,8 +174,8 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
return Status::OK();
}
+ auto allocator = session_state.GetAllocator(copy_info.target_device);
if (!target_mlvalue.IsAllocated()) {
- auto allocator = session_state.GetAllocator(copy_info.target_device);
ORT_ENFORCE(allocator != nullptr, "Failed to find allocator for device ", copy_info.target_device.ToString());
ORT_RETURN_IF_ERROR(utils::AllocateHelper(allocator, source_mlvalue, target_mlvalue));
}
@@ -190,9 +190,16 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
}
} else if (source_mlvalue.IsTensorSequence()) {
const TensorSeq& source_tensor_seq = source_mlvalue.Get();
- const TensorSeq& target_tensor_seq = target_mlvalue.Get();
- ORT_ENFORCE(source_tensor_seq.Size() == target_tensor_seq.Size(),
- "source and target tensor sequence have different number of elements.");
+ TensorSeq& target_tensor_seq = const_cast(target_mlvalue.Get());
+ size_t size = 0;
+ while ((size = target_tensor_seq.Size()) < source_tensor_seq.Size()) {
+ if (0 == size) {
+ target_tensor_seq.SetType(source_tensor_seq.DataType());
+ }
+ const Tensor& source_tensor = source_tensor_seq.Get(size);
+ std::unique_ptr target_tensor = std::make_unique(source_tensor.DataType(), source_tensor.Shape(), allocator);
+ target_tensor_seq.Add(std::move(*target_tensor));
+ }
auto source_iter = source_tensor_seq.begin();
auto target_iter = target_tensor_seq.begin();
while (source_iter != source_tensor_seq.end() &&
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 19a18d0dc9..e15a1931fb 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -26,10 +26,39 @@ class Memcpy final : public OpKernel {
Memcpy(const OpKernelInfo& info) : OpKernel{info} {}
Status Compute(OpKernelContext* ctx) const override {
- const auto* X = ctx->Input(0);
- Tensor* Y = ctx->Output(0, X->Shape());
- Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
- return retval;
+ auto X_type = ctx->InputType(0);
+ if (X_type->IsTensorType()) {
+ const auto* X = ctx->Input(0);
+ ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
+ Tensor* Y = ctx->Output(0, X->Shape());
+ ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor.");
+ return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
+ } else if (X_type->IsTensorSequenceType()) {
+ const TensorSeq* X = ctx->Input(0);
+ ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
+ TensorSeq* Y = ctx->Output(0);
+ ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
+ auto X_dtype = X->DataType();
+ Y->SetType(X_dtype);
+ AllocatorPtr alloc;
+ auto status = ctx->GetTempSpaceAllocator(&alloc);
+ if (!status.IsOK()) {
+ return Status(common::ONNXRUNTIME, common::FAIL,
+ "Memcpy cuda: unable to get an allocator.");
+ }
+ auto X_size = X->Size();
+ for (size_t i = 0; i < X_size; ++i) {
+ const Tensor& source_tensor = X->Get(i);
+ std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc);
+ Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor, Info().GetKernelDef().ExecQueueId());
+ if (!retval.IsOK()) {
+ return retval;
+ }
+ Y->Add(std::move(*target_tensor));
+ }
+ return Status::OK();
+ }
+ return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
}
};
@@ -42,7 +71,7 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.ExecQueueId(kCudaStreamCopyIn)
- .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
ONNX_OPERATOR_KERNEL_EX(
@@ -53,7 +82,7 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
.ExecQueueId(kCudaStreamCopyOut)
- .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
} // namespace cuda
@@ -770,6 +799,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceLength);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceErase);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceInsert);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Conv);
@@ -1589,6 +1625,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.cc b/onnxruntime/core/providers/cuda/tensor/identity_op.cc
index 563bb45643..65e4649068 100644
--- a/onnxruntime/core/providers/cuda/tensor/identity_op.cc
+++ b/onnxruntime/core/providers/cuda/tensor/identity_op.cc
@@ -57,7 +57,7 @@ ONNX_OPERATOR_KERNEL_EX(
14,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
- .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
+ .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes())
.Alias(0, 0),
IdentityOp);
} // namespace cuda
diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.h b/onnxruntime/core/providers/cuda/tensor/identity_op.h
index 35f695c9d9..be90031ec7 100644
--- a/onnxruntime/core/providers/cuda/tensor/identity_op.h
+++ b/onnxruntime/core/providers/cuda/tensor/identity_op.h
@@ -53,14 +53,11 @@ class IdentityOp final : public CudaKernel {
}
} else if (X_ml_type->IsTensorSequenceType()) {
const TensorSeq* X = context->Input(0);
- if (nullptr == X) {
- return Status(common::ONNXRUNTIME, common::FAIL,
- "IdentityOp cuda: input tensor is missing.");
- }
+ ORT_ENFORCE(X != nullptr, "IdentityOp cuda: input tensor is missing.");
TensorSeq* Y = context->Output(0);
- if (nullptr == Y) {
- return Status(common::ONNXRUNTIME, common::FAIL,
- "IdentityOp cuda: failed to allocate output tensor sequence.");
+ ORT_ENFORCE(Y != nullptr, "IdentityOp cuda: failed to allocate output tensor sequence.");
+ if (X == Y) {
+ return Status::OK();
}
auto X_type = X->DataType();
Y->SetType(X_type);
@@ -73,7 +70,8 @@ class IdentityOp final : public CudaKernel {
auto X_size = X->Size();
for (size_t i = 0; i < X_size; ++i) {
const Tensor& source_tensor = X->Get(i);
- std::unique_ptr target_tensor = Tensor::Create(X_type, source_tensor.Shape(), alloc);
+ std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(),
+ source_tensor.Shape(), alloc);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
source_tensor.DataRaw(),
source_tensor.SizeInBytes(),
diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.cc b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc
new file mode 100644
index 0000000000..71f277c48b
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.cc
@@ -0,0 +1,87 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "sequence_op.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceAt,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .TypeConstraint("I", std::vector{
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()}),
+ SequenceAt);
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceConstruct,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
+ SequenceConstruct);
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceEmpty,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
+ SequenceEmpty);
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceLength,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
+ .TypeConstraint("I", DataTypeImpl::GetTensorType()),
+ SequenceLength);
+
+ONNX_OPERATOR_KERNEL_EX(
+ ConcatFromSequence,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
+ ConcatFromSequence);
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceErase,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
+ .TypeConstraint("I", std::vector{
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()}),
+ SequenceErase);
+
+ONNX_OPERATOR_KERNEL_EX(
+ SequenceInsert,
+ kOnnxDomain,
+ 11,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
+ .TypeConstraint("I", std::vector{
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()}),
+ SequenceInsert);
+
+} // namespace cuda
+} // namespace onnxruntime
+
+
diff --git a/onnxruntime/core/providers/cuda/tensor/sequence_op.h b/onnxruntime/core/providers/cuda/tensor/sequence_op.h
new file mode 100644
index 0000000000..d5ed1c06eb
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/sequence_op.h
@@ -0,0 +1,276 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/shared_library/provider_api.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/tensor/concat.h"
+#include "core/providers/cuda/tensor/concat_impl.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+template
+int64_t ReadIndex(const Tensor& tensor, const char* type_name) {
+ DataType data{0};
+ if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data(),
+ sizeof(DataType), cudaMemcpyDeviceToHost))) {
+ ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, ".");
+ }
+ return static_cast(data);
+}
+
+class SequenceAt final : public CudaKernel {
+ public:
+ SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override {
+ const TensorSeq* X = context->Input(0);
+ ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input.");
+ const Tensor* I = context->Input(1);
+ ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor.");
+
+ int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t");
+ int64_t sequence_size = static_cast(X->Size());
+ if (idx < 0) {
+ idx = sequence_size + idx;
+ }
+ ORT_ENFORCE(idx >= 0 && idx < sequence_size, "SequenceAt GPU: Invalid sequence index.");
+
+ const Tensor& source_tensor = X->Get(idx);
+ auto source_type = source_tensor.DataType();
+ const void* source_addr = source_tensor.DataRaw(source_type);
+
+ Tensor* target_tensor = context->Output(0, source_tensor.Shape());
+ ORT_ENFORCE(target_tensor != nullptr, "SequenceAt GPU: Got nullptr for output tensor.");
+ void* target_addr = target_tensor->MutableDataRaw(source_type);
+
+ if (source_addr != target_addr) {
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_addr,
+ source_addr,
+ source_tensor.SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ }
+ return Status::OK();
+ }
+}; // SequenceAt
+
+class SequenceConstruct final : public CudaKernel {
+ public:
+ SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override {
+ TensorSeq* Y = context->Output(0);
+ ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence.");
+
+ AllocatorPtr alloc;
+ ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
+ "SequenceConstruct GPU: Unable to get an allocator.");
+
+ int32_t at = 0;
+ const Tensor* source_tensor = nullptr;
+ while (nullptr != (source_tensor = context->Input(at++))) {
+ if (1 == at) {
+ Y->SetType(source_tensor->DataType());
+ }
+ std::unique_ptr target_tensor = Tensor::Create(source_tensor->DataType(),
+ source_tensor->Shape(), alloc);
+ ORT_ENFORCE(target_tensor, "SequenceConstruct GPU: Failed to allocate new tensor.");
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
+ source_tensor->DataRaw(),
+ source_tensor->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ Y->Add(std::move(*target_tensor)); // Add will check type consistency inside
+ }
+ return Status::OK();
+ }
+}; // SequenceConstruct
+
+class SequenceEmpty final : public CudaKernel {
+ public:
+ SequenceEmpty(const OpKernelInfo& info) : CudaKernel(info) {
+ if (!info.GetAttr("dtype", &dtype_).IsOK()) {
+ dtype_ = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
+ }
+ }
+ Status ComputeInternal(OpKernelContext* context) const override {
+ TensorSeq* Y = context->Output(0);
+ ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence.");
+#ifdef SHARED_PROVIDER
+ Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast(dtype_)));
+#else
+ Y->SetType(DataTypeImpl::TensorTypeFromONNXEnum(static_cast(dtype_))->GetElementType());
+#endif
+ return Status::OK();
+ }
+
+ private:
+ int64_t dtype_{};
+}; // SequenceEmpty
+
+class SequenceLength final : public CudaKernel {
+ public:
+ SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override {
+ const TensorSeq* X = context->Input(0);
+ ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr.");
+ Tensor* Y = context->Output(0, {});
+ ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence.");
+ auto X_size = static_cast(X->Size());
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(),
+ &X_size,
+ sizeof(int64_t),
+ cudaMemcpyHostToDevice, Stream()));
+ return Status::OK();
+ }
+}; // SequenceLength
+
+class ConcatFromSequence final : public CudaKernel, public ConcatBase {
+ public:
+ ConcatFromSequence(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info, true) {}
+
+ Status ComputeInternal(OpKernelContext* context) const override {
+ const TensorSeq* X = context->Input(0);
+ ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr.");
+ int64_t input_count = static_cast(X->Size());
+ std::vector input_tensors;
+ for (int64_t i = 0; i < input_count; ++i) {
+ input_tensors.push_back(&X->Get(i));
+ }
+ Prepare p;
+ ORT_RETURN_IF_ERROR(PrepareForCompute(context, input_tensors, p));
+ if (0 == p.output_num_elements) {
+ return Status::OK();
+ }
+
+ int64_t initial_output_offset = 0;
+ auto element_bytes = p.output_tensor->DataType()->Size();
+ for (int input_index = 0; input_index < input_count; input_index++) {
+ const auto& prep = p.inputs[input_index];
+ if (prep.num_elements == 0) {
+ continue;
+ }
+ auto input_axis_pitch = prep.axis_pitch;
+ const uint8_t* input = static_cast(prep.tensor->DataRaw());
+
+ auto input_size = prep.num_elements;
+ uint8_t* output = static_cast(p.output_tensor->MutableDataRaw());
+ int64_t cur_out_offset = 0;
+ int64_t cur_in_offset = 0;
+ for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) {
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
+ output + (initial_output_offset + cur_out_offset) * element_bytes,
+ input + cur_in_offset * element_bytes, input_axis_pitch * element_bytes,
+ cudaMemcpyHostToDevice, Stream()));
+ cur_out_offset += p.output_axis_pitch;
+ cur_in_offset += input_axis_pitch;
+ }
+ initial_output_offset += input_axis_pitch;
+ }
+ return Status::OK();
+ }
+}; // ConcatFromSequence
+
+class SequenceErase final : public CudaKernel {
+ public:
+ SequenceErase(const OpKernelInfo& info) : CudaKernel(info) {}
+
+ Status ComputeInternal(OpKernelContext* context) const override {
+ const TensorSeq* X = context->Input(0);
+ ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input.");
+ int64_t X_size = static_cast(X->Size());
+ int64_t idx = X_size - 1;
+ const Tensor* I = context->Input(1);
+ if (I != nullptr) {
+ idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t");
+ if (idx < 0) {
+ idx = X_size + idx;
+ }
+ ORT_ENFORCE(idx >= 0 && idx < X_size, "SequenceErase GPU: Invalid sequence index.");
+ }
+
+ AllocatorPtr alloc;
+ ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
+ "SequenceErase GPU: Unable to get an allocator.");
+ TensorSeq* Y = context->Output(0);
+ ORT_ENFORCE(Y != nullptr, "SequenceErase GPU: Failed to allocate output tensor sequence.");
+ Y->SetType(X->DataType());
+ for (int64_t i = 0; i < X_size; ++i) {
+ if (i == idx) {
+ continue;
+ }
+ const Tensor& source_tensor = X->Get(i);
+ std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(),
+ source_tensor.Shape(), alloc);
+
+ ORT_ENFORCE(target_tensor, "SequenceErase GPU: Failed to allocate new tensor.");
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
+ source_tensor.DataRaw(),
+ source_tensor.SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ Y->Add(std::move(*target_tensor));
+ }
+ return Status::OK();
+ }
+}; // SequenceErase
+
+class SequenceInsert final : public CudaKernel {
+ public:
+ SequenceInsert(const OpKernelInfo& info) : CudaKernel(info) {}
+
+ Status ComputeInternal(OpKernelContext* context) const override {
+ const TensorSeq* S = context->Input(0);
+ ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input.");
+ int64_t S_size = static_cast(S->Size());
+ int64_t idx = S_size;
+ const Tensor* I = context->Input(2);
+ if (I != nullptr) {
+ idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex(*I, "int32_t") : ReadIndex(*I, "int64_t");
+ if (idx < 0) {
+ idx = S_size + idx;
+ }
+ ORT_ENFORCE(idx >= 0 && idx <= S_size, "SequenceInsert GPU: Invalid sequence index.");
+ }
+ const Tensor* X = context->Input(1);
+ ORT_ENFORCE(X != nullptr, "SequenceInsert GPU: Got nullptr for tensor input.");
+
+ AllocatorPtr alloc;
+ ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
+ "SequenceInsert GPU: Unable to get an allocator.");
+
+ TensorSeq* Y = context->Output(0);
+ ORT_ENFORCE(Y != nullptr, "SequenceInsert GPU: Failed to allocate output tensor sequence.");
+ Y->SetType(S->DataType());
+ for (int64_t i = 0; i < S_size; ++i) {
+ if (i == idx) {
+ std::unique_ptr target_tensor = Tensor::Create(X->DataType(),
+ X->Shape(), alloc);
+ ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
+ X->DataRaw(), X->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ Y->Add(std::move(*target_tensor));
+ }
+ const Tensor& source_tensor = S->Get(i);
+ std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(),
+ source_tensor.Shape(), alloc);
+ ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
+ source_tensor.DataRaw(),
+ source_tensor.SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ Y->Add(std::move(*target_tensor)); // Add will check type consistency inside
+ } // for
+ if (idx == S_size) {
+ std::unique_ptr target_tensor = Tensor::Create(X->DataType(),
+ X->Shape(), alloc);
+ ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
+ X->DataRaw(), X->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream()));
+ Y->Add(std::move(*target_tensor));
+ }
+ return Status::OK();
+ }
+}; // SequenceInsert
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
index 4c5377c87e..1618e03c35 100644
--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
+++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
@@ -105,6 +105,7 @@ template <>
MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Tensor(); }
template <>
MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_TensorSeq(); }
+MLDataType DataTypeImpl::GetTypeFromOnnxType(int onnx_type) { return Provider_GetHost()->DataTypeImpl__GetTypeFromOnnxType(onnx_type); }
template <>
MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_bool(); }
template <>
@@ -132,6 +133,8 @@ MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTy
template <>
MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_MLFloat16(); }
template <>
+MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_string(); }
+template <>
MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_bool(); }
template <>
MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_int8(); }
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index e75a8969c7..3468bfbdba 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -423,6 +423,7 @@ struct ProviderHost {
// DataTypeImpl
virtual MLDataType DataTypeImpl__GetType_Tensor() = 0;
virtual MLDataType DataTypeImpl__GetType_TensorSeq() = 0;
+ virtual MLDataType DataTypeImpl__GetTypeFromOnnxType(int) = 0;
virtual MLDataType DataTypeImpl__GetType_bool() = 0;
virtual MLDataType DataTypeImpl__GetType_int8() = 0;
virtual MLDataType DataTypeImpl__GetType_uint8() = 0;
@@ -436,6 +437,7 @@ struct ProviderHost {
virtual MLDataType DataTypeImpl__GetType_double() = 0;
virtual MLDataType DataTypeImpl__GetType_BFloat16() = 0;
virtual MLDataType DataTypeImpl__GetType_MLFloat16() = 0;
+ virtual MLDataType DataTypeImpl__GetType_string() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_int8() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_uint8() = 0;
@@ -458,6 +460,9 @@ struct ProviderHost {
virtual const std::vector& DataTypeImpl__AllTensorTypes() = 0;
virtual const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() = 0;
virtual const std::vector& DataTypeImpl__AllTensorAndSequenceTensorTypes() = 0;
+ virtual const std::vector& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() = 0;
+ virtual const std::vector& DataTypeImpl__AllSequenceTensorTypes() = 0;
+ virtual const std::vector& DataTypeImpl__AllFixedSizeSequenceTensorTypes() = 0;
virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0;
virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0;
@@ -1282,6 +1287,7 @@ class DataTypeImpl final {
static MLDataType GetType();
template
static MLDataType GetTensorType();
+ static MLDataType GetTypeFromOnnxType(int);
bool IsTensorType() const { return g_host->DataTypeImpl__IsTensorType(this); }
bool IsTensorSequenceType() const { return g_host->DataTypeImpl__IsTensorSequenceType(this); }
@@ -1292,6 +1298,9 @@ class DataTypeImpl final {
static const std::vector& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); }
static const std::vector& AllIEEEFloatTensorTypes() { return g_host->DataTypeImpl__AllIEEEFloatTensorTypes(); }
static const std::vector& AllTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllTensorAndSequenceTensorTypes(); }
+ static const std::vector& AllFixedSizeTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes(); }
+ static const std::vector& AllSequenceTensorTypes() { return g_host->DataTypeImpl__AllSequenceTensorTypes(); }
+ static const std::vector& AllFixedSizeSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeSequenceTensorTypes(); }
const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); }
diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc
index 07636bbc15..f574fd3eba 100644
--- a/onnxruntime/core/session/environment.cc
+++ b/onnxruntime/core/session/environment.cc
@@ -202,13 +202,24 @@ Status Environment::Initialize(std::unique_ptr logging_
// Register MemCpy schema;
// These ops are internal-only, so register outside of onnx
+ static std::vector all_fixed_size_types = []() {
+ std::vector all_types;
+ std::vector all_tensor_types = OpSchema::all_tensor_types_with_bfloat();
+ std::vector all_sequence_types = OpSchema::all_tensor_sequence_types();
+ all_types.insert(all_types.end(), all_tensor_types.begin(), all_tensor_types.end());
+ all_types.insert(all_types.end(), all_sequence_types.begin(), all_sequence_types.end());
+ all_types.emplace_back("seq(tensor(bfloat16))");
+ all_types.erase(std::remove_if(all_types.begin(), all_types.end(),
+ [](const std::string& s) { return s.find("string") != std::string::npos; }), all_types.end());
+ return all_types; }();
+
ORT_ATTRIBUTE_UNUSED ONNX_OPERATOR_SCHEMA(MemcpyFromHost)
.Input(0, "X", "input", "T")
.Output(0, "Y", "output", "T")
.TypeConstraint(
"T",
- OpSchema::all_tensor_types_with_bfloat(),
- "Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
+ all_fixed_size_types,
+ "Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(
Internal copy node
@@ -219,8 +230,8 @@ Internal copy node
.Output(0, "Y", "output", "T")
.TypeConstraint(
"T",
- OpSchema::all_tensor_types_with_bfloat(),
- "Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
+ all_fixed_size_types,
+ "Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(
Internal copy node