mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Fixes #10096 If the only thing preventing a simple mappable operator from being fused into a fusion group is that its Tensor inputs are not of the same shape as the output, then the graph fuser inserts explicit expand nodes for those inputs. This helps the graph fuser not miss out on any fusion opportunities involving simple mappable operations that have Tensor inputs. This PR doesn't do anything for the scalar case; that can be addressed later. Test Plan - Simple expect test case - Added expect tests for a raw LSTMCell. The expands help speed up the forwards pass by allowing more operations to be fused into the LSTMCell's single FusionGroup. cc apaszke zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/10325 Differential Revision: D9379308 Pulled By: zou3519 fbshipit-source-id: 86d2202eb97e9bb16e511667b7fe177aeaf88245
65 lines
2.9 KiB
Text
65 lines
2.9 KiB
Text
graph(%x.1 : Float(3, 10)
|
|
%hx.1 : Float(3, 20)
|
|
%cx.1 : Float(3, 20)
|
|
%w_ih : Float(80, 10)
|
|
%w_hh : Float(80, 20)
|
|
%b_ih : Float(80)
|
|
%b_hh : Float(80)) {
|
|
%7 : Float(10!, 80!) = aten::t(%w_ih)
|
|
%8 : Float(20!, 80!) = aten::t(%w_hh)
|
|
%9 : Float(3, 80) = aten::mm(%hx.1, %8)
|
|
%10 : int = prim::Constant[value=1]()
|
|
%11 : float = prim::Constant[value=1]()
|
|
%12 : Float(3, 80) = aten::addmm(%9, %x.1, %7, %11, %11)
|
|
%13 : int[] = prim::Constant[value=[3, 80]]()
|
|
%14 : int = prim::Constant[value=0]()
|
|
%15 : Float(3!, 80) = aten::expand(%b_ih, %13, %14)
|
|
%16 : Float(3!, 80) = aten::expand(%b_hh, %13, %14)
|
|
%17 : int = prim::Constant[value=4]()
|
|
%18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20), %21 : Float(3!, 20) = aten::chunk(%12, %17, %10)
|
|
%22 : Float(3!, 20), %23 : Float(3!, 20), %24 : Float(3!, 20), %25 : Float(3!, 20) = aten::chunk(%15, %17, %10)
|
|
%26 : Float(3!, 20), %27 : Float(3!, 20), %28 : Float(3!, 20), %29 : Float(3!, 20) = aten::chunk(%16, %17, %10)
|
|
%hy : Float(3, 20), %31 : Float(3, 20), %cy : Float(3, 20), %outgate.2 : Float(3, 20), %cellgate.2 : Float(3, 20), %forgetgate.2 : Float(3, 20), %ingate.2 : Float(3, 20) = prim::FusionGroup_0[device=0](%cx.1, %29, %28, %27, %26, %21, %25, %20, %24, %19, %23, %18, %22)
|
|
return (%hy, %cy, %7, %8, %ingate.2, %forgetgate.2, %cellgate.2, %outgate.2, %31);
|
|
}
|
|
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
|
|
%24 : Float(3!, 20)
|
|
%28 : Float(3!, 20)
|
|
%32 : Float(3!, 20)
|
|
%36 : Float(3!, 20)
|
|
%39 : Float(3!, 20)
|
|
%40 : Float(3!, 20)
|
|
%43 : Float(3!, 20)
|
|
%44 : Float(3!, 20)
|
|
%47 : Float(3!, 20)
|
|
%48 : Float(3!, 20)
|
|
%51 : Float(3!, 20)
|
|
%52 : Float(3!, 20)) {
|
|
%53 : int = prim::Constant[value=1]()
|
|
%54 : Float(3, 20) = aten::add(%51, %52, %53)
|
|
%49 : int = prim::Constant[value=1]()
|
|
%50 : Float(3, 20) = aten::add(%47, %48, %49)
|
|
%45 : int = prim::Constant[value=1]()
|
|
%46 : Float(3, 20) = aten::add(%43, %44, %45)
|
|
%41 : int = prim::Constant[value=1]()
|
|
%42 : Float(3, 20) = aten::add(%39, %40, %41)
|
|
%37 : int = prim::Constant[value=1]()
|
|
%38 : Float(3, 20) = aten::add(%54, %36, %37)
|
|
%33 : int = prim::Constant[value=1]()
|
|
%34 : Float(3, 20) = aten::add(%50, %32, %33)
|
|
%29 : int = prim::Constant[value=1]()
|
|
%30 : Float(3, 20) = aten::add(%46, %28, %29)
|
|
%25 : int = prim::Constant[value=1]()
|
|
%26 : Float(3, 20) = aten::add(%42, %24, %25)
|
|
%ingate.2 : Float(3, 20) = aten::sigmoid(%38)
|
|
%forgetgate.2 : Float(3, 20) = aten::sigmoid(%34)
|
|
%cellgate.2 : Float(3, 20) = aten::tanh(%30)
|
|
%outgate.2 : Float(3, 20) = aten::sigmoid(%26)
|
|
%14 : Float(3, 20) = aten::mul(%forgetgate.2, %13)
|
|
%11 : Float(3, 20) = aten::mul(%ingate.2, %cellgate.2)
|
|
%7 : int = prim::Constant[value=1]()
|
|
%cy : Float(3, 20) = aten::add(%14, %11, %7)
|
|
%4 : Float(3, 20) = aten::tanh(%cy)
|
|
%hy : Float(3, 20) = aten::mul(%outgate.2, %4)
|
|
return (%hy, %4, %cy, %outgate.2, %cellgate.2, %forgetgate.2, %ingate.2);
|
|
}
|