diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index 207f10850b..341742486f 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -83,7 +83,10 @@ void ComputeBroadcastBackwardAxes( } std::vector GetShape(const ArgDef& arg_def) { - ORT_ENFORCE(arg_def.type_proto, "During GetShape, ", arg_def.name, "'s type_proto is null."); + ORT_ENFORCE(arg_def.type_proto + && arg_def.type_proto->has_tensor_type() + && arg_def.type_proto->tensor_type().has_shape(), + "During GetShape, ", arg_def.name, "'s shape is null."); std::vector shape; const auto& dims = arg_def.type_proto->tensor_type().shape().dim(); for (auto dim = dims.begin(); dim < dims.end(); dim++) { diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index e49c858181..ec81e08261 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -937,7 +937,25 @@ Example 4: .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types") - .SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC"); + .SetDoc(R"DOC(SparseSoftmaxCrossEntropy)DOC") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + std::string reduction = getAttribute(ctx, "reduction", "mean"); + if (reduction.compare("none") == 0) { + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + } else { + updateOutputShape(ctx, 0, TensorShapeProto()); + } + + if(ctx.getNumOutputs() == 2) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + }); ONNX_CONTRIB_OPERATOR_SCHEMA(SparseSoftmaxCrossEntropyGrad) .SetDomain(kOnnxDomain) diff --git a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc index ceff7c3f77..ac8f1caad2 100644 --- a/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc +++ b/orttraining/orttraining/core/optimizer/insert_output_rewriter.cc @@ -14,7 +14,9 @@ Status InsertMaxPoolOutput::Apply(Graph& graph, Node& node, RewriteRuleEffect& r TypeProto t; t.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); - t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape()); + if (Y->Shape() != nullptr) { + t.mutable_tensor_type()->mutable_shape()->CopyFrom(*Y->Shape()); + } NodeArg& node_arg = graph.GetOrCreateNodeArg(Y->Name() + "_mask", &t); @@ -38,7 +40,9 @@ Status InsertSoftmaxCrossEntropyLossOutput::Apply(Graph& graph, Node& node, Rewr TypeProto t; t.mutable_tensor_type()->set_elem_type(X->TypeAsProto()->tensor_type().elem_type()); - t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits. + if (X->Shape() != nullptr) { + t.mutable_tensor_type()->mutable_shape()->CopyFrom(*X->Shape()); // log probability should have the same shape as logits. + } NodeArg& node_arg = graph.GetOrCreateNodeArg(X->Name() + "_log_prob", &t);