mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Bugfix for shape inference and GetShape. (#4243)
Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
parent
12367a6b11
commit
b41fcf1570
3 changed files with 29 additions and 4 deletions
|
|
@ -83,7 +83,10 @@ void ComputeBroadcastBackwardAxes(
|
|||
}
|
||||
|
||||
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
|
||||
ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null.");
|
||||
ORT_ENFORCE(arg_def.type_proto
|
||||
&& arg_def.type_proto->has_tensor_type()
|
||||
&& arg_def.type_proto->tensor_type().has_shape(),
|
||||
"During GetShape, ", arg_def.name, "'s shape is null.");
|
||||
std::vector<Dimension> shape;
|
||||
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
|
||||
for (auto dim = dims.begin(); dim < dims.end(); dim++) {
|
||||
|
|
|
|||
|
|
@ -937,7 +937,25 @@ Example 4:
|
|||
.TypeConstraint("Tind",
|
||||
{"tensor(int32)", "tensor(int64)"},
|
||||
"Constrain indices to integer types")
|
||||
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC");
|
||||
.SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
std::string reduction = getAttribute(ctx, "reduction", "mean");
|
||||
if (reduction.compare("none") == 0) {
|
||||
if (hasInputShape(ctx, 1)) {
|
||||
propagateShapeFromInputToOutput(ctx, 1, 0);
|
||||
}
|
||||
} else {
|
||||
updateOutputShape(ctx, 0, TensorShapeProto());
|
||||
}
|
||||
|
||||
if(ctx.getNumOutputs() == 2) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SparseSoftmaxCrossEntropyGrad)
|
||||
.SetDomain(kOnnxDomain)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ Status InsertMaxPoolOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& r
|
|||
|
||||
TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
|
||||
if (Y->Shape() != nullptr) {
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape());
|
||||
}
|
||||
|
||||
NodeArg& node_arg = graph.GetOrCreateNodeArg(Y->Name() + "_mask", &t);
|
||||
|
||||
|
|
@ -38,7 +40,9 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr
|
|||
|
||||
TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(X->TypeAsProto()->tensor_type().elem_type());
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
|
||||
if (X->Shape() != nullptr) {
|
||||
t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits.
|
||||
}
|
||||
|
||||
NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue