pytorch/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
Richard Zou e29b5a1ea8 graph fuser inserts explicit expands where necessary (#10325)
Summary:
Fixes #10096

If the only thing preventing a simple mappable operator from being fused
into a fusion group is that its Tensor inputs are not of the same shape as the
output, then the graph fuser inserts explicit expand nodes for those
inputs.
This helps the graph fuser not miss out on any fusion opportunities
involving simple mappable operations that have Tensor inputs. This PR
doesn't do anything for the scalar case; that can be addressed later.

Test Plan
- Simple expect test case
- Added expect tests for a raw LSTMCell. The expands help speed up the
  forwards pass by allowing more operations to be fused into the LSTMCell's single
  FusionGroup.

cc apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10325

Differential Revision: D9379308

Pulled By: zou3519

fbshipit-source-id: 86d2202eb97e9bb16e511667b7fe177aeaf88245
2018-08-17 16:03:46 -07:00

65 lines
2.9 KiB
Text

graph(%x.1 : Float(3, 10)
%hx.1 : Float(3, 20)
%cx.1 : Float(3, 20)
%w_ih : Float(80, 10)
%w_hh : Float(80, 20)
%b_ih : Float(80)
%b_hh : Float(80)) {
%7 : Float(10!, 80!) = aten::t(%w_ih)
%8 : Float(20!, 80!) = aten::t(%w_hh)
%9 : Float(3, 80) = aten::mm(%hx.1, %8)
%10 : int = prim::Constant[value=1]()
%11 : float = prim::Constant[value=1]()
%12 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %11, %11)
%13 : int[] = prim::Constant[value=[3, 80]]()
%14 : int = prim::Constant[value=0]()
%15 : Float(3!, 80) = aten::expand(%b_ih, %13, %14)
%16 : Float(3!, 80) = aten::expand(%b_hh, %13, %14)
%17 : int = prim::Constant[value=4]()
%18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20), %21 : Float(3!, 20) = aten::chunk(%12, %17, %10)
%22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20), %25 : Float(3!, 20) = aten::chunk(%15, %17, %10)
%26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20), %29 : Float(3!, 20) = aten::chunk(%16, %17, %10)
%hy : Float(3, 20), %31 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %29, %28, %27, %26, %21, %25, %20, %24, %19, %23, %18, %22)
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %31);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%24 : Float(3!, 20)
%28 : Float(3!, 20)
%32 : Float(3!, 20)
%36 : Float(3!, 20)
%39 : Float(3!, 20)
%40 : Float(3!, 20)
%43 : Float(3!, 20)
%44 : Float(3!, 20)
%47 : Float(3!, 20)
%48 : Float(3!, 20)
%51 : Float(3!, 20)
%52 : Float(3!, 20)) {
%53 : int = prim::Constant[value=1]()
%54 : Float(3, 20) = aten::add(%51, %52, %53)
%49 : int = prim::Constant[value=1]()
%50 : Float(3, 20) = aten::add(%47, %48, %49)
%45 : int = prim::Constant[value=1]()
%46 : Float(3, 20) = aten::add(%43, %44, %45)
%41 : int = prim::Constant[value=1]()
%42 : Float(3, 20) = aten::add(%39, %40, %41)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%54, %36, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%50, %32, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%46, %28, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%42, %24, %25)
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
%outgate.2 : Float(3, 20) = aten::sigmoid(%26)
%14 : Float(3, 20) = aten::mul(%forgetgate.2, %13)
%11 : Float(3, 20) = aten::mul(%ingate.2, %cellgate.2)
%7 : int = prim::Constant[value=1]()
%cy : Float(3, 20) = aten::add(%14, %11, %7)
%4 : Float(3, 20) = aten::tanh(%cy)
%hy : Float(3, 20) = aten::mul(%outgate.2, %4)
return (%hy, %4, %cy, %outgate.2, %cellgate.2, %forgetgate.2, %ingate.2);
}