diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9c33a032e3..8e1437dd9c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -711,7 +711,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetDropoutGradient) { IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) { std::vector outputs; - for (int i = 0; i < 3; i++) { + for (int i = 0; i < GetSrcNodeInputSize(); i++) { if (IsGradientRequiredForSrcNodeInput(i)) { outputs.push_back(GI(i)); } else { diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index b5dcca390c..f59a2ebf71 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -473,7 +473,7 @@ void RegisterTrainingOpSchemas() { .Input(1, "X", "Input tensor", "T") .Input(2, "W", "Weight tensor", "T") .Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional) - .Output(1, "dW", "Gradient of W", "T") + .Output(1, "dW", "Gradient of W", "T", OpSchema::Optional) .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) .AllowUncheckedAttributes() .TypeConstraint( diff --git a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc index 5b886fc37d..fa58bf0def 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc @@ -56,7 +56,6 @@ Status ConvGrad::Compute(OpKernelContext* context) const { } Tensor* dW = context->Output(1, W->Shape()); - T* dWdata = dW->template MutableData(); TensorShape input_shape = X->Shape().Slice(2); TensorShape output_shape = dY->Shape().Slice(2); @@ -80,8 +79,6 @@ Status ConvGrad::Compute(OpKernelContext* context) const { const T* Wdata = W->template Data(); const T* dYdata = dY->template Data(); - // Pre-setting the gradients to zero. - math::Set(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance()); BufferUniquePtr bias_multiplier(alloc->Alloc(sizeof(T) * output_image_size), BufferDeleter(alloc)); T* bias_multiplier_data = nullptr; @@ -97,73 +94,82 @@ Status ConvGrad::Compute(OpKernelContext* context) const { bias_multiplier_data, &CPUMathUtil::Instance()); } + + T* dWdata = nullptr; + if (dW) { + dWdata = dW->template MutableData(); + // Pre-setting the gradients to zero. + math::Set(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance()); + } bool skip_im2col = (kernel_size == 1) && conv_attrs_.HasStridesOneAndNoPadding(); for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { - if (!skip_im2col) { - if (kernel_rank == 1) { - math::Im2col()( - Xdata + group_id * X_offset, - C / conv_attrs_.group, - 1, - input_shape[0], - 1, - kernel_shape[0], - 1, - dilations[0], - 0, - pads[0], - 0, - pads[1], - 1, - strides[0], - col_buffer_data); - } else if (kernel_rank == 2) { - math::Im2col()( - Xdata + group_id * X_offset, - C / conv_attrs_.group, - input_shape[0], - input_shape[1], - kernel_shape[0], - kernel_shape[1], - dilations[0], - dilations[1], - pads[0], - pads[1], - pads[2], - pads[3], - strides[0], - strides[1], - col_buffer_data); - } else { - math::Im2col()( - Xdata + group_id * X_offset, - input_shape.GetDims().data(), - output_shape.GetDims().data(), - kernel_dim, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_shape.size()), - col_buffer_data); + if (dW) { + for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { + if (!skip_im2col) { + if (kernel_rank == 1) { + math::Im2col()( + Xdata + group_id * X_offset, + C / conv_attrs_.group, + 1, + input_shape[0], + 1, + kernel_shape[0], + 1, + dilations[0], + 0, + pads[0], + 0, + pads[1], + 1, + strides[0], + col_buffer_data); + } else if (kernel_rank == 2) { + math::Im2col()( + Xdata + group_id * X_offset, + C / conv_attrs_.group, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + col_buffer_data); + } else { + math::Im2col()( + Xdata + group_id * X_offset, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_shape.size()), + col_buffer_data); + } } + // Gradient with respect to W, filter. + math::Gemm( + CblasNoTrans, + CblasTrans, + M / conv_attrs_.group, + kernel_dim, + output_image_size, + 1, + dYdata + group_id * Y_offset, + skip_im2col ? Xdata + group_id * X_offset : col_buffer_data, + 1, + dWdata + group_id * W_offset, + tp); } - // Gradient with respect to W, filter. - math::Gemm( - CblasNoTrans, - CblasTrans, - M / conv_attrs_.group, - kernel_dim, - output_image_size, - 1, - dYdata + group_id * Y_offset, - skip_im2col ? Xdata + group_id * X_offset : col_buffer_data, - 1, - dWdata + group_id * W_offset, - tp); } if (dB) { // Gradient with respect to bias can be computed independent from group. @@ -182,6 +188,7 @@ Status ConvGrad::Compute(OpKernelContext* context) const { dYdata += Y_offset * conv_attrs_.group; } + Tensor* dX = context->Output(0, X->Shape()); if (dX) { T* dXdata = dX->template MutableData(); @@ -251,3 +258,4 @@ ONNX_CPU_OPERATOR_KERNEL( } // namespace contrib } // namespace onnxruntime +