graph fuser inserts explicit expands where necessary (#10325)

Summary:
Fixes #10096

If the only thing preventing a simple mappable operator from being fused
into a fusion group is that its Tensor inputs are not of the same shape as the
output, then the graph fuser inserts explicit expand nodes for those
inputs.
This helps the graph fuser not miss out on any fusion opportunities
involving simple mappable operations that have Tensor inputs. This PR
doesn't do anything for the scalar case; that can be addressed later.

Test Plan
- Simple expect test case
- Added expect tests for a raw LSTMCell. The expands help speed up the
  forwards pass by allowing more operations to be fused into the LSTMCell's single
  FusionGroup.

cc apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10325

Differential Revision: D9379308

Pulled By: zou3519

fbshipit-source-id: 86d2202eb97e9bb16e511667b7fe177aeaf88245
This commit is contained in:
Richard Zou 2018-08-17 15:46:06 -07:00 committed by Facebook Github Bot
parent 7c55d11ba5
commit e29b5a1ea8
6 changed files with 306 additions and 83 deletions

View file

@ -0,0 +1,18 @@
graph(%0 : Float(4, 4)
%1 : Float(4)
%2 : Float(4)) {
%3 : int[] = prim::Constant[value=[4, 4]]()
%4 : int = prim::Constant[value=0]()
%5 : Float(4!, 4) = aten::expand(%1, %3, %4)
%6 : Float(4!, 4) = aten::expand(%2, %3, %4)
%7 : Float(4, 4) = prim::FusionGroup_0[device=0](%6, %0, %5)
return (%7);
}
with prim::FusionGroup_0 = graph(%1 : Float(4!, 4)
%4 : Float(4, 4)
%5 : Float(4!, 4)) {
%6 : Float(4, 4) = aten::mul(%4, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(4, 4) = aten::add(%6, %1, %2)
return (%3);
}

View file

@ -10,24 +10,51 @@ graph(%x.1 : Float(3, 10)
%9 : Float(3, 80) = aten::mm(%hx.1, %8)
%10 : int = prim::Constant[value=1]()
%11 : float = prim::Constant[value=1]()
%12 : float = prim::Constant[value=1]()
%13 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %11, %12)
%14 : Float(3, 80) = aten::add(%13, %b_ih, %10)
%gates : Float(3, 80) = aten::add(%14, %b_hh, %10)
%16 : int = prim::Constant[value=4]()
%ingate.1 : Float(3!, 20), %forgetgate.1 : Float(3!, 20), %cellgate.1 : Float(3!, 20), %outgate.1 : Float(3!, 20) = aten::chunk(%gates, %16, %10)
%hy : Float(3, 20), %22 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %22);
%12 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %11, %11)
%13 : int[] = prim::Constant[value=[3, 80]]()
%14 : int = prim::Constant[value=0]()
%15 : Float(3!, 80) = aten::expand(%b_ih, %13, %14)
%16 : Float(3!, 80) = aten::expand(%b_hh, %13, %14)
%17 : int = prim::Constant[value=4]()
%18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20), %21 : Float(3!, 20) = aten::chunk(%12, %17, %10)
%22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20), %25 : Float(3!, 20) = aten::chunk(%15, %17, %10)
%26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20), %29 : Float(3!, 20) = aten::chunk(%16, %17, %10)
%hy : Float(3, 20), %31 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %29, %28, %27, %26, %21, %25, %20, %24, %19, %23, %18, %22)
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %31);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%15 : Float(3!, 20)
%17 : Float(3!, 20)
%19 : Float(3!, 20)
%21 : Float(3!, 20)) {
%ingate.2 : Float(3, 20) = aten::sigmoid(%21)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%19)
%cellgate.2 : Float(3, 20) = aten::tanh(%17)
%outgate.2 : Float(3, 20) = aten::sigmoid(%15)
%24 : Float(3!, 20)
%28 : Float(3!, 20)
%32 : Float(3!, 20)
%36 : Float(3!, 20)
%39 : Float(3!, 20)
%40 : Float(3!, 20)
%43 : Float(3!, 20)
%44 : Float(3!, 20)
%47 : Float(3!, 20)
%48 : Float(3!, 20)
%51 : Float(3!, 20)
%52 : Float(3!, 20)) {
%53 : int = prim::Constant[value=1]()
%54 : Float(3, 20) = aten::add(%51, %52, %53)
%49 : int = prim::Constant[value=1]()
%50 : Float(3, 20) = aten::add(%47, %48, %49)
%45 : int = prim::Constant[value=1]()
%46 : Float(3, 20) = aten::add(%43, %44, %45)
%41 : int = prim::Constant[value=1]()
%42 : Float(3, 20) = aten::add(%39, %40, %41)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%54, %36, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%50, %32, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%46, %28, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%42, %24, %25)
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
%outgate.2 : Float(3, 20) = aten::sigmoid(%26)
%14 : Float(3, 20) = aten::mul(%forgetgate.2, %13)
%11 : Float(3, 20) = aten::mul(%ingate.2, %cellgate.2)
%7 : int = prim::Constant[value=1]()

