Fix shape inference bug in GatherND contrib op (#1132)

This commit is contained in:
Hariharan Seshadri 2019-05-30 17:05:44 -07:00 committed by GitHub
parent 4757933afe
commit facdf77f84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -929,7 +929,7 @@ with the exception that numpy default keepdims to False instead of True.)DOC")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 2)) {
fail_shape_inference("GatherND requires two tensor inputs.");
return;
}
auto& data_shape = ctx.getInputType(0)->tensor_type().shape();
auto& indices_shape = ctx.getInputType(1)->tensor_type().shape();