ConvGrad CUDA Kernel Bugfix (#7273)

* bugfix

* add ut
This commit is contained in:
Vincent Wang 2021-04-08 08:22:18 +08:00 committed by GitHub
parent 844361bc67
commit beb299e17d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 4 deletions

View file

@ -510,8 +510,13 @@ def test_gradient_correctness():
assert torch.allclose(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("use_fp16", [False, True])
def test_gradient_correctness_conv1d(use_fp16):
@pytest.mark.parametrize("use_fp16, input_requires_grad", [
(False, False),
(False, True),
(True, False),
(True, True),
])
def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):
class NeuralNetConv1D(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1):
super(NeuralNetConv1D, self).__init__()
@ -536,7 +541,7 @@ def test_gradient_correctness_conv1d(use_fp16):
return prediction
for step in range(10):
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True)
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=input_requires_grad)
pt_prediction = run_step(pt_model, x)
ort_prediction = run_step(ort_model, x)

View file

@ -129,7 +129,7 @@ Status ConvGrad<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* dW = context->Output(1, W->Shape());
Tensor* dB = context->Output(2, {M});
ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB));
ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB));
ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X));
ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W));