View file

@ -26,21 +26,18 @@ graph(%0 : Float(3, 20!)
%cellgate : Float(3, 20)
%outgate : Float(3, 20)
%27 : Float(3, 20)) {
%28 : int = prim::Constant[value=1]()
%29 : Float(3, 80) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0)
%30 : Float(3, 80), %31 : Float(3, 80) = prim::FusionGroup_1[device=0](%Uz, %29)
%32 : Float(3, 80) = aten::mul(%31, %beta_h)
%33 : Float(3, 80) = aten::mul(%31, %Wx)
%34 : Float(3, 80) = aten::mul(%31, %beta_i)
%35 : Float(3, 80) = prim::FusionGroup_2[device=0](%32, %29, %22)
%36 : Float(3, 80), %37 : Float(3, 80) = prim::FusionGroup_3[device=0](%Wx, %29, %Uz)
%38 : Float(3, 80) = aten::mul(%37, %alpha)
%39 : Float(3, 80) = aten::add(%34, %38, %28)
%40 : Float(80!, 3!) = aten::t(%35)
%28 : Float(3, 80) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0)
%29 : int[] = prim::Constant[value=[3, 80]]()
%30 : int = prim::Constant[value=0]()
%31 : Float(3!, 80) = aten::expand(%beta_h, %29, %30)
%32 : Float(3!, 80) = aten::expand(%beta_i, %29, %30)
%33 : Float(3!, 80) = aten::expand(%alpha, %29, %30)
%34 : Float(3, 80), %35 : Float(3, 80), %36 : Float(3, 80), %37 : Float(3, 80), %38 : Float(3, 80), %39 : Float(3, 80) = prim::FusionGroup_1[device=0](%33, %32, %Wx, %28, %Uz, %22, %31)
%40 : Float(80!, 3!) = aten::t(%36)
%41 : Float(80, 20) = aten::mm(%40, %hx)
%42 : Float(80!, 3!) = aten::t(%39)
%42 : Float(80!, 3!) = aten::t(%34)
%43 : Float(80, 10) = aten::mm(%42, %x)
return (%43, %41, %36, %33, %30, %31);
return (%43, %41, %35, %37, %38, %39);
}
with prim::FusionGroup_0 = graph(%9 : Float(3, 20)
%19 : Float(3, 20)
@ -87,25 +84,30 @@ with prim::FusionGroup_0 = graph(%9 : Float(3, 20)
%4 : Float(3, 80) = prim::FusedConcat[dim=1](%7, %17, %27, %37)
return (%4);
}
with prim::FusionGroup_1 = graph(%1 : Float(3, 80)
%3 : Float(3, 80)) {
%4 : int = prim::Constant[value=1]()
%5 : Float(3, 80) = aten::mul(%3, %4)
%2 : Float(3, 80) = aten::mul(%5, %1)
return (%2, %5);
}
with prim::FusionGroup_2 = graph(%0 : Float(3, 80)
%4 : Float(3, 80)
%5 : Float(3, 80)) {
%6 : Float(3, 80) = aten::mul(%4, %5)
with prim::FusionGroup_1 = graph(%5 : Float(3!, 80)
%8 : Float(3!, 80)
%10 : Float(3, 80)
%12 : Float(3, 80)
%13 : Float(3, 80)
%20 : Float(3, 80)
%22 : Float(3!, 80)) {
%30 : int = prim::Constant[value=1]()
%29 : int = prim::Constant[value=1]()
%28 : int = prim::Constant[value=1]()
%26 : int = prim::Constant[value=1]()
%27 : Float(3, 80) = aten::mul(%12, %26)
%25 : Float(3, 80) = aten::mul(%27, %13)
%24 : Float(3, 80) = aten::mul(%27, %10)
%23 : Float(3, 80) = aten::mul(%27, %22)
%21 : Float(3, 80) = aten::mul(%12, %20)
%19 : int = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]()
%18 : Float(3, 80) = aten::add(%23, %21, %17)
%14 : Float(3, 80) = aten::mul(%12, %13)
%11 : Float(3, 80) = aten::mul(%14, %10)
%9 : Float(3, 80) = aten::mul(%27, %8)
%6 : Float(3, 80) = aten::mul(%14, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(3, 80) = aten::add(%0, %6, %2)
return (%3);
}
with prim::FusionGroup_3 = graph(%1 : Float(3, 80)
%3 : Float(3, 80)
%4 : Float(3, 80)) {
%5 : Float(3, 80) = aten::mul(%3, %4)
%2 : Float(3, 80) = aten::mul(%5, %1)
return (%2, %5);
%3 : Float(3, 80) = aten::add(%9, %6, %2)
return (%3, %11, %18, %24, %25, %27);
}

View file

@ -12,36 +12,91 @@ graph(%x.1 : Float(3, 10)
%11 : Float(20!, 80!) = aten::t(%w_hh)
%Uz.1 : Float(3, 80) = aten::mm(%hx.1, %11)
%13 : Float(3, 80) = aten::mul(%alpha.1, %Wx.1)
%14 : Float(3, 80) = aten::mul(%beta_i.1, %Wx.1)
%15 : int = prim::Constant[value=1]()
%16 : Float(3, 80) = aten::mul(%beta_h.1, %Uz.1)
%17 : Float(3, 80) = prim::FusionGroup_0[device=0](%16, %14, %13, %Uz.1)
%gates : Float(3, 80) = aten::add(%17, %bias, %15)
%19 : int = prim::Constant[value=4]()
%ingate.1 : Float(3!, 20), %forgetgate.1 : Float(3!, 20), %cellgate.1 : Float(3!, 20), %outgate.1 : Float(3!, 20) = aten::chunk(%gates, %19, %15)
%hy : Float(3, 20), %25 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_1[device=0](%cx.1, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %25);
%14 : int[] = prim::Constant[value=[3, 80]]()
%15 : int = prim::Constant[value=0]()
%16 : Float(3!, 80) = aten::expand(%beta_i.1, %14, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(3!, 80) = aten::expand(%beta_h.1, %14, %15)
%19 : Float(3!, 80) = aten::expand(%bias, %14, %15)
%20 : int = prim::Constant[value=4]()
%21 : Float(3!, 20), %22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20) = aten::chunk(%13, %20, %17)
%25 : Float(3!, 20), %26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20) = aten::chunk(%Uz.1, %20, %17)
%29 : Float(3!, 20), %30 : Float(3!, 20), %31 : Float(3!, 20), %32 : Float(3!, 20) = aten::chunk(%16, %20, %17)
%33 : Float(3!, 20), %34 : Float(3!, 20), %35 : Float(3!, 20), %36 : Float(3!, 20) = aten::chunk(%Wx.1, %20, %17)
%37 : Float(3!, 20), %38 : Float(3!, 20), %39 : Float(3!, 20), %40 : Float(3!, 20) = aten::chunk(%18, %20, %17)
%41 : Float(3!, 20), %42 : Float(3!, 20), %43 : Float(3!, 20), %44 : Float(3!, 20) = aten::chunk(%19, %20, %17)
%hy : Float(3, 20), %46 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %44, %43, %42, %41, %40, %28, %39, %27, %37, %25, %38, %26, %31, %35, %30, %34, %22, %26, %29, %33, %21, %25, %23, %27, %32, %36, %24, %28)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %46);
}
with prim::FusionGroup_0 = graph(%1 : Float(3, 80)
%5 : Float(3, 80)
%8 : Float(3, 80)
%9 : Float(3, 80)) {
%10 : Float(3, 80) = aten::mul(%8, %9)
%6 : int = prim::Constant[value=1]()
%7 : Float(3, 80) = aten::add(%10, %5, %6)
%2 : int = prim::Constant[value=1]()
%3 : Float(3, 80) = aten::add(%7, %1, %2)
return (%3);
}
with prim::FusionGroup_1 = graph(%13 : Float(3, 20)
%15 : Float(3!, 20)
%17 : Float(3!, 20)
%19 : Float(3!, 20)
%21 : Float(3!, 20)) {
%ingate.2 : Float(3, 20) = aten::sigmoid(%21)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%19)
%cellgate.2 : Float(3, 20) = aten::tanh(%17)
%outgate.2 : Float(3, 20) = aten::sigmoid(%15)
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%24 : Float(3!, 20)
%28 : Float(3!, 20)
%32 : Float(3!, 20)
%36 : Float(3!, 20)
%59 : Float(3!, 20)
%60 : Float(3!, 20)
%66 : Float(3!, 20)
%67 : Float(3!, 20)
%69 : Float(3!, 20)
%70 : Float(3!, 20)
%76 : Float(3!, 20)
%77 : Float(3!, 20)
%83 : Float(3!, 20)
%84 : Float(3!, 20)
%86 : Float(3!, 20)
%87 : Float(3!, 20)
%89 : Float(3!, 20)
%90 : Float(3!, 20)
%92 : Float(3!, 20)
%93 : Float(3!, 20)
%95 : Float(3!, 20)
%96 : Float(3!, 20)
%98 : Float(3!, 20)
%99 : Float(3!, 20)
%101 : Float(3!, 20)
%102 : Float(3!, 20)
%104 : Float(3!, 20)
%105 : Float(3!, 20)) {
%106 : Float(3, 20) = aten::mul(%104, %105)
%103 : Float(3, 20) = aten::mul(%101, %102)
%100 : Float(3, 20) = aten::mul(%98, %99)
%97 : Float(3, 20) = aten::mul(%95, %96)
%94 : Float(3, 20) = aten::mul(%92, %93)
%91 : Float(3, 20) = aten::mul(%89, %90)
%88 : Float(3, 20) = aten::mul(%86, %87)
%85 : Float(3, 20) = aten::mul(%83, %84)
%81 : int = prim::Constant[value=1]()
%82 : Float(3, 20) = aten::add(%91, %88, %81)
%78 : Float(3, 20) = aten::mul(%76, %77)
%74 : int = prim::Constant[value=1]()
%75 : Float(3, 20) = aten::add(%97, %94, %74)
%71 : Float(3, 20) = aten::mul(%69, %70)
%68 : Float(3, 20) = aten::mul(%66, %67)
%64 : int = prim::Constant[value=1]()
%65 : Float(3, 20) = aten::add(%100, %85, %64)
%61 : Float(3, 20) = aten::mul(%59, %60)
%57 : int = prim::Constant[value=1]()
%58 : Float(3, 20) = aten::add(%106, %103, %57)
%53 : int = prim::Constant[value=1]()
%54 : Float(3, 20) = aten::add(%75, %71, %53)
%49 : int = prim::Constant[value=1]()
%50 : Float(3, 20) = aten::add(%82, %78, %49)
%45 : int = prim::Constant[value=1]()
%46 : Float(3, 20) = aten::add(%65, %68, %45)
%41 : int = prim::Constant[value=1]()
%42 : Float(3, 20) = aten::add(%58, %61, %41)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%54, %36, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%50, %32, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%46, %28, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%42, %24, %25)
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
%outgate.2 : Float(3, 20) = aten::sigmoid(%26)
%14 : Float(3, 20) = aten::mul(%forgetgate.2, %13)
%11 : Float(3, 20) = aten::mul(%ingate.2, %cellgate.2)
%7 : int = prim::Constant[value=1]()

