Implement Sequence Ops GPU (#7863)

This commit is contained in:
RandySheriffH 2021-06-07 15:30:26 -07:00 committed by GitHub
parent 9e4dc08483
commit 1a5ee11dbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 477 additions and 42 deletions

View file

@ -458,6 +458,7 @@ Do not modify directly.*
|Concat|*in* inputs:**T**<br> *out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ConcatFromSequence|*in* input_sequence:**S**<br> *out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|9+|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@ -514,7 +515,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|If|*in* cond:**B**<br> *out* outputs:**V**|13+|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@ -552,8 +553,8 @@ Do not modify directly.*
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|||[8, 9]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 7]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16)|
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|MemcpyFromHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|MemcpyToHost|*in* X:**T**<br> *out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Min|*in* data_0:**T**<br> *out* min:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
@ -637,6 +638,12 @@ Do not modify directly.*
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|SequenceAt|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SequenceConstruct|*in* inputs:**T**<br> *out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceErase|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**<br> *in* tensor:**T**<br> *in* position:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**<br> *out* length:**I**|11+|**I** = tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|Shape|*in* data:**T**<br> *out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -495,6 +495,7 @@ struct ProviderHostImpl : ProviderHost {
// DataTypeImpl (wrapped)
MLDataType DataTypeImpl__GetType_Tensor() override { return DataTypeImpl::GetType<Tensor>(); }
MLDataType DataTypeImpl__GetType_TensorSeq () override { return DataTypeImpl::GetType<TensorSeq>(); }
MLDataType DataTypeImpl__GetTypeFromOnnxType (int onnx_type) override { return DataTypeImpl::TensorTypeFromONNXEnum(onnx_type)->GetElementType(); }
MLDataType DataTypeImpl__GetType_bool() override { return DataTypeImpl::GetType<bool>(); }
MLDataType DataTypeImpl__GetType_int8() override { return DataTypeImpl::GetType<int8_t>(); }
MLDataType DataTypeImpl__GetType_uint8() override { return DataTypeImpl::GetType<uint8_t>(); }
@ -508,6 +509,7 @@ struct ProviderHostImpl : ProviderHost {
MLDataType DataTypeImpl__GetType_double() override { return DataTypeImpl::GetType<double>(); }
MLDataType DataTypeImpl__GetType_BFloat16() override { return DataTypeImpl::GetType<BFloat16>(); }
MLDataType DataTypeImpl__GetType_MLFloat16() override { return DataTypeImpl::GetType<MLFloat16>(); }
MLDataType DataTypeImpl__GetType_string() override { return DataTypeImpl::GetType<std::string>(); }
MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType<bool>(); }
MLDataType DataTypeImpl__GetTensorType_int8() override { return DataTypeImpl::GetTensorType<int8_t>(); }
MLDataType DataTypeImpl__GetTensorType_uint8() override { return DataTypeImpl::GetTensorType<uint8_t>(); }
@ -521,7 +523,6 @@ struct ProviderHostImpl : ProviderHost {
MLDataType DataTypeImpl__GetTensorType_double() override { return DataTypeImpl::GetTensorType<double>(); }
MLDataType DataTypeImpl__GetTensorType_BFloat16() override { return DataTypeImpl::GetTensorType<BFloat16>(); }
MLDataType DataTypeImpl__GetTensorType_MLFloat16() override { return DataTypeImpl::GetTensorType<MLFloat16>(); }
const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); }
bool DataTypeImpl__IsTensorType(const DataTypeImpl* p) override { return p->IsTensorType(); }
bool DataTypeImpl__IsTensorSequenceType(const DataTypeImpl* p) override { return p->IsTensorSequenceType(); }
@ -531,6 +532,9 @@ struct ProviderHostImpl : ProviderHost {
const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllIEEEFloatTensorTypes() override { return DataTypeImpl::AllIEEEFloatTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllTensorAndSequenceTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllSequenceTensorTypes() override { return DataTypeImpl::AllSequenceTensorTypes(); }
const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeSequenceTensorTypes() override { return DataTypeImpl::AllFixedSizeSequenceTensorTypes(); }
size_t DataTypeImpl__Size(const DataTypeImpl* p) override { return p->Size(); }
const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) override { return p->AsPrimitiveDataType(); }

View file

@ -53,22 +53,12 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe
}
AllocatorPtr SessionState::GetAllocator(OrtDevice device) const noexcept {
AllocatorPtr result;
using AllocatorEntry = std::map<OrtMemoryInfo, std::function<AllocatorPtr(int id, OrtMemType mem_type)>,
OrtMemoryInfoLessThanIgnoreAllocType>::const_reference;
auto entry = std::find_if(allocators_.cbegin(), allocators_.cend(),
[device](AllocatorEntry& entry) {
return entry.first.device == device &&
entry.first.mem_type == OrtMemTypeDefault;
});
if (entry != allocators_.cend()) {
result = entry->second(device.Id(), OrtMemTypeDefault);
for (const auto& iter : allocators_) {
if (iter.first.device == device) {
return iter.second(device.Id(), iter.first.mem_type);
}
}
return result;
return nullptr;
}
void SessionState::CreateGraphInfo() {

View file

@ -174,8 +174,8 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
return Status::OK();
}
auto allocator = session_state.GetAllocator(copy_info.target_device);
if (!target_mlvalue.IsAllocated()) {
auto allocator = session_state.GetAllocator(copy_info.target_device);
ORT_ENFORCE(allocator != nullptr, "Failed to find allocator for device ", copy_info.target_device.ToString());
ORT_RETURN_IF_ERROR(utils::AllocateHelper(allocator, source_mlvalue, target_mlvalue));
}
@ -190,9 +190,16 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
}
} else if (source_mlvalue.IsTensorSequence()) {
const TensorSeq& source_tensor_seq = source_mlvalue.Get<TensorSeq>();
const TensorSeq& target_tensor_seq = target_mlvalue.Get<TensorSeq>();
ORT_ENFORCE(source_tensor_seq.Size() == target_tensor_seq.Size(),
"source and target tensor sequence have different number of elements.");
TensorSeq& target_tensor_seq = const_cast<TensorSeq&>(target_mlvalue.Get<TensorSeq>());
size_t size = 0;
while ((size = target_tensor_seq.Size()) < source_tensor_seq.Size()) {
if (0 == size) {
target_tensor_seq.SetType(source_tensor_seq.DataType());
}
const Tensor& source_tensor = source_tensor_seq.Get(size);
std::unique_ptr<Tensor> target_tensor = std::make_unique<Tensor>(source_tensor.DataType(), source_tensor.Shape(), allocator);
target_tensor_seq.Add(std::move(*target_tensor));
}
auto source_iter = source_tensor_seq.begin();
auto target_iter = target_tensor_seq.begin();
while (source_iter != source_tensor_seq.end() &&

View file

@ -26,10 +26,39 @@ class Memcpy final : public OpKernel {
Memcpy(const OpKernelInfo& info) : OpKernel{info} {}
Status Compute(OpKernelContext* ctx) const override {
const auto* X = ctx->Input<Tensor>(0);
Tensor* Y = ctx->Output(0, X->Shape());
Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
return retval;
auto X_type = ctx->InputType(0);
if (X_type->IsTensorType()) {
const auto* X = ctx->Input<Tensor>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
Tensor* Y = ctx->Output(0, X->Shape());
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor.");
return Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId());
} else if (X_type->IsTensorSequenceType()) {
const TensorSeq* X = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr.");
TensorSeq* Y = ctx->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence.");
auto X_dtype = X->DataType();
Y->SetType(X_dtype);
AllocatorPtr alloc;
auto status = ctx->GetTempSpaceAllocator(&alloc);
if (!status.IsOK()) {
return Status(common::ONNXRUNTIME, common::FAIL,
"Memcpy cuda: unable to get an allocator.");
}
auto X_size = X->Size();
for (size_t i = 0; i < X_size; ++i) {
const Tensor& source_tensor = X->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc);
Status retval = Info().GetDataTransferManager().CopyTensor(source_tensor, *target_tensor, Info().GetKernelDef().ExecQueueId());
if (!retval.IsOK()) {
return retval;
}
Y->Add(std::move(*target_tensor));
}
return Status::OK();
}
return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type.");
}
};
@ -42,7 +71,7 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.ExecQueueId(kCudaStreamCopyIn)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
ONNX_OPERATOR_KERNEL_EX(
@ -53,7 +82,7 @@ ONNX_OPERATOR_KERNEL_EX(
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
.ExecQueueId(kCudaStreamCopyOut)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()),
Memcpy);
} // namespace cuda
@ -770,6 +799,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceLength);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceErase);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceInsert);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Conv);
@ -1589,6 +1625,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceAt)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceConstruct)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceEmpty)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceLength)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, ConcatFromSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceErase)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, SequenceInsert)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Conv)>,

View file

@ -57,7 +57,7 @@ ONNX_OPERATOR_KERNEL_EX(
14,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes())
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes())
.Alias(0, 0),
IdentityOp<false>);
} // namespace cuda

View file

@ -53,14 +53,11 @@ class IdentityOp final : public CudaKernel {
}
} else if (X_ml_type->IsTensorSequenceType()) {
const TensorSeq* X = context->Input<TensorSeq>(0);
if (nullptr == X) {
return Status(common::ONNXRUNTIME, common::FAIL,
"IdentityOp cuda: input tensor is missing.");
}
ORT_ENFORCE(X != nullptr, "IdentityOp cuda: input tensor is missing.");
TensorSeq* Y = context->Output<TensorSeq>(0);
if (nullptr == Y) {
return Status(common::ONNXRUNTIME, common::FAIL,
"IdentityOp cuda: failed to allocate output tensor sequence.");
ORT_ENFORCE(Y != nullptr, "IdentityOp cuda: failed to allocate output tensor sequence.");
if (X == Y) {
return Status::OK();
}
auto X_type = X->DataType();
Y->SetType(X_type);
@ -73,7 +70,8 @@ class IdentityOp final : public CudaKernel {
auto X_size = X->Size();
for (size_t i = 0; i < X_size; ++i) {
const Tensor& source_tensor = X->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(X_type, source_tensor.Shape(), alloc);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
source_tensor.Shape(), alloc);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
source_tensor.DataRaw(),
source_tensor.SizeInBytes(),

View file

@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "sequence_op.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
SequenceAt,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
SequenceAt);
ONNX_OPERATOR_KERNEL_EX(
SequenceConstruct,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
SequenceConstruct);
ONNX_OPERATOR_KERNEL_EX(
SequenceEmpty,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
SequenceEmpty);
ONNX_OPERATOR_KERNEL_EX(
SequenceLength,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
SequenceLength);
ONNX_OPERATOR_KERNEL_EX(
ConcatFromSequence,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
ConcatFromSequence);
ONNX_OPERATOR_KERNEL_EX(
SequenceErase,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
SequenceErase);
ONNX_OPERATOR_KERNEL_EX(
SequenceInsert,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("I", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
SequenceInsert);
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,276 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/tensor/concat.h"
#include "core/providers/cuda/tensor/concat_impl.h"
namespace onnxruntime {
namespace cuda {
template <typename DataType>
int64_t ReadIndex(const Tensor& tensor, const char* type_name) {
DataType data{0};
if (!CUDA_CALL(cudaMemcpy(&data, tensor.Data<DataType>(),
sizeof(DataType), cudaMemcpyDeviceToHost))) {
ORT_THROW("Cuda: Failed to read tensor data as type: ", type_name, ".");
}
return static_cast<int64_t>(data);
}
class SequenceAt final : public CudaKernel {
public:
SequenceAt(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceAt GPU: Got nullptr for sequence input.");
const Tensor* I = context->Input<Tensor>(1);
ORT_ENFORCE(I != nullptr, "SequenceAt GPU: Got nullptr input for index tensor.");
int64_t idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
int64_t sequence_size = static_cast<int64_t>(X->Size());
if (idx < 0) {
idx = sequence_size + idx;
}
ORT_ENFORCE(idx >= 0 && idx < sequence_size, "SequenceAt GPU: Invalid sequence index.");
const Tensor& source_tensor = X->Get(idx);
auto source_type = source_tensor.DataType();
const void* source_addr = source_tensor.DataRaw(source_type);
Tensor* target_tensor = context->Output(0, source_tensor.Shape());
ORT_ENFORCE(target_tensor != nullptr, "SequenceAt GPU: Got nullptr for output tensor.");
void* target_addr = target_tensor->MutableDataRaw(source_type);
if (source_addr != target_addr) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_addr,
source_addr,
source_tensor.SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
}
return Status::OK();
}
}; // SequenceAt
class SequenceConstruct final : public CudaKernel {
public:
SequenceConstruct(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceConstruct GPU: Got nullptr for output sequence.");
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
"SequenceConstruct GPU: Unable to get an allocator.");
int32_t at = 0;
const Tensor* source_tensor = nullptr;
while (nullptr != (source_tensor = context->Input<Tensor>(at++))) {
if (1 == at) {
Y->SetType(source_tensor->DataType());
}
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor->DataType(),
source_tensor->Shape(), alloc);
ORT_ENFORCE(target_tensor, "SequenceConstruct GPU: Failed to allocate new tensor.");
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
source_tensor->DataRaw(),
source_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
Y->Add(std::move(*target_tensor)); // Add will check type consistency inside
}
return Status::OK();
}
}; // SequenceConstruct
class SequenceEmpty final : public CudaKernel {
public:
SequenceEmpty(const OpKernelInfo& info) : CudaKernel(info) {
if (!info.GetAttr("dtype", &dtype_).IsOK()) {
dtype_ = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
}
}
Status ComputeInternal(OpKernelContext* context) const override {
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceEmpty GPU: Failed to allocate output tensor sequence.");
#ifdef SHARED_PROVIDER
Y->SetType(DataTypeImpl::GetTypeFromOnnxType(static_cast<int>(dtype_)));
#else
Y->SetType(DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(dtype_))->GetElementType());
#endif
return Status::OK();
}
private:
int64_t dtype_{};
}; // SequenceEmpty
class SequenceLength final : public CudaKernel {
public:
SequenceLength(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceLength GPU: Input tensor sequence is nullptr.");
Tensor* Y = context->Output(0, {});
ORT_ENFORCE(Y != nullptr, "SequenceLength GPU: Failed to allocate output tensor sequence.");
auto X_size = static_cast<int64_t>(X->Size());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->MutableDataRaw(),
&X_size,
sizeof(int64_t),
cudaMemcpyHostToDevice, Stream()));
return Status::OK();
}
}; // SequenceLength
class ConcatFromSequence final : public CudaKernel, public ConcatBase {
public:
ConcatFromSequence(const OpKernelInfo& info) : CudaKernel(info), ConcatBase(info, true) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "ConcatFromSequence GPU: Input tensor sequence is nullptr.");
int64_t input_count = static_cast<int64_t>(X->Size());
std::vector<const Tensor*> input_tensors;
for (int64_t i = 0; i < input_count; ++i) {
input_tensors.push_back(&X->Get(i));
}
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, input_tensors, p));
if (0 == p.output_num_elements) {
return Status::OK();
}
int64_t initial_output_offset = 0;
auto element_bytes = p.output_tensor->DataType()->Size();
for (int input_index = 0; input_index < input_count; input_index++) {
const auto& prep = p.inputs[input_index];
if (prep.num_elements == 0) {
continue;
}
auto input_axis_pitch = prep.axis_pitch;
const uint8_t* input = static_cast<const uint8_t*>(prep.tensor->DataRaw());
auto input_size = prep.num_elements;
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
int64_t cur_out_offset = 0;
int64_t cur_in_offset = 0;
for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
output + (initial_output_offset + cur_out_offset) * element_bytes,
input + cur_in_offset * element_bytes, input_axis_pitch * element_bytes,
cudaMemcpyHostToDevice, Stream()));
cur_out_offset += p.output_axis_pitch;
cur_in_offset += input_axis_pitch;
}
initial_output_offset += input_axis_pitch;
}
return Status::OK();
}
}; // ConcatFromSequence
class SequenceErase final : public CudaKernel {
public:
SequenceErase(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* X = context->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "SequenceErase GPU: Got nullptr for sequence input.");
int64_t X_size = static_cast<int64_t>(X->Size());
int64_t idx = X_size - 1;
const Tensor* I = context->Input<Tensor>(1);
if (I != nullptr) {
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
if (idx < 0) {
idx = X_size + idx;
}
ORT_ENFORCE(idx >= 0 && idx < X_size, "SequenceErase GPU: Invalid sequence index.");
}
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
"SequenceErase GPU: Unable to get an allocator.");
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceErase GPU: Failed to allocate output tensor sequence.");
Y->SetType(X->DataType());
for (int64_t i = 0; i < X_size; ++i) {
if (i == idx) {
continue;
}
const Tensor& source_tensor = X->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
source_tensor.Shape(), alloc);
ORT_ENFORCE(target_tensor, "SequenceErase GPU: Failed to allocate new tensor.");
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
source_tensor.DataRaw(),
source_tensor.SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
Y->Add(std::move(*target_tensor));
}
return Status::OK();
}
}; // SequenceErase
class SequenceInsert final : public CudaKernel {
public:
SequenceInsert(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override {
const TensorSeq* S = context->Input<TensorSeq>(0);
ORT_ENFORCE(S != nullptr, "SequenceInsert GPU: Got nullptr for sequence input.");
int64_t S_size = static_cast<int64_t>(S->Size());
int64_t idx = S_size;
const Tensor* I = context->Input<Tensor>(2);
if (I != nullptr) {
idx = I->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32 ? ReadIndex<int32_t>(*I, "int32_t") : ReadIndex<int64_t>(*I, "int64_t");
if (idx < 0) {
idx = S_size + idx;
}
ORT_ENFORCE(idx >= 0 && idx <= S_size, "SequenceInsert GPU: Invalid sequence index.");
}
const Tensor* X = context->Input<Tensor>(1);
ORT_ENFORCE(X != nullptr, "SequenceInsert GPU: Got nullptr for tensor input.");
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc).IsOK(),
"SequenceInsert GPU: Unable to get an allocator.");
TensorSeq* Y = context->Output<TensorSeq>(0);
ORT_ENFORCE(Y != nullptr, "SequenceInsert GPU: Failed to allocate output tensor sequence.");
Y->SetType(S->DataType());
for (int64_t i = 0; i < S_size; ++i) {
if (i == idx) {
std::unique_ptr<Tensor> target_tensor = Tensor::Create(X->DataType(),
X->Shape(), alloc);
ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
X->DataRaw(), X->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
Y->Add(std::move(*target_tensor));
}
const Tensor& source_tensor = S->Get(i);
std::unique_ptr<Tensor> target_tensor = Tensor::Create(source_tensor.DataType(),
source_tensor.Shape(), alloc);
ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
source_tensor.DataRaw(),
source_tensor.SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
Y->Add(std::move(*target_tensor)); // Add will check type consistency inside
} // for
if (idx == S_size) {
std::unique_ptr<Tensor> target_tensor = Tensor::Create(X->DataType(),
X->Shape(), alloc);
ORT_ENFORCE(target_tensor, "SequenceInsert GPU: Failed to allocate new tensor.");
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target_tensor->MutableDataRaw(),
X->DataRaw(), X->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
Y->Add(std::move(*target_tensor));
}
return Status::OK();
}
}; // SequenceInsert
} // namespace cuda
} // namespace onnxruntime

View file

@ -105,6 +105,7 @@ template <>
MLDataType DataTypeImpl::GetType<Tensor>() { return Provider_GetHost()->DataTypeImpl__GetType_Tensor(); }
template <>
MLDataType DataTypeImpl::GetType<TensorSeq>() { return Provider_GetHost()->DataTypeImpl__GetType_TensorSeq(); }
MLDataType DataTypeImpl::GetTypeFromOnnxType(int onnx_type) { return Provider_GetHost()->DataTypeImpl__GetTypeFromOnnxType(onnx_type); }
template <>
MLDataType DataTypeImpl::GetType<bool>() { return Provider_GetHost()->DataTypeImpl__GetType_bool(); }
template <>
@ -132,6 +133,8 @@ MLDataType DataTypeImpl::GetType<BFloat16>() { return Provider_GetHost()->DataTy
template <>
MLDataType DataTypeImpl::GetType<MLFloat16>() { return Provider_GetHost()->DataTypeImpl__GetType_MLFloat16(); }
template <>
MLDataType DataTypeImpl::GetType<std::string>() { return Provider_GetHost()->DataTypeImpl__GetType_string(); }
template <>
MLDataType DataTypeImpl::GetTensorType<bool>() { return Provider_GetHost()->DataTypeImpl__GetTensorType_bool(); }
template <>
MLDataType DataTypeImpl::GetTensorType<int8_t>() { return Provider_GetHost()->DataTypeImpl__GetTensorType_int8(); }

View file

@ -423,6 +423,7 @@ struct ProviderHost {
// DataTypeImpl
virtual MLDataType DataTypeImpl__GetType_Tensor() = 0;
virtual MLDataType DataTypeImpl__GetType_TensorSeq() = 0;
virtual MLDataType DataTypeImpl__GetTypeFromOnnxType(int) = 0;
virtual MLDataType DataTypeImpl__GetType_bool() = 0;
virtual MLDataType DataTypeImpl__GetType_int8() = 0;
virtual MLDataType DataTypeImpl__GetType_uint8() = 0;
@ -436,6 +437,7 @@ struct ProviderHost {
virtual MLDataType DataTypeImpl__GetType_double() = 0;
virtual MLDataType DataTypeImpl__GetType_BFloat16() = 0;
virtual MLDataType DataTypeImpl__GetType_MLFloat16() = 0;
virtual MLDataType DataTypeImpl__GetType_string() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_int8() = 0;
virtual MLDataType DataTypeImpl__GetTensorType_uint8() = 0;
@ -458,6 +460,9 @@ struct ProviderHost {
virtual const std::vector<MLDataType>& DataTypeImpl__AllTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllIEEEFloatTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllTensorAndSequenceTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllSequenceTensorTypes() = 0;
virtual const std::vector<MLDataType>& DataTypeImpl__AllFixedSizeSequenceTensorTypes() = 0;
virtual size_t DataTypeImpl__Size(const DataTypeImpl* p) = 0;
virtual const PrimitiveDataTypeBase* DataTypeImpl__AsPrimitiveDataType(const DataTypeImpl* p) = 0;
@ -1282,6 +1287,7 @@ class DataTypeImpl final {
static MLDataType GetType();
template <typename elemT>
static MLDataType GetTensorType();
static MLDataType GetTypeFromOnnxType(int);
bool IsTensorType() const { return g_host->DataTypeImpl__IsTensorType(this); }
bool IsTensorSequenceType() const { return g_host->DataTypeImpl__IsTensorSequenceType(this); }
@ -1292,6 +1298,9 @@ class DataTypeImpl final {
static const std::vector<MLDataType>& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); }
static const std::vector<MLDataType>& AllIEEEFloatTensorTypes() { return g_host->DataTypeImpl__AllIEEEFloatTensorTypes(); }
static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllTensorAndSequenceTensorTypes(); }
static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeTensorAndSequenceTensorTypes(); }
static const std::vector<MLDataType>& AllSequenceTensorTypes() { return g_host->DataTypeImpl__AllSequenceTensorTypes(); }
static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes() { return g_host->DataTypeImpl__AllFixedSizeSequenceTensorTypes(); }
const PrimitiveDataTypeBase* AsPrimitiveDataType() const { return g_host->DataTypeImpl__AsPrimitiveDataType(this); }

View file

@ -202,13 +202,24 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
// Register MemCpy schema;
// These ops are internal-only, so register outside of onnx
static std::vector<std::string> all_fixed_size_types = []() {
std::vector<std::string> all_types;
std::vector<std::string> all_tensor_types = OpSchema::all_tensor_types_with_bfloat();
std::vector<std::string> all_sequence_types = OpSchema::all_tensor_sequence_types();
all_types.insert(all_types.end(), all_tensor_types.begin(), all_tensor_types.end());
all_types.insert(all_types.end(), all_sequence_types.begin(), all_sequence_types.end());
all_types.emplace_back("seq(tensor(bfloat16))");
all_types.erase(std::remove_if(all_types.begin(), all_types.end(),
[](const std::string& s) { return s.find("string") != std::string::npos; }), all_types.end());
return all_types; }();
ORT_ATTRIBUTE_UNUSED ONNX_OPERATOR_SCHEMA(MemcpyFromHost)
.Input(0, "X", "input", "T")
.Output(0, "Y", "output", "T")
.TypeConstraint(
"T",
OpSchema::all_tensor_types_with_bfloat(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
all_fixed_size_types,
"Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(
Internal copy node
@ -219,8 +230,8 @@ Internal copy node
.Output(0, "Y", "output", "T")
.TypeConstraint(
"T",
OpSchema::all_tensor_types_with_bfloat(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
all_fixed_size_types,
"Constrain to all fixed size tensor and sequence types. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.SetDoc(R"DOC(
Internal copy node