Bumped up to op_ver 11 for a bunch of Nuphar Ops (#2025)

This change enabled op_ver 11 for a dozen of Nuphar Ops
This commit is contained in:
Yang Chen 2019-10-07 10:34:05 -07:00 committed by KeDengMS
parent 3c26ae5b6d
commit 7d2f0c79bd
4 changed files with 93 additions and 26 deletions

View file

@ -28,6 +28,7 @@ tvm::Tensor Gather(const tvm::Tensor& t,
for (size_t i = axis_t + 1; i < t->shape.size(); ++i)
output_shape.push_back(t->shape[i]);
tvm::Expr idx_upper_bound = t->shape[axis_t];
auto l = [&](const tvm::Array<tvm::Var>& ovars) {
tvm::Array<tvm::Expr> ivars;
for (size_t i = 0; i < t->shape.size(); ++i) {
@ -37,7 +38,9 @@ tvm::Tensor Gather(const tvm::Tensor& t,
tvm::Array<tvm::Expr> idx_vars;
for (size_t d = 0; d < indices->shape.size(); ++d)
idx_vars.push_back(ovars[axis_t + d]);
ivars.push_back(tvm::cast(tvm::Int(32), indices(idx_vars))); // tvm indices must be Int32
// make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1]
tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound);
ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); // tvm indices must be Int32
} else {
ivars.push_back(ovars[i - 1 + indices->shape.size()]);
}

View file

@ -182,10 +182,22 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T2", DataTypeImpl::AllFixedSizeTensorExceptHalfTypes()),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gather,
kOnnxDomain,
1,
10,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
Gather,
kOnnxDomain,
11,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
@ -228,10 +240,21 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Scan,
kOnnxDomain,
9,
10,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("V", DataTypeImpl::AllTensorTypes()),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
Scan,
kOnnxDomain,
11,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
@ -250,6 +273,17 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
Scatter,
kOnnxDomain,
11,
kNupharExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
nuphar::NupharKernel);
ONNX_OPERATOR_KERNEL_EX(
ScatterElements,
kOnnxDomain,

View file

@ -75,23 +75,30 @@ class NupharKernelState {
#define LIST_NUPHAR_OPS() \
NUPHAR_OP(Abs, 6, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Add, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ArgMax, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ArgMax, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ArgMax, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ArgMin, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ArgMin, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ArgMin, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(AveragePool, 7, 9, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(AveragePool, 10, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(AveragePool, 11, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Ceil, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Clip, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Concat, 4, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(Concat, 4, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Concat, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
DISABLE_MACRO(NUPHAR_OP(Conv, 1, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes())) \
NUPHAR_OP(Crop, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Div, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Dropout, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Elu, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Equal, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(Equal, 7, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Equal, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Erf, 9, DataTypeImpl::GetTensorType<float>()) \
NUPHAR_OP(Exp, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Flatten, 1, 8, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Flatten, 9, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Flatten, 9, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Flatten, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Floor, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Gemm, 7, 8, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Gemm, 9, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
@ -103,7 +110,8 @@ class NupharKernelState {
NUPHAR_OP(LeakyRelu, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Less, 9, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Log, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(LogSoftmax, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(LogSoftmax, 1, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(LogSoftmax, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
DISABLE_MACRO(NUPHAR_OP(LSTM, 7, DataTypeImpl::AllIEEEFloatTensorTypes())) \
NUPHAR_VERSIONED_OP(MatMul, 1, 8, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(MatMul, 9, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
@ -111,6 +119,7 @@ class NupharKernelState {
NUPHAR_VERSIONED_OP(MaxPool, 1, 7, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_VERSIONED_OP(MaxPool, 8, 9, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(MaxPool, 10, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(MaxPool, 11, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Min, 8, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Mul, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Neg, 6, DataTypeImpl::AllFixedSizeTensorTypes()) \
@ -119,27 +128,41 @@ class NupharKernelState {
NUPHAR_OP(PRelu, 7, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Relu, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Reciprocal, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceL1, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceL2, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceLogSum, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceLogSumExp, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceMax, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceMean, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceMin, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceProd, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceSum, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceSumSquare, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceL1, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceL1, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceL2, 1, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceL2, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceLogSum, 1, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceLogSum, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceLogSumExp, 1, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(ReduceLogSumExp, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceMax, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceMax, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceMean, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceMean, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceMin, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceMin, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceProd, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceProd, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceSum, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceSum, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(ReduceSumSquare, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ReduceSumSquare, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Reshape, 5, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(ScaledTanh, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Selu, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Sigmoid, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Slice, 1, 9, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Slice, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Softmax, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Slice, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(Softmax, 1, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Softmax, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Softplus, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Softsign, 1, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Split, 2, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Squeeze, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(Split, 2, 10, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Split, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Squeeze, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Squeeze, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Sqrt, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Sub, 7, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Sum, 8, DataTypeImpl::AllFixedSizeTensorTypes()) \
@ -147,7 +170,8 @@ class NupharKernelState {
NUPHAR_OP(ThresholdedRelu, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Tile, 6, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Transpose, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Unsqueeze, 1, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_VERSIONED_OP(Unsqueeze, 1, 10, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Unsqueeze, 11, DataTypeImpl::AllFixedSizeTensorTypes()) \
NUPHAR_OP(Where, 9, DataTypeImpl::AllFixedSizeTensorTypes())
} // namespace nuphar

View file

@ -396,12 +396,15 @@ LIST_NUPHAR_OPS()
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 6, 8, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Cast);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, Gather);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, 10, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Scan);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Scatter);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements);
static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
@ -419,12 +422,15 @@ static void RegisterStandaloneNupharKernels(KernelRegistry& kernel_registry) {
// ops that have multiple type constraints
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 6, 8, Cast)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Cast)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, Gather)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 1, 10, Gather)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Gather)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, GatherElements)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 10, MatMulInteger)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kMSDomain, 1, MatMulInteger16)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, Scan)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scan)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Scan)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 9, 10, Scatter)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, Scatter)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kNupharExecutionProvider, kOnnxDomain, 11, ScatterElements)>());
}