View file

@ -475,6 +475,21 @@ class TestJit(JitTestCase):
self.assertExportImport(trace, (t,) + tuple(model.parameters()))
self.assertExpectedONNXGraph(trace)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
def test_broadcast_fusion_cuda(self):
def scaleshift(x, scale, shift):
return x * scale + shift
inputs = [
torch.randn(4, 4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
]
ge = self.checkTrace(scaleshift, inputs)
self.assertExpectedGraph(ge.graph_for(*inputs))
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")

View file

@ -1,7 +1,9 @@
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/fusion_compiler.h"
#include "torch/csrc/jit/autodiff.h"
#include "torch/csrc/jit/assertions.h"
#include "ATen/ExpandUtils.h"
#include <unordered_map>
#ifdef USE_CUDA
@ -186,8 +188,7 @@ struct GraphFuser {
}
bool hasSupportedType(Node* node) {
return areTensorsOfSameShape(node->inputs()) &&
haveSupportedType(node->inputs()) &&
return haveSupportedType(node->inputs()) &&
haveSupportedType(node->outputs());
}
@ -204,6 +205,11 @@ struct GraphFuser {
});
}
// Checks if the node is fusible into a FusionGroup. A node is fusible if:
// - it is a FusionGroup
// - it is a simple map op and its inputs/outputs have compatible types.
// NB: two nodes that are fusible might not be fused together
// if they don't have compatible map_size.
bool isFusable(Node * node) {
if (node->owningBlock() != block) return false;
if (node->kind() == prim::FusionGroup) return true;
@ -224,7 +230,7 @@ struct GraphFuser {
node->matches("aten::div(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
node->matches("aten::div(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other)) {
auto inputs = tensorInputs(node);
return areTensorsOfSameShape(inputs) && haveSupportedType(inputs);
return haveSupportedType(inputs);
}
else if (
node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
@ -244,7 +250,7 @@ struct GraphFuser {
node->matches("aten::ne(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other)) {
// comparison operators produce Byte type, and it's ok, check only inputs
auto inputs = tensorInputs(node);
return areTensorsOfSameShape(inputs) && haveSupportedType(inputs);
return haveSupportedType(inputs);
} else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
// type_as can have different input types as long as output is float, check only output
return haveSupportedType(node->outputs());
@ -361,6 +367,52 @@ struct GraphFuser {
return true;
}
at::optional<at::IntList> mapSize(Node * node) {
if (isSimpleMap(node)) {
auto type = node->output()->type()->cast<TensorType>();
if (!type) {
return at::nullopt;
}
return at::optional<at::IntList>(at::in_place, type->sizes());
}
if (node->kind() == prim::FusionGroup) {
// inputs are guaranteed to be the map_size
auto type = node->inputs().at(0)->type()->cast<TensorType>();
JIT_ASSERT(type);
return at::optional<at::IntList>(at::in_place, type->sizes());
}
if (node->kind() == aten::cat) {
// Assuming all inputs to aten::cat are same size. This is
// a condition for aten::cat to be fusible.
Node * list_construct = node->namedInput(attr::tensors)->node();
JIT_ASSERT(areTensorsOfSameShape(list_construct->inputs()));
auto type = list_construct->inputs().at(0)->type()->cast<TensorType>();
return at::optional<at::IntList>(at::in_place, type->sizes());
}
if (node->kind() == aten::chunk) {
// Assuming all outputs to aten::chunk are same size.
// This is a condition for the graph fuser to operate on
// aten::chunk nodes and is checked elsewhere.
JIT_ASSERT(areTensorsOfSameShape(node->outputs()));
auto type = node->outputs().at(0)->type()->cast<TensorType>();
return at::optional<at::IntList>(at::in_place, type->sizes());
}
return at::nullopt;
}
bool equalSizes(at::IntList a, at::IntList b) {
return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
}
bool haveSameMapSize(Node * consumer, Node * producer) {
auto consumer_map_size = mapSize(consumer);
auto producer_map_size = mapSize(producer);
if (!consumer_map_size || !producer_map_size) {
return false;
}
return equalSizes(*consumer_map_size, *producer_map_size);
}
bool shouldFuse(Node * consumer, Value * producer) {
// this handles cases where producer can be moved _into_ the fusion group of consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
@ -369,10 +421,46 @@ struct GraphFuser {
// but this requires better handling of merging fusion groups so it is not done now
Node *real_consumer = consumer->kind() == aten::cat ? consumer->namedInput(attr::tensors)->node() : consumer;
return isFusable(producer->node()) &&
haveSameMapSize(consumer, producer->node()) &&
allUsersAreThisConsumerOrOccurAfterIt(real_consumer, producer) &&
compatibleDevices(consumer, producer);
}
void maybeInsertExplicitExpands(Node * node) {
if (!isSimpleMap(node)) {
return;
}
WithInsertPoint guard(node);
auto map_size = mapSize(node).value();
auto * graph = node->owningGraph();
auto tensor_inputs = tensorInputs(node);
for (auto * producer: tensor_inputs) {
auto type = producer->type()->cast<TensorType>();
JIT_ASSERT(type);
if (equalSizes(map_size, type->sizes())) {
continue;
}
// Insert explicit expand node when input doesn't have correct size.
//
// XXX: This hardcodes the "map size" for this FusionGroup.
// If we want to make the graph fuser more general in the future,
// we could use aten::broadcast_tensors or add a primitive op that broadcasts.
auto * expand = graph->insert(
aten::expand,
{producer, graph->insertConstant(IValue(map_size)), graph->insertConstant(0)})->node();
{
std::vector<int64_t> sizes, strides;
std::tie(sizes, strides) = at::inferExpandGeometry(
type->sizes(), type->strides(), map_size);
expand->output()->setType(type->withSizesStrides(sizes, strides));
}
topological_index[expand] = topological_index[producer->node()];
node->replaceInputWith(producer, expand->output());
}
}
// insert a producer node into a consuming fusion group.
// DOES NOT WORK if n is a consumer of an output of the fusion group
// returns the node _inside_ the group that represents the node
@ -438,6 +526,7 @@ struct GraphFuser {
Node * mergeNodeIntoGroup(Node* group, Node * n) {
JIT_ASSERT(n->kind() != prim::FusionGroup);
maybeInsertExplicitExpands(n);
auto & subgraph = getSubgraph(group);
// map from nodes in the surrounding graph to parameters in the fusion
// group's subgraph that correspond to them
@ -489,6 +578,7 @@ struct GraphFuser {
// turn consumer node n into a fusion group with just n inside
// to prepare for fusion and replace uses of n with the new group
Node * createSingletonFusionGroup(Node * n) {
maybeInsertExplicitExpands(n);
auto group = block->owningGraph()->createFusionGroup(getDevice(n).index());
// propogate position information for the new node so we can always
// have a valid mapping
@ -582,7 +672,12 @@ struct GraphFuser {
return false;
// and the thing being chunked is fusable into the consumer
Value * producer_for_chunk = chunk->namedInput(attr::self);
if (!isFusable(producer_for_chunk->node()) || !allUsersAreThisConsumer(chunk,producer_for_chunk))
if (!isFusable(producer_for_chunk->node()) ||
!allUsersAreThisConsumer(chunk,producer_for_chunk) ||
!areTensorsOfSameShape(chunk->outputs()) ||
// After moving the chunk, op will have the same map_size as chunk.
// This checks if op will have same map_size as consumer after the move.
!haveSameMapSize(consumer, chunk))
return false;
// and all uses of the chunk are in this consumer
for (auto s : chunk->outputs()) {
@ -592,6 +687,15 @@ struct GraphFuser {
}
}
// First, we'll add explicit expands where necessary to make the chunk
// move valid. Let's say we have:
// %z = aten::mul(%x, %y)
// %z.1, %z.2 = aten::chunk(%z, ...)
// ... = prim::FusionGroup(%z.1, %z.2, ...)
// It's possible that %x and %y do not have the same size as %z and
// need to be expanded first so that they can be chunked like %z
maybeInsertExplicitExpands(producer_for_chunk->node());
// multiple return operators
Node * producer_for_chunk_node = producer_for_chunk->node();
JIT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
@ -746,6 +850,8 @@ struct GraphFuser {
void FuseGraph(std::shared_ptr<Graph>& graph) {
GraphFuser(graph->block()).run();
// After FuseGraph some common subexpressions may come back
EliminateCommonSubexpression(graph);
}
}}