mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[release/1.6] [JIT] Dont include view ops in autodiff graphs (#42029)
* Dont include view ops in autodiff graphs * skip view ops in autodiff testing * two more tests * appease calng format * Pacify clang-format Co-authored-by: eellison <eellison@fb.com> Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
This commit is contained in:
parent
35ad2d8586
commit
29fe90e2a2
3 changed files with 45 additions and 23 deletions
|
|
@ -103,6 +103,9 @@ def LSTMCellF(input, hx, cx, *params):
|
|||
|
||||
|
||||
def doAutodiffCheck(testname):
|
||||
# TODO: setting false on test itself is not working
|
||||
if "test_t_" in testname or testname == "test_t":
|
||||
return False
|
||||
|
||||
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -110,6 +110,20 @@ class SubgraphSlicer {
|
|||
return result;
|
||||
}
|
||||
|
||||
bool isViewOp(Node* n) {
|
||||
switch (n->kind()) {
|
||||
case aten::view:
|
||||
case aten::view_as:
|
||||
case aten::reshape:
|
||||
case aten::reshape_as:
|
||||
case aten::transpose:
|
||||
case aten::expand:
|
||||
case aten::expand_as:
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool shouldConsiderForMerge(Node* node) {
|
||||
// if we're already in the process of merging
|
||||
if (node->kind() == prim::DifferentiableGraph) {
|
||||
|
|
@ -118,6 +132,11 @@ class SubgraphSlicer {
|
|||
if (node->kind() == prim::Constant) {
|
||||
return false;
|
||||
}
|
||||
// view ops as outputs of differentiable subgraphs can cause incorrect
|
||||
// differentiation for now, do not include them in the subgraph
|
||||
if (isViewOp(node)) {
|
||||
return false;
|
||||
}
|
||||
return isDifferentiable(node);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -179,22 +179,22 @@ def method_tests():
|
|||
('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')),
|
||||
('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
|
||||
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
|
||||
('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
|
||||
('transpose', (), (0, 0), 'scalar', (True,)),
|
||||
('transpose', (1,), (0, 0), '1d', (True,)),
|
||||
('transpose', (L, L), (0, 1), '2d', (True,)),
|
||||
('transpose', (S, S, S), (2, 0), '3d', (True,)),
|
||||
('t', (1, 2), NO_ARGS, '', (True,)),
|
||||
('view', (S, S, S), (S * S, S), '', (True,)),
|
||||
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
|
||||
('view', (S,), (S,), '1d', (True,)),
|
||||
('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
|
||||
('view', (), (1,), 'scalar_to_1d', (True,)),
|
||||
('reshape', (S, S, S), (S * S, S), '', (True,)),
|
||||
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
|
||||
('reshape', (S,), (S,), '1d', (True,)),
|
||||
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
|
||||
('reshape', (), (1,), 'scalar_to_1d', (True,)),
|
||||
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
|
||||
('transpose', (), (0, 0), 'scalar', (False,)),
|
||||
('transpose', (1,), (0, 0), '1d', (False,)),
|
||||
('transpose', (L, L), (0, 1), '2d', (False,)),
|
||||
('transpose', (S, S, S), (2, 0), '3d', (False,)),
|
||||
('t', (1, 2), NO_ARGS, '', (False,)),
|
||||
('view', (S, S, S), (S * S, S), '', (False,)),
|
||||
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
|
||||
('view', (S,), (S,), '1d', (False,)),
|
||||
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
|
||||
('view', (), (1,), 'scalar_to_1d', (False,)),
|
||||
('reshape', (S, S, S), (S * S, S), '', (False,)),
|
||||
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
|
||||
('reshape', (S,), (S,), '1d', (False,)),
|
||||
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
|
||||
('reshape', (), (1,), 'scalar_to_1d', (False,)),
|
||||
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
|
||||
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
|
||||
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
|
||||
|
|
@ -220,14 +220,14 @@ def method_tests():
|
|||
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
|
||||
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
|
||||
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
|
||||
('expand', (S, 1, 1), (S, S, S), '', (True,)),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)),
|
||||
('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
|
||||
('expand', (1,), (S, S, S), '1_element', (True,)),
|
||||
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
|
||||
('expand', (S, 1, 1), (S, S, S), '', (False,)),
|
||||
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (False,)),
|
||||
('expand', (S, 1), (S, S, S), 'new_dim', (False,)),
|
||||
('expand', (1,), (S, S, S), '1_element', (False,)),
|
||||
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (False,)),
|
||||
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
|
||||
('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
|
||||
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
|
||||
('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)),
|
||||
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
|
||||
('exp', (S, S, S), NO_ARGS, '', (True,)),
|
||||
('exp', (), NO_ARGS, 'scalar', (True,)),
|
||||
('expm1', (S, S, S), NO_ARGS, '', (True,)),
|
||||
|
|
|
|||
Loading…
Reference in a new issue