Adding shaper inference for Op expand_dims

This commit is contained in:
Du Li 2018-11-24 18:53:42 -08:00
parent c7513e676f
commit e7e801b45e
2 changed files with 34 additions and 4 deletions

View file

@ -12,8 +12,8 @@
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::OPTIONAL;
using ::ONNX_NAMESPACE::OpSchema;
using ::ONNX_NAMESPACE::OPTIONAL;
void RegisterContribSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
@ -41,7 +41,37 @@ Sample echo operator.)DOC");
"T",
ONNX_NAMESPACE::OpSchema::all_tensor_types(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (!hasInputShape(ctx, 0))
return;
auto& input_shape = getInputShape(ctx, 0);
const int rank = input_shape.dim_size();
const ONNX_NAMESPACE::TensorProto* axis_initializer = ctx.getInputData(1);
if (!axis_initializer)
return;
const int axis = axis_initializer->int32_data()[0];
if (axis > rank || axis < -rank - 1) {
fail_shape_inference("Input axis is invalid: ", axis);
}
int pos = axis >= 0 ? axis : rank + axis - 1;
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (int i = 0; i < pos; ++i) {
output_shape.add_dim();
*(output_shape.mutable_dim(i)) = input_shape.dim(i);
}
output_shape.add_dim();
output_shape.mutable_dim(pos)->set_dim_value(1);
for (int i = pos + 1; i < rank + 1; ++i) {
output_shape.add_dim();
*(output_shape.mutable_dim(i)) = input_shape.dim(i - 1);
}
updateOutputShape(ctx, 0, output_shape);
})
.SetDoc(R"DOC(ExpandDims echo operator.)DOC");
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);

View file

@ -9,7 +9,7 @@ namespace test {
TEST(ContribOpTest, ExpandDims_0) {
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
test.AddShapeToTensorData(true); // TODO: re-enable shape inference test
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
test.AddInput<int32_t>("axis", {}, {-1});
test.AddOutput<float>("Y", {2, 3, 1}, std::vector<float>(6, 1.0f));
@ -18,7 +18,7 @@ TEST(ContribOpTest, ExpandDims_0) {
TEST(ContribOpTest, ExpandDims_1) {
OpTester test("ExpandDims", 1, onnxruntime::kMSDomain);
test.AddShapeToTensorData(false); // TODO: re-enable shape inference test
test.AddShapeToTensorData(true); // TODO: re-enable shape inference test
test.AddInput<float>("X", {2, 3}, std::vector<float>(6, 1.0f));
test.AddInput<int32_t>("axis", {}, {1});
test.AddOutput<float>("Y", {2, 1, 3}, std::vector<float>(6, 1.0f));