diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 79efcb5587..b81e0746a1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -510,8 +510,13 @@ def test_gradient_correctness(): assert torch.allclose(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) -@pytest.mark.parametrize("use_fp16", [False, True]) -def test_gradient_correctness_conv1d(use_fp16): +@pytest.mark.parametrize("use_fp16, input_requires_grad", [ + (False, False), + (False, True), + (True, False), + (True, True), + ]) +def test_gradient_correctness_conv1d(use_fp16, input_requires_grad): class NeuralNetConv1D(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1): super(NeuralNetConv1D, self).__init__() @@ -536,7 +541,7 @@ def test_gradient_correctness_conv1d(use_fp16): return prediction for step in range(10): - x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True) + x = torch.randn(N, seq_len, C_in, device=device, requires_grad=input_requires_grad) pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index b1e9f7a6df..ca742ad897 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -129,7 +129,7 @@ Status ConvGrad::ComputeInternal(OpKernelContext* context) const { Tensor* dW = context->Output(1, W->Shape()); Tensor* dB = context->Output(2, {M}); - ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB)); + ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB)); ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X)); ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W));