Make dW optional for convgrad (#7083)

This commit is contained in:
Pranav Prakash 2021-04-05 17:05:20 -07:00 committed by GitHub
parent c5973fbbac
commit 3b16afc0db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 66 deletions

View file

@ -711,7 +711,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetDropoutGradient) {
IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) {
std::vector<ArgDef> outputs;
for (int i = 0; i < 3; i++) {
for (int i = 0; i < GetSrcNodeInputSize(); i++) {
if (IsGradientRequiredForSrcNodeInput(i)) {
outputs.push_back(GI(i));
} else {

View file

@ -473,7 +473,7 @@ void RegisterTrainingOpSchemas() {
.Input(1, "X", "Input tensor", "T")
.Input(2, "W", "Weight tensor", "T")
.Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional)
.Output(1, "dW", "Gradient of W", "T")
.Output(1, "dW", "Gradient of W", "T", OpSchema::Optional)
.Output(2, "dB", "Gradient of B", "T", OpSchema::Optional)
.AllowUncheckedAttributes()
.TypeConstraint(

View file

@ -56,7 +56,6 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
}
Tensor* dW = context->Output(1, W->Shape());
T* dWdata = dW->template MutableData<T>();
TensorShape input_shape = X->Shape().Slice(2);
TensorShape output_shape = dY->Shape().Slice(2);
@ -80,8 +79,6 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
const T* Wdata = W->template Data<T>();
const T* dYdata = dY->template Data<T>();
// Pre-setting the gradients to zero.
math::Set<T, CPUMathUtil>(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance());
BufferUniquePtr bias_multiplier(alloc->Alloc(sizeof(T) * output_image_size), BufferDeleter(alloc));
T* bias_multiplier_data = nullptr;
@ -97,73 +94,82 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
bias_multiplier_data,
&CPUMathUtil::Instance());
}
T* dWdata = nullptr;
if (dW) {
dWdata = dW->template MutableData<T>();
// Pre-setting the gradients to zero.
math::Set<T, CPUMathUtil>(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance());
}
bool skip_im2col = (kernel_size == 1) && conv_attrs_.HasStridesOneAndNoPadding();
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (!skip_im2col) {
if (kernel_rank == 1) {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
C / conv_attrs_.group,
1,
input_shape[0],
1,
kernel_shape[0],
1,
dilations[0],
0,
pads[0],
0,
pads[1],
1,
strides[0],
col_buffer_data);
} else if (kernel_rank == 2) {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data);
} else {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data);
if (dW) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (!skip_im2col) {
if (kernel_rank == 1) {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
C / conv_attrs_.group,
1,
input_shape[0],
1,
kernel_shape[0],
1,
dilations[0],
0,
pads[0],
0,
pads[1],
1,
strides[0],
col_buffer_data);
} else if (kernel_rank == 2) {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data);
} else {
math::Im2col<T, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_shape.size()),
col_buffer_data);
}
}
// Gradient with respect to W, filter.
math::Gemm<T>(
CblasNoTrans,
CblasTrans,
M / conv_attrs_.group,
kernel_dim,
output_image_size,
1,
dYdata + group_id * Y_offset,
skip_im2col ? Xdata + group_id * X_offset : col_buffer_data,
1,
dWdata + group_id * W_offset,
tp);
}
// Gradient with respect to W, filter.
math::Gemm<T>(
CblasNoTrans,
CblasTrans,
M / conv_attrs_.group,
kernel_dim,
output_image_size,
1,
dYdata + group_id * Y_offset,
skip_im2col ? Xdata + group_id * X_offset : col_buffer_data,
1,
dWdata + group_id * W_offset,
tp);
}
if (dB) {
// Gradient with respect to bias can be computed independent from group.
@ -182,6 +188,7 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
dYdata += Y_offset * conv_attrs_.group;
}
Tensor* dX = context->Output(0, X->Shape());
if (dX) {
T* dXdata = dX->template MutableData<T>();
@ -251,3 +258,4 @@ ONNX_CPU_OPERATOR_KERNEL(
} // namespace contrib
} // namespace onnxruntime