Bugfix for shape inference and GetShape. (#4243)

Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
Vincent Wang 2020-06-17 15:11:02 +08:00 committed by GitHub
parent 12367a6b11
commit b41fcf1570
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 4 deletions

View file

@ -83,7 +83,10 @@ void ComputeBroadcastBackwardAxes(
}
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null.");
ORT_ENFORCE(arg_def.type_proto
&& arg_def.type_proto->has_tensor_type()
&& arg_def.type_proto->tensor_type().has_shape(),
"During GetShape, ", arg_def.name, "'s shape is null.");
std::vector<Dimension> shape;
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
for (auto dim = dims.begin(); dim < dims.end(); dim++) {

View file

@ -937,7 +937,25 @@ Example 4:
.TypeConstraint("Tind",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types")
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC");
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
std::string reduction = getAttribute(ctx, "reduction", "mean");
if (reduction.compare("none") == 0) {
if (hasInputShape(ctx, 1)) {
propagateShapeFromInputToOutput(ctx, 1, 0);
}
} else {
updateOutputShape(ctx, 0, TensorShapeProto());
}
if(ctx.getNumOutputs() == 2) {
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 1);
}
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(SparseSoftmaxCrossEntropyGrad)
.SetDomain(kOnnxDomain)

View file

@ -14,7 +14,9 @@ Status InsertMaxPoolOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& r
TypeProto t;
t.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
if (Y->Shape() != nullptr) {
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
}
NodeArg& node_arg = graph.GetOrCreateNodeArg(Y->Name() + "_mask", &t);
@ -38,7 +40,9 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
TypeProto t;
t.mutable_tensor_type()->set_elem_type(X->TypeAsProto()->tensor_type().elem_type());
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
if (X->Shape() != nullptr) {
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
}
NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);