mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Adding shaper inference for Op expand_dims
This commit is contained in:
parent
c7513e676f
commit
e7e801b45e
2 changed files with 34 additions and 4 deletions
|
|
@ -12,8 +12,8 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
using ::ONNX_NAMESPACE::AttributeProto;
|
||||
using ::ONNX_NAMESPACE::OPTIONAL;
|
||||
using ::ONNX_NAMESPACE::OpSchema;
|
||||
using ::ONNX_NAMESPACE::OPTIONAL;
|
||||
|
||||
void RegisterContribSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
|
||||
|
|
@ -41,7 +41,37 @@ Sample echo operator.)DOC");
|
|||
"T",
|
||||
ONNX_NAMESPACE::OpSchema::all_tensor_types(),
|
||||
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Shape inference
|
||||
if (!hasInputShape(ctx, 0))
|
||||
return;
|
||||
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
const int rank = input_shape.dim_size();
|
||||
const ONNX_NAMESPACE::TensorProto* axis_initializer = ctx.getInputData(1);
|
||||
if (!axis_initializer)
|
||||
return;
|
||||
const int axis = axis_initializer->int32_data()[0];
|
||||
if (axis > rank || axis < -rank - 1) {
|
||||
fail_shape_inference("Input axis is invalid: ", axis);
|
||||
}
|
||||
int pos = axis >= 0 ? axis : rank + axis - 1;
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
for (int i = 0; i < pos; ++i) {
|
||||
output_shape.add_dim();
|
||||
*(output_shape.mutable_dim(i)) = input_shape.dim(i);
|
||||
}
|
||||
output_shape.add_dim();
|
||||
output_shape.mutable_dim(pos)->set_dim_value(1);
|
||||
for (int i = pos + 1; i < rank + 1; ++i) {
|
||||
output_shape.add_dim();
|
||||
*(output_shape.mutable_dim(i)) = input_shape.dim(i - 1);
|
||||
}
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
})
|
||||
.SetDoc(R"DOC(ExpandDims echo operator.)DOC");
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ namespace test {
|
|||
|
||||
TEST(ContribOpTest, ExpandDims_0) {
|
||||
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
|
||||
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
|
||||
test.AddShapeToTensorData(true); // TODO: re-enable shape inference test
|
||||
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
|
||||
test.AddInput<int32_t>("axis", {}, {-1});
|
||||
test.AddOutput<float>("Y", {2, 3, 1}, std::vector<float>(6, 1.0f));
|
||||
|
|
@ -18,7 +18,7 @@ TEST(ContribOpTest, ExpandDims_0) {
|
|||
|
||||
TEST(ContribOpTest, ExpandDims_1) {
|
||||
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
|
||||
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
|
||||
test.AddShapeToTensorData(true); // TODO: re-enable shape inference test
|
||||
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
|
||||
test.AddInput<int32_t>("axis", {}, {1});
|
||||
test.AddOutput<float>("Y", {2, 1, 3}, std::vector<float>(6, 1.0f));
|
||||
|
|
|
|||
Loading…
Reference in a new issue