Fix the type constraints on CUDA If operator to exclude strings. (#2431)

This commit is contained in:
Scott McKay 2019-11-20 06:48:14 +10:00 committed by GitHub
parent 8647201ac7
commit 3be554c2fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -18,7 +18,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
// output shape rules requiring the output shapes of the 'THEN' and 'ELSE'
@ -30,7 +30,7 @@ ONNX_OPERATOR_KERNEL_EX(If,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) // 'cond' needs to be on CPU
.TypeConstraint("B", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
If);
Status If::Compute(OpKernelContext* ctx) const {