[release/1.6] [JIT] Dont include view ops in autodiff graphs (#42029)

* Dont include view ops in autodiff graphs

* skip view ops in autodiff testing

* two more tests

* appease calng format

* Pacify clang-format

Co-authored-by: eellison <eellison@fb.com>
Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
This commit is contained in:
eellison 2020-07-24 13:41:32 -07:00 committed by GitHub
parent 35ad2d8586
commit 29fe90e2a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 23 deletions

View file

@ -103,6 +103,9 @@ def LSTMCellF(input, hx, cx, *params):
def doAutodiffCheck(testname):
# TODO: setting false on test itself is not working
if "test_t_" in testname or testname == "test_t":
return False
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
return False

View file

@ -110,6 +110,20 @@ class SubgraphSlicer {
return result;
}
bool isViewOp(Node* n) {
switch (n->kind()) {
case aten::view:
case aten::view_as:
case aten::reshape:
case aten::reshape_as:
case aten::transpose:
case aten::expand:
case aten::expand_as:
return true;
}
return false;
}
bool shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (node->kind() == prim::DifferentiableGraph) {
@ -118,6 +132,11 @@ class SubgraphSlicer {
if (node->kind() == prim::Constant) {
return false;
}
// view ops as outputs of differentiable subgraphs can cause incorrect
// differentiation for now, do not include them in the subgraph
if (isViewOp(node)) {
return false;
}
return isDifferentiable(node);
}

View file

@ -179,22 +179,22 @@ def method_tests():
('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')),
('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
('transpose', (), (0, 0), 'scalar', (True,)),
('transpose', (1,), (0, 0), '1d', (True,)),
('transpose', (L, L), (0, 1), '2d', (True,)),
('transpose', (S, S, S), (2, 0), '3d', (True,)),
('t', (1, 2), NO_ARGS, '', (True,)),
('view', (S, S, S), (S * S, S), '', (True,)),
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
('view', (S,), (S,), '1d', (True,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
('view', (), (1,), 'scalar_to_1d', (True,)),
('reshape', (S, S, S), (S * S, S), '', (True,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
('reshape', (S,), (S,), '1d', (True,)),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
('reshape', (), (1,), 'scalar_to_1d', (True,)),
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
('transpose', (), (0, 0), 'scalar', (False,)),
('transpose', (1,), (0, 0), '1d', (False,)),
('transpose', (L, L), (0, 1), '2d', (False,)),
('transpose', (S, S, S), (2, 0), '3d', (False,)),
('t', (1, 2), NO_ARGS, '', (False,)),
('view', (S, S, S), (S * S, S), '', (False,)),
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('view', (S,), (S,), '1d', (False,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('view', (), (1,), 'scalar_to_1d', (False,)),
('reshape', (S, S, S), (S * S, S), '', (False,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
('reshape', (S,), (S,), '1d', (False,)),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
('reshape', (), (1,), 'scalar_to_1d', (False,)),
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
@ -220,14 +220,14 @@ def method_tests():
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
('expand', (S, 1, 1), (S, S, S), '', (True,)),
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)),
('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
('expand', (1,), (S, S, S), '1_element', (True,)),
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
('expand', (S, 1, 1), (S, S, S), '', (False,)),
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (False,)),
('expand', (S, 1), (S, S, S), 'new_dim', (False,)),
('expand', (1,), (S, S, S), '1_element', (False,)),
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (False,)),
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)),
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),
('exp', (), NO_ARGS, 'scalar', (True,)),
('expm1', (S, S, S), NO_ARGS, '', (True,)),