mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Make dW optional for convgrad (#7083)
This commit is contained in:
parent
c5973fbbac
commit
3b16afc0db
3 changed files with 74 additions and 66 deletions
|
|
@ -711,7 +711,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetDropoutGradient) {
|
|||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) {
|
||||
std::vector<ArgDef> outputs;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); i++) {
|
||||
if (IsGradientRequiredForSrcNodeInput(i)) {
|
||||
outputs.push_back(GI(i));
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -473,7 +473,7 @@ void RegisterTrainingOpSchemas() {
|
|||
.Input(1, "X", "Input tensor", "T")
|
||||
.Input(2, "W", "Weight tensor", "T")
|
||||
.Output(0, "dX", "Gradient of input X", "T", OpSchema::Optional)
|
||||
.Output(1, "dW", "Gradient of W", "T")
|
||||
.Output(1, "dW", "Gradient of W", "T", OpSchema::Optional)
|
||||
.Output(2, "dB", "Gradient of B", "T", OpSchema::Optional)
|
||||
.AllowUncheckedAttributes()
|
||||
.TypeConstraint(
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
Tensor* dW = context->Output(1, W->Shape());
|
||||
T* dWdata = dW->template MutableData<T>();
|
||||
|
||||
TensorShape input_shape = X->Shape().Slice(2);
|
||||
TensorShape output_shape = dY->Shape().Slice(2);
|
||||
|
|
@ -80,8 +79,6 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
|||
const T* Wdata = W->template Data<T>();
|
||||
const T* dYdata = dY->template Data<T>();
|
||||
|
||||
// Pre-setting the gradients to zero.
|
||||
math::Set<T, CPUMathUtil>(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance());
|
||||
|
||||
BufferUniquePtr bias_multiplier(alloc->Alloc(sizeof(T) * output_image_size), BufferDeleter(alloc));
|
||||
T* bias_multiplier_data = nullptr;
|
||||
|
|
@ -97,73 +94,82 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
|||
bias_multiplier_data,
|
||||
&CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
T* dWdata = nullptr;
|
||||
if (dW) {
|
||||
dWdata = dW->template MutableData<T>();
|
||||
// Pre-setting the gradients to zero.
|
||||
math::Set<T, CPUMathUtil>(dW->Shape().Size(), 0, dWdata, &CPUMathUtil::Instance());
|
||||
}
|
||||
|
||||
bool skip_im2col = (kernel_size == 1) && conv_attrs_.HasStridesOneAndNoPadding();
|
||||
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
|
||||
if (!skip_im2col) {
|
||||
if (kernel_rank == 1) {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
C / conv_attrs_.group,
|
||||
1,
|
||||
input_shape[0],
|
||||
1,
|
||||
kernel_shape[0],
|
||||
1,
|
||||
dilations[0],
|
||||
0,
|
||||
pads[0],
|
||||
0,
|
||||
pads[1],
|
||||
1,
|
||||
strides[0],
|
||||
col_buffer_data);
|
||||
} else if (kernel_rank == 2) {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
C / conv_attrs_.group,
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
kernel_shape[0],
|
||||
kernel_shape[1],
|
||||
dilations[0],
|
||||
dilations[1],
|
||||
pads[0],
|
||||
pads[1],
|
||||
pads[2],
|
||||
pads[3],
|
||||
strides[0],
|
||||
strides[1],
|
||||
col_buffer_data);
|
||||
} else {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
input_shape.GetDims().data(),
|
||||
output_shape.GetDims().data(),
|
||||
kernel_dim,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data);
|
||||
if (dW) {
|
||||
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
|
||||
if (!skip_im2col) {
|
||||
if (kernel_rank == 1) {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
C / conv_attrs_.group,
|
||||
1,
|
||||
input_shape[0],
|
||||
1,
|
||||
kernel_shape[0],
|
||||
1,
|
||||
dilations[0],
|
||||
0,
|
||||
pads[0],
|
||||
0,
|
||||
pads[1],
|
||||
1,
|
||||
strides[0],
|
||||
col_buffer_data);
|
||||
} else if (kernel_rank == 2) {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
C / conv_attrs_.group,
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
kernel_shape[0],
|
||||
kernel_shape[1],
|
||||
dilations[0],
|
||||
dilations[1],
|
||||
pads[0],
|
||||
pads[1],
|
||||
pads[2],
|
||||
pads[3],
|
||||
strides[0],
|
||||
strides[1],
|
||||
col_buffer_data);
|
||||
} else {
|
||||
math::Im2col<T, StorageOrder::NCHW>()(
|
||||
Xdata + group_id * X_offset,
|
||||
input_shape.GetDims().data(),
|
||||
output_shape.GetDims().data(),
|
||||
kernel_dim,
|
||||
kernel_shape.data(),
|
||||
strides.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
static_cast<int>(kernel_shape.size()),
|
||||
col_buffer_data);
|
||||
}
|
||||
}
|
||||
// Gradient with respect to W, filter.
|
||||
math::Gemm<T>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M / conv_attrs_.group,
|
||||
kernel_dim,
|
||||
output_image_size,
|
||||
1,
|
||||
dYdata + group_id * Y_offset,
|
||||
skip_im2col ? Xdata + group_id * X_offset : col_buffer_data,
|
||||
1,
|
||||
dWdata + group_id * W_offset,
|
||||
tp);
|
||||
}
|
||||
// Gradient with respect to W, filter.
|
||||
math::Gemm<T>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M / conv_attrs_.group,
|
||||
kernel_dim,
|
||||
output_image_size,
|
||||
1,
|
||||
dYdata + group_id * Y_offset,
|
||||
skip_im2col ? Xdata + group_id * X_offset : col_buffer_data,
|
||||
1,
|
||||
dWdata + group_id * W_offset,
|
||||
tp);
|
||||
}
|
||||
if (dB) {
|
||||
// Gradient with respect to bias can be computed independent from group.
|
||||
|
|
@ -182,6 +188,7 @@ Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
|||
dYdata += Y_offset * conv_attrs_.group;
|
||||
}
|
||||
|
||||
|
||||
Tensor* dX = context->Output(0, X->Shape());
|
||||
if (dX) {
|
||||
T* dXdata = dX->template MutableData<T>();
|
||||
|
|
@ -251,3 +258,4 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue