mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
placeholder for internal contrib ops (#219)
This commit is contained in:
parent
b9cc134576
commit
94f8f2b05c
3 changed files with 80 additions and 44 deletions
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/range_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/internal_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
|
||||
|
|
@ -590,49 +591,49 @@ The bounding box coordinates corresponding to the selected indices can then be o
|
|||
.SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input (0, "data", "Tensor of rank r >= 1.", "T" )
|
||||
.Input (1, "indices", "Tensor of rank q >= 1.", "Tind" )
|
||||
.Output (0, "output", "Tensor of rank q-1+r-indices[-1].", "T" )
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
"Constrain input and output types to any tensor type.")
|
||||
.TypeConstraint(
|
||||
"Tind",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indice type to int32 or int64")
|
||||
.TypeAndShapeInferenceFunction( [] (ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 2)) {
|
||||
fail_shape_inference("GatherND requires two tensor inputs.");
|
||||
}
|
||||
auto& data_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
|
||||
auto data_rank = data_shape.dim_size();
|
||||
auto indices_rank = indices_shape.dim_size();
|
||||
if (data_rank < 1 || indices_rank < 1) {
|
||||
fail_shape_inference("both data and indices tensor need to have rank larger than zero.");
|
||||
}
|
||||
auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value();
|
||||
if (last_indice_dimension > data_rank) {
|
||||
fail_shape_inference("last dimension of indices must not be larger and rank of data tensor");
|
||||
}
|
||||
for (int i = 0; i < indices_rank - 1; ++i) {
|
||||
*ctx.getOutputType(0)
|
||||
->mutable_tensor_type()
|
||||
->mutable_shape()
|
||||
->add_dim() = indices_shape.dim(i);
|
||||
}
|
||||
for (int i = static_cast<int>(last_indice_dimension); i < data_rank; ++i) {
|
||||
*ctx.getOutputType(0)
|
||||
->mutable_tensor_type()
|
||||
->mutable_shape()
|
||||
->add_dim() = data_shape.dim(i);
|
||||
}
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "data", "Tensor of rank r >= 1.", "T")
|
||||
.Input(1, "indices", "Tensor of rank q >= 1.", "Tind")
|
||||
.Output(0, "output", "Tensor of rank q-1+r-indices[-1].", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
"Constrain input and output types to any tensor type.")
|
||||
.TypeConstraint(
|
||||
"Tind",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indice type to int32 or int64")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 2)) {
|
||||
fail_shape_inference("GatherND requires two tensor inputs.");
|
||||
}
|
||||
auto& data_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
|
||||
auto data_rank = data_shape.dim_size();
|
||||
auto indices_rank = indices_shape.dim_size();
|
||||
if (data_rank < 1 || indices_rank < 1) {
|
||||
fail_shape_inference("both data and indices tensor need to have rank larger than zero.");
|
||||
}
|
||||
auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value();
|
||||
if (last_indice_dimension > data_rank) {
|
||||
fail_shape_inference("last dimension of indices must not be larger and rank of data tensor");
|
||||
}
|
||||
for (int i = 0; i < indices_rank - 1; ++i) {
|
||||
*ctx.getOutputType(0)
|
||||
->mutable_tensor_type()
|
||||
->mutable_shape()
|
||||
->add_dim() = indices_shape.dim(i);
|
||||
}
|
||||
for (int i = static_cast<int>(last_indice_dimension); i < data_rank; ++i) {
|
||||
*ctx.getOutputType(0)
|
||||
->mutable_tensor_type()
|
||||
->mutable_shape()
|
||||
->add_dim() = data_shape.dim(i);
|
||||
}
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather
|
||||
slices of `data` into an output tensor of rank q - 1 + r - indices[-1].
|
||||
Example 1:
|
||||
|
|
@ -652,7 +653,8 @@ Example 4:
|
|||
indices = [[[0,1]],[[1,0]]]
|
||||
output = [[[2,3]],[[4,5]]]
|
||||
)DOC");
|
||||
|
||||
// register internal ops
|
||||
RegisterInternalSchemas();
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
12
onnxruntime/core/graph/contrib_ops/internal_schema_defs.cc
Normal file
12
onnxruntime/core/graph/contrib_ops/internal_schema_defs.cc
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "range_schema_defs.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void RegisterInternalSchemas() {}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
22
onnxruntime/core/graph/contrib_ops/internal_schema_defs.h
Normal file
22
onnxruntime/core/graph/contrib_ops/internal_schema_defs.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void RegisterInternalSchemas();
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue