diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 22d3dc7b8f..40729da3b4 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -418,9 +418,9 @@ void RegisterTrainingOpSchemas() { .Input(0, "dY", "Gradient of output Y", "T") .Input(1, "X", "Input tensor", "T") .Input(2, "W", "Weight tensor", "T") - .Output(0, "dX", "Gradient of input X", "T") + .Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional) .Output(1, "dW", "Gradient of W", "T") - .Output(2, "dB", "Gradient of B", "T") + .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) .AllowUncheckedAttributes() .TypeConstraint( "T",