mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
make the variable declaration closer to usage
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9262 Differential Revision: D10363576 Pulled By: ezyang fbshipit-source-id: 05c8eb12f3b389caf562cca9e338cc91b0e9acc1
This commit is contained in:
parent
15bdb9fe61
commit
239b2ac718
1 changed files with 3 additions and 2 deletions
|
|
@ -264,7 +264,7 @@ class GradientChecker:
|
|||
# hack.
|
||||
grad_ops, g_input = getGradientForOp(op)
|
||||
|
||||
dims_to_check = inputs[input_to_check].size
|
||||
|
||||
_input_device_options = input_device_options or \
|
||||
core.InferOpBlobDevicesAsDict(op)[0]
|
||||
# First, feed in the input.
|
||||
|
|
@ -285,7 +285,8 @@ class GradientChecker:
|
|||
raise Exception(
|
||||
"Mismatched gradient shapes: estimated ({}), grad ({})".format(
|
||||
grad_estimate.shape, grad.shape))
|
||||
|
||||
|
||||
dims_to_check = inputs[input_to_check].size
|
||||
for current_dim in range(dims_to_check):
|
||||
# Positive gradient
|
||||
inputs[input_to_check].flat[current_dim] += self._stepsize
|
||||
|
|
|
|||
Loading…
Reference in a new issue