mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
parent
844361bc67
commit
beb299e17d
2 changed files with 9 additions and 4 deletions
|
|
@ -510,8 +510,13 @@ def test_gradient_correctness():
|
|||
assert torch.allclose(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("use_fp16", [False, True])
|
||||
def test_gradient_correctness_conv1d(use_fp16):
|
||||
@pytest.mark.parametrize("use_fp16, input_requires_grad", [
|
||||
(False, False),
|
||||
(False, True),
|
||||
(True, False),
|
||||
(True, True),
|
||||
])
|
||||
def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):
|
||||
class NeuralNetConv1D(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1):
|
||||
super(NeuralNetConv1D, self).__init__()
|
||||
|
|
@ -536,7 +541,7 @@ def test_gradient_correctness_conv1d(use_fp16):
|
|||
return prediction
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True)
|
||||
x = torch.randn(N, seq_len, C_in, device=device, requires_grad=input_requires_grad)
|
||||
pt_prediction = run_step(pt_model, x)
|
||||
ort_prediction = run_step(ort_model, x)
|
||||
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ Status ConvGrad<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
Tensor* dW = context->Output(1, W->Shape());
|
||||
Tensor* dB = context->Output(2, {M});
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareArgs(*dX, *dY, *dW, dB));
|
||||
ORT_RETURN_IF_ERROR(PrepareArgs(*X, *dY, *W, dB));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeWeightGradient(dW, dY, X));
|
||||
ORT_RETURN_IF_ERROR(ComputeInputGradient(dX, dY, W));
|
||||
|
|
|
|||
Loading…
Reference in a new issue