diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a125c72068..e11b3af821 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -5,6 +5,7 @@ #include "core/graph/contrib_ops/attn_lstm_schema_defs.h" #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/range_schema_defs.h" +#include "core/graph/contrib_ops/internal_schema_defs.h" #include "core/graph/op.h" #include "onnx/defs/shape_inference.h" @@ -590,49 +591,49 @@ The bounding box coordinates corresponding to the selected indices can then be o .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) - .SetDomain(kMSDomain) - .SinceVersion(1) - .Input (0, "data", "Tensor of rank r >= 1.", "T" ) - .Input (1, "indices", "Tensor of rank q >= 1.", "Tind" ) - .Output (0, "output", "Tensor of rank q-1+r-indices[-1].", "T" ) - .TypeConstraint( - "T", - OpSchema::all_tensor_types(), - "Constrain input and output types to any tensor type.") - .TypeConstraint( - "Tind", - {"tensor(int32)", "tensor(int64)"}, - "Constrain indice type to int32 or int64") - .TypeAndShapeInferenceFunction( [] (ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 2)) { - fail_shape_inference("GatherND requires two tensor inputs."); - } - auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); - auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); - auto data_rank = data_shape.dim_size(); - auto indices_rank = indices_shape.dim_size(); - if (data_rank < 1 || indices_rank < 1) { - fail_shape_inference("both data and indices tensor need to have rank larger than zero."); - } - auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value(); - if (last_indice_dimension > data_rank) { - fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); - } - for (int i = 0; i < indices_rank - 1; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = indices_shape.dim(i); - } - for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { - *ctx.getOutputType(0) - ->mutable_tensor_type() - ->mutable_shape() - ->add_dim() = data_shape.dim(i); - } - }) - .SetDoc(R"DOC( + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "data", "Tensor of rank r >= 1.", "T") + .Input(1, "indices", "Tensor of rank q >= 1.", "Tind") + .Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T") + .TypeConstraint( + "T", + OpSchema::all_tensor_types(), + "Constrain input and output types to any tensor type.") + .TypeConstraint( + "Tind", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indice type to int32 or int64") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + fail_shape_inference("GatherND requires two tensor inputs."); + } + auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); + auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); + auto data_rank = data_shape.dim_size(); + auto indices_rank = indices_shape.dim_size(); + if (data_rank < 1 || indices_rank < 1) { + fail_shape_inference("both data and indices tensor need to have rank larger than zero."); + } + auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value(); + if (last_indice_dimension > data_rank) { + fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); + } + for (int i = 0; i < indices_rank - 1; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = indices_shape.dim(i); + } + for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = data_shape.dim(i); + } + }) + .SetDoc(R"DOC( Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. Example 1: @@ -652,7 +653,8 @@ Example 4: indices = [[[0,1]],[[1,0]]] output = [[[2,3]],[[4,5]]] )DOC"); - + // register internal ops + RegisterInternalSchemas(); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/internal_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/internal_schema_defs.cc new file mode 100644 index 0000000000..f88ed27cfc --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/internal_schema_defs.cc @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "range_schema_defs.h" + +namespace onnxruntime { +namespace contrib { + +void RegisterInternalSchemas() {} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/internal_schema_defs.h b/onnxruntime/core/graph/contrib_ops/internal_schema_defs.h new file mode 100644 index 0000000000..25333585fa --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/internal_schema_defs.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +#include "onnx/defs/schema.h" +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +namespace onnxruntime { +namespace contrib { + +void RegisterInternalSchemas(); + +} // namespace contrib +} // namespace onnxruntime