2021-10-29 19:40:39 +00:00
|
|
|
# Owner(s): ["module: unknown"]
|
|
|
|
|
|
2021-04-06 03:50:39 +00:00
|
|
|
import unittest
|
2021-07-10 21:04:48 +00:00
|
|
|
from typing import Dict, Optional
|
2021-04-06 03:50:39 +00:00
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
2020-08-12 20:02:29 +00:00
|
|
|
from torch import nn
|
2020-08-29 06:17:17 +00:00
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
2024-03-04 19:10:46 +00:00
|
|
|
from torch.testing._internal.static_module import StaticModule
|
2022-06-03 23:39:04 +00:00
|
|
|
from typing import List
|
2020-08-12 20:02:29 +00:00
|
|
|
|
2020-09-14 19:33:02 +00:00
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
def linear_shim(
|
|
|
|
|
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
|
|
|
|
) -> torch.Tensor:
|
2020-08-14 03:16:57 +00:00
|
|
|
output = input.matmul(weight.t())
|
|
|
|
|
if bias is not None:
|
|
|
|
|
output += bias
|
|
|
|
|
ret = output
|
|
|
|
|
return ret
|
2020-09-14 19:33:02 +00:00
|
|
|
|
|
|
|
|
|
2020-08-14 03:16:57 +00:00
|
|
|
torch.nn.functional.linear = linear_shim
|
|
|
|
|
|
2020-08-12 20:02:29 +00:00
|
|
|
|
|
|
|
|
class MultiHeadAttentionLayer(nn.Module):
|
|
|
|
|
def __init__(self, hid_dim, n_heads, dropout, device):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert hid_dim % n_heads == 0
|
|
|
|
|
self.hid_dim = hid_dim
|
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
self.head_dim = hid_dim // n_heads
|
|
|
|
|
self.fc_q = nn.Linear(hid_dim, hid_dim)
|
|
|
|
|
self.fc_k = nn.Linear(hid_dim, hid_dim)
|
|
|
|
|
self.fc_v = nn.Linear(hid_dim, hid_dim)
|
|
|
|
|
self.fc_o = nn.Linear(hid_dim, hid_dim)
|
|
|
|
|
# self.dropout = nn.Dropout(dropout)
|
|
|
|
|
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, value, mask):
|
|
|
|
|
batch_size = query.shape[0]
|
|
|
|
|
Q = self.fc_q(query)
|
|
|
|
|
K = self.fc_k(key)
|
|
|
|
|
V = self.fc_v(value)
|
|
|
|
|
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
|
|
|
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
|
|
|
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
|
|
|
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
|
|
|
|
|
# energy = energy.masked_fill(mask == 0, -1e10)
|
|
|
|
|
attention = torch.softmax(energy, dim=-1)
|
|
|
|
|
# x = torch.matmul(self.dropout(attention), V)
|
|
|
|
|
x = torch.matmul(attention, V)
|
|
|
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
x = x.view(batch_size, -1, self.hid_dim)
|
|
|
|
|
x = self.fc_o(x)
|
|
|
|
|
return x, attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
|
|
|
|
|
def create_mlp(ln, sigmoid_layer):
|
|
|
|
|
layers = nn.ModuleList()
|
|
|
|
|
for i in range(0, len(ln) - 1):
|
|
|
|
|
n = ln[i]
|
|
|
|
|
m = ln[i + 1]
|
|
|
|
|
|
|
|
|
|
LL = nn.Linear(int(n), int(m), bias=True)
|
|
|
|
|
|
|
|
|
|
mean = 0.0 # std_dev = np.sqrt(variance)
|
|
|
|
|
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
|
|
|
|
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
|
|
|
|
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
|
|
|
|
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
|
|
|
|
LL.weight.data = torch.tensor(W, requires_grad=True)
|
|
|
|
|
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
|
|
|
|
layers.append(LL)
|
|
|
|
|
|
|
|
|
|
if i == sigmoid_layer:
|
|
|
|
|
layers.append(nn.Sigmoid())
|
|
|
|
|
else:
|
|
|
|
|
layers.append(nn.ReLU())
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
s = torch.jit.script(torch.nn.Sequential(*layers))
|
|
|
|
|
s.eval()
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def trivial_graph(a, b, c):
|
|
|
|
|
s = torch.tensor([[3, 3], [3, 3]])
|
|
|
|
|
return a + b * c + s
|
|
|
|
|
|
2022-06-03 23:39:04 +00:00
|
|
|
def elementwise_square_addition(input1, input2):
|
|
|
|
|
return input1 * input1 + input2 * input2
|
|
|
|
|
|
|
|
|
|
def fork_wait_graph1(input1, input2):
|
|
|
|
|
fut = torch.jit.fork(elementwise_square_addition, input1, input2)
|
|
|
|
|
return torch.jit.wait(fut)
|
|
|
|
|
|
|
|
|
|
def fork_wait_graph2(input1, input2):
|
|
|
|
|
fut = torch.jit.fork(loop_graph, input1, input2, 5)
|
|
|
|
|
return torch.jit.wait(fut)
|
|
|
|
|
|
2022-06-20 16:32:17 +00:00
|
|
|
"""
|
|
|
|
|
graph with multiple fork/wait operations
|
|
|
|
|
:param input: torch.tensor input to forked subgraph
|
|
|
|
|
:param iters: number of future/wait pairs to be created
|
|
|
|
|
"""
|
|
|
|
|
def fork_wait_graph3(input, iters: int):
|
2022-06-03 23:39:04 +00:00
|
|
|
futures : List[torch.jit.Future[torch.Tensor]] = []
|
2022-06-20 16:32:17 +00:00
|
|
|
for _ in range(iters):
|
2022-06-03 23:39:04 +00:00
|
|
|
futures.append(torch.jit.fork(torch.neg, input))
|
|
|
|
|
results = []
|
|
|
|
|
for future in futures:
|
|
|
|
|
results.append(torch.jit.wait(future))
|
|
|
|
|
return torch.sum(torch.stack(results))
|
2021-07-10 21:04:48 +00:00
|
|
|
|
2022-06-20 16:32:17 +00:00
|
|
|
"""
|
|
|
|
|
graph with multi-level fork/wait operations
|
|
|
|
|
:param input: torch.tensor input to forked subgraph
|
|
|
|
|
:param num_forks: number of top level forks
|
|
|
|
|
:param num_child_forks: number of child forks per parent fork
|
|
|
|
|
"""
|
|
|
|
|
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
|
|
|
|
|
futures : List[torch.jit.Future[torch.Tensor]] = []
|
|
|
|
|
for _ in range(num_forks):
|
|
|
|
|
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
|
|
|
|
|
results = []
|
|
|
|
|
for future in futures:
|
|
|
|
|
results.append(torch.jit.wait(future))
|
|
|
|
|
return torch.sum(torch.stack(results))
|
|
|
|
|
|
2022-06-11 03:11:49 +00:00
|
|
|
def add_tensor(input1, input2):
|
|
|
|
|
return input1 + input2
|
|
|
|
|
|
|
|
|
|
def fork_wait_graph_exception(input1, input2):
|
|
|
|
|
fut = torch.jit.fork(add_tensor, input1, input2)
|
|
|
|
|
return torch.jit.wait(fut)
|
|
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
def loop_graph(a, b, iters: int):
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
c = a + b * 2
|
|
|
|
|
for i in range(iters):
|
|
|
|
|
c = c + b
|
|
|
|
|
c *= 2
|
|
|
|
|
c -= a
|
|
|
|
|
return c
|
|
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
|
|
|
|
|
def output_graph(a, b, c, iters: int):
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
s = torch.tensor([[3, 3], [3, 3]])
|
|
|
|
|
k = a + b * c + s
|
2021-07-10 21:04:48 +00:00
|
|
|
d: Dict[int, torch.Tensor] = {}
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
for i in range(iters):
|
|
|
|
|
d[i] = k + i
|
|
|
|
|
return d
|
2020-09-14 19:33:02 +00:00
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
|
|
|
|
|
class SubModule(nn.Module):
|
2024-08-01 07:22:48 +00:00
|
|
|
def __init__(self) -> None:
|
2023-02-12 22:20:50 +00:00
|
|
|
super().__init__()
|
2021-07-10 21:04:48 +00:00
|
|
|
self.a = 11
|
|
|
|
|
self.b = 2
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.a + self.b + x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SubModule2(nn.Module):
|
2024-08-01 07:22:48 +00:00
|
|
|
def __init__(self) -> None:
|
2023-02-12 22:20:50 +00:00
|
|
|
super().__init__()
|
2021-07-10 21:04:48 +00:00
|
|
|
self.a = 12
|
|
|
|
|
self.b = 2
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
self.b = 30
|
|
|
|
|
return self.a + self.b + x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestModule(nn.Module):
|
2024-08-01 07:22:48 +00:00
|
|
|
def __init__(self) -> None:
|
2023-02-12 22:20:50 +00:00
|
|
|
super().__init__()
|
2021-07-10 21:04:48 +00:00
|
|
|
self.sub1 = SubModule()
|
|
|
|
|
self.sub2 = SubModule2()
|
|
|
|
|
self.a = 3
|
|
|
|
|
self.b = 4
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
self.b = 20
|
|
|
|
|
return self.sub1(x) + self.a + self.b + self.sub2(x)
|
|
|
|
|
|
|
|
|
|
|
2021-03-05 18:12:17 +00:00
|
|
|
class TestStaticModule(TestCase):
|
2022-06-03 23:39:04 +00:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Test Case: To test simple fork/wait operation in a graph
|
|
|
|
|
fork is called on simple addition operation on input tensors
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_1(self):
|
|
|
|
|
inp1 = torch.ones(5, 5)
|
|
|
|
|
inp2 = torch.randn(5, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph1)
|
|
|
|
|
output_ref = torch_graph(inp1, inp2)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module(inp1, inp2)
|
|
|
|
|
torch.testing.assert_close(output_test, output_ref)
|
|
|
|
|
|
2022-07-05 23:40:53 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test simple fork/wait operation with
|
|
|
|
|
StaticRuntime runAsync API returning future
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_1_async(self):
|
|
|
|
|
inp1 = torch.ones(5, 5)
|
|
|
|
|
inp2 = torch.randn(5, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph1)
|
|
|
|
|
output_ref = torch_graph(inp1, inp2)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module.runAsync((inp1, inp2), {})
|
|
|
|
|
output_test.wait()
|
|
|
|
|
torch.testing.assert_close(output_test.value(), output_ref)
|
|
|
|
|
|
2022-06-03 23:39:04 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation in a graph on
|
|
|
|
|
a loop subgraph performing mix of operations
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_2(self):
|
|
|
|
|
inp1 = torch.randn(5, 5)
|
|
|
|
|
inp2 = torch.randn(5, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph2)
|
|
|
|
|
output_ref = torch_graph(inp1, inp2)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module(inp1, inp2)
|
|
|
|
|
torch.testing.assert_close(output_test, output_ref)
|
|
|
|
|
|
2022-07-05 23:40:53 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation on a loop
|
|
|
|
|
subgraph with StaticRuntime runAsync API returning future
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_2_async(self):
|
|
|
|
|
inp1 = torch.randn(5, 5)
|
|
|
|
|
inp2 = torch.randn(5, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph2)
|
|
|
|
|
output_ref = torch_graph(inp1, inp2)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module.runAsync((inp1, inp2), {})
|
|
|
|
|
output_test.wait()
|
|
|
|
|
torch.testing.assert_close(output_test.value(), output_ref)
|
|
|
|
|
|
2022-06-03 23:39:04 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation in a graph on
|
|
|
|
|
having multiple fork/wait operations
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_3(self):
|
|
|
|
|
input = torch.ones(3, 3)
|
2022-06-20 16:32:17 +00:00
|
|
|
num_forks = 10
|
2022-06-03 23:39:04 +00:00
|
|
|
torch_graph = torch.jit.script(fork_wait_graph3)
|
2022-06-20 16:32:17 +00:00
|
|
|
output_ref = torch_graph(input, num_forks)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module(input, num_forks)
|
|
|
|
|
torch.testing.assert_close(output_test, output_ref)
|
2022-07-05 23:40:53 +00:00
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation in a graph with
|
|
|
|
|
multiple fork/wait operations on runAsync API returning future
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_3_async(self):
|
|
|
|
|
input = torch.ones(3, 3)
|
|
|
|
|
num_forks = 10
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph3)
|
|
|
|
|
output_ref = torch_graph(input, num_forks)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module.runAsync((input, num_forks), {})
|
|
|
|
|
output_test.wait()
|
|
|
|
|
torch.testing.assert_close(output_test.value(), output_ref)
|
|
|
|
|
|
2022-06-20 16:32:17 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation in a graph on
|
|
|
|
|
multiple nested fork/wait operations
|
|
|
|
|
"""
|
2023-11-02 20:46:24 +00:00
|
|
|
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
|
2022-06-20 16:32:17 +00:00
|
|
|
def test_fork_wait_4(self):
|
|
|
|
|
input = torch.ones(3, 3)
|
|
|
|
|
num_forks = 10
|
|
|
|
|
num_child_forks = 10
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph4)
|
2022-06-03 23:39:04 +00:00
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
2022-06-20 16:32:17 +00:00
|
|
|
output_ref = torch_graph(input, num_forks, num_child_forks)
|
|
|
|
|
output_test = static_runtime_module(input, num_forks, num_child_forks)
|
2022-06-03 23:39:04 +00:00
|
|
|
torch.testing.assert_close(output_test, output_ref)
|
|
|
|
|
|
2022-07-05 23:40:53 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test fork/wait operation in a graph with multiple
|
|
|
|
|
nested fork/wait operations on runAsync API returning future
|
|
|
|
|
"""
|
2023-11-02 20:46:24 +00:00
|
|
|
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
|
2022-07-05 23:40:53 +00:00
|
|
|
def test_fork_wait_4_async(self):
|
|
|
|
|
input = torch.ones(3, 3)
|
|
|
|
|
num_forks = 10
|
|
|
|
|
num_child_forks = 10
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph4)
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_ref = torch_graph(input, num_forks, num_child_forks)
|
|
|
|
|
output_test = static_runtime_module.runAsync(
|
|
|
|
|
(input, num_forks, num_child_forks), {})
|
|
|
|
|
output_test.wait()
|
|
|
|
|
torch.testing.assert_close(output_test.value(), output_ref)
|
|
|
|
|
|
2022-06-11 03:11:49 +00:00
|
|
|
"""
|
|
|
|
|
Test Case: To test exception handling in fork/wait
|
|
|
|
|
operation. Add.Tensor op is called for tensors with
|
|
|
|
|
non-matching dims on the forked subgraph and the
|
|
|
|
|
exception raised by subgraph is set on future returned
|
|
|
|
|
by prim::fork to parent graph. Returned exception is
|
|
|
|
|
checked for substring expected_error_msg as declared below
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_exception(self):
|
|
|
|
|
# incompatible tensors for add due to shape mismatch
|
|
|
|
|
input1 = torch.randn(4, 7)
|
|
|
|
|
input2 = torch.randn(4, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph_exception)
|
|
|
|
|
try:
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module(input1, input2)
|
2022-07-05 23:40:53 +00:00
|
|
|
except Exception as error:
|
|
|
|
|
expected_error_msg = (
|
|
|
|
|
"The size of tensor a (7) must match the size "
|
|
|
|
|
"of tensor b (5) at non-singleton dimension 1"
|
|
|
|
|
)
|
|
|
|
|
# test fails if error does not contain expected substr
|
|
|
|
|
if str(error).find(expected_error_msg) == -1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Tried execution of add.Tensors with incompatible shape. "
|
|
|
|
|
"Exception raised by forked runtime execution does "
|
2024-06-02 23:25:26 +00:00
|
|
|
f'not contain expected substring: "{expected_error_msg}"'
|
2022-07-05 23:40:53 +00:00
|
|
|
) from error
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Test Case: To test exception handling in fork/wait
|
|
|
|
|
operation with runAsync API. Add.Tensor op is called for
|
|
|
|
|
tensors with non-matching dims on the forked subgraph
|
|
|
|
|
and the exception raised by subgraph is set on future returned
|
|
|
|
|
by prim::fork to parent graph. Returned exception is
|
|
|
|
|
checked for substring expected_error_msg as declared below
|
|
|
|
|
"""
|
|
|
|
|
def test_fork_wait_exception_async(self):
|
|
|
|
|
# incompatible tensors for add due to shape mismatch
|
|
|
|
|
input1 = torch.randn(4, 7)
|
|
|
|
|
input2 = torch.randn(4, 5)
|
|
|
|
|
torch_graph = torch.jit.script(fork_wait_graph_exception)
|
|
|
|
|
try:
|
|
|
|
|
static_runtime_module = StaticModule(torch_graph)
|
|
|
|
|
output_test = static_runtime_module.runAsync(
|
|
|
|
|
(input1, input2), {})
|
2022-06-11 03:11:49 +00:00
|
|
|
except Exception as error:
|
|
|
|
|
expected_error_msg = (
|
|
|
|
|
"The size of tensor a (7) must match the size "
|
|
|
|
|
"of tensor b (5) at non-singleton dimension 1"
|
|
|
|
|
)
|
|
|
|
|
# test fails if error does not contain expected substr
|
|
|
|
|
if str(error).find(expected_error_msg) == -1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Tried execution of add.Tensors with incompatible shape. "
|
|
|
|
|
"Exception raised by forked runtime execution does "
|
2024-06-02 23:25:26 +00:00
|
|
|
f'not contain expected substring: "{expected_error_msg}"'
|
2022-06-11 03:11:49 +00:00
|
|
|
) from error
|
|
|
|
|
|
2020-08-29 06:17:17 +00:00
|
|
|
def test_multihead_attention_layer(self):
|
|
|
|
|
HID_DIM = 256
|
|
|
|
|
QUERY_LEN = 8
|
|
|
|
|
BATCH_SIZE = 128
|
|
|
|
|
LAYERS = 3
|
|
|
|
|
HEADS = 8
|
|
|
|
|
DROPOUT = 0.1
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
2020-09-25 18:01:10 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
|
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
2020-08-29 06:17:17 +00:00
|
|
|
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
|
|
|
|
|
|
|
|
|
attention.eval()
|
|
|
|
|
attention = torch.jit.script(attention)
|
|
|
|
|
attention.eval()
|
|
|
|
|
o_ref = attention(src, src, src, src_mask)
|
|
|
|
|
|
2021-03-05 18:12:17 +00:00
|
|
|
attention_a = StaticModule(attention)
|
2020-08-29 06:17:17 +00:00
|
|
|
o_test = attention_a(src, src, src, src_mask)
|
2020-10-07 03:52:29 +00:00
|
|
|
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
|
[pt][static_runtime] Memory model (#46896)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46896
The idea of the memory model is quite similar to that of BlackBoxPredictor, however, it's more complicated in pt due to 1) tensor views that share storage with storage refcount bumps but with different TensorImpls, 2) tensors sharing the same TensorImpl and the same storage, but with no refcount bump of the StorageImpl, 3) data types such as TensorList and Tuples that have Tensors in them, 4) need to support non-out/out variant mix while we move the aten ops to out variants.
As a result, I have to make the following adjustments:
1) remove tensors in output Tuples from internal blob list;
2) for memory allocation/deallocation, get candidate Tensors from the outputs of ops with out variant, extract StorageImpls from the Tensors, dedup, and remove output tensor StorageImpls, and get the final list of blobs for memory planning;
3) during the clean_up_memory pass, clean up memory held by the StorageImpls as well as Tensors/Lists/Tuples in IValues that don't participate in memory planning to reduce overall memory usage
Risk:
PyTorch team is planning to deprecate the current resize_outout api, which we do rely on. This is a pretty big risk.
https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/caffe2/aten/src/ATen/native/Resize.cpp?commit=6457b329847607553d34e788a3a7092f41f38895&lines=9-23
Test Plan:
```
buck test //caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```
Benchmarks:
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 13 \
buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench \
--scripted_model=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/traced_precomputation.pt \
--pt_inputs=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/container_precomputation_bs1.pt \
--iters=1000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true \
--pt_cleanup_activations=true --pt_enable_out_variant=false
```
|pt_cleanup_activations |pt_enable_out_variant |old ms/iter |new ms/iter |
|--- |--- |--- |--- |
|0 |0 |0.31873 |0.30228 |
|0 |1 |0.30018 |0.29184 |
|1 |0 |0.35246 |0.31895 |
|1 |1 |0.35742 |0.30417 |
Reviewed By: bwasti, raziel
Differential Revision: D24471854
fbshipit-source-id: 4ac37dca7d2a0c362120a7f02fd3995460c9a55c
2020-11-04 07:42:24 +00:00
|
|
|
|
2020-08-29 06:17:17 +00:00
|
|
|
for a, b in zip(o_ref, o_test):
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(a, b)
|
[pt][static_runtime] Memory model (#46896)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46896
The idea of the memory model is quite similar to that of BlackBoxPredictor, however, it's more complicated in pt due to 1) tensor views that share storage with storage refcount bumps but with different TensorImpls, 2) tensors sharing the same TensorImpl and the same storage, but with no refcount bump of the StorageImpl, 3) data types such as TensorList and Tuples that have Tensors in them, 4) need to support non-out/out variant mix while we move the aten ops to out variants.
As a result, I have to make the following adjustments:
1) remove tensors in output Tuples from internal blob list;
2) for memory allocation/deallocation, get candidate Tensors from the outputs of ops with out variant, extract StorageImpls from the Tensors, dedup, and remove output tensor StorageImpls, and get the final list of blobs for memory planning;
3) during the clean_up_memory pass, clean up memory held by the StorageImpls as well as Tensors/Lists/Tuples in IValues that don't participate in memory planning to reduce overall memory usage
Risk:
PyTorch team is planning to deprecate the current resize_outout api, which we do rely on. This is a pretty big risk.
https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/caffe2/aten/src/ATen/native/Resize.cpp?commit=6457b329847607553d34e788a3a7092f41f38895&lines=9-23
Test Plan:
```
buck test //caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```
Benchmarks:
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 13 \
buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench \
--scripted_model=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/traced_precomputation.pt \
--pt_inputs=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/container_precomputation_bs1.pt \
--iters=1000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true \
--pt_cleanup_activations=true --pt_enable_out_variant=false
```
|pt_cleanup_activations |pt_enable_out_variant |old ms/iter |new ms/iter |
|--- |--- |--- |--- |
|0 |0 |0.31873 |0.30228 |
|0 |1 |0.30018 |0.29184 |
|1 |0 |0.35246 |0.31895 |
|1 |1 |0.35742 |0.30417 |
Reviewed By: bwasti, raziel
Differential Revision: D24471854
fbshipit-source-id: 4ac37dca7d2a0c362120a7f02fd3995460c9a55c
2020-11-04 07:42:24 +00:00
|
|
|
|
2020-10-07 03:52:29 +00:00
|
|
|
for a, b in zip(o_ref, o_test_kw):
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(a, b)
|
2020-10-07 03:52:29 +00:00
|
|
|
|
|
|
|
|
def test_multihead_attention_layer_benchmark(self):
|
|
|
|
|
HID_DIM = 256
|
|
|
|
|
QUERY_LEN = 8
|
|
|
|
|
BATCH_SIZE = 128
|
|
|
|
|
LAYERS = 3
|
|
|
|
|
HEADS = 8
|
|
|
|
|
DROPOUT = 0.1
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
|
|
|
|
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
|
|
|
|
|
|
|
|
|
attention.eval()
|
|
|
|
|
attention = torch.jit.script(attention)
|
2021-03-05 18:12:17 +00:00
|
|
|
attention_a = StaticModule(attention)
|
2020-10-07 03:52:29 +00:00
|
|
|
|
[pt][static_runtime] Memory model (#46896)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46896
The idea of the memory model is quite similar to that of BlackBoxPredictor, however, it's more complicated in pt due to 1) tensor views that share storage with storage refcount bumps but with different TensorImpls, 2) tensors sharing the same TensorImpl and the same storage, but with no refcount bump of the StorageImpl, 3) data types such as TensorList and Tuples that have Tensors in them, 4) need to support non-out/out variant mix while we move the aten ops to out variants.
As a result, I have to make the following adjustments:
1) remove tensors in output Tuples from internal blob list;
2) for memory allocation/deallocation, get candidate Tensors from the outputs of ops with out variant, extract StorageImpls from the Tensors, dedup, and remove output tensor StorageImpls, and get the final list of blobs for memory planning;
3) during the clean_up_memory pass, clean up memory held by the StorageImpls as well as Tensors/Lists/Tuples in IValues that don't participate in memory planning to reduce overall memory usage
Risk:
PyTorch team is planning to deprecate the current resize_outout api, which we do rely on. This is a pretty big risk.
https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/caffe2/aten/src/ATen/native/Resize.cpp?commit=6457b329847607553d34e788a3a7092f41f38895&lines=9-23
Test Plan:
```
buck test //caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```
Benchmarks:
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 13 \
buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench \
--scripted_model=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/traced_precomputation.pt \
--pt_inputs=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/container_precomputation_bs1.pt \
--iters=1000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true \
--pt_cleanup_activations=true --pt_enable_out_variant=false
```
|pt_cleanup_activations |pt_enable_out_variant |old ms/iter |new ms/iter |
|--- |--- |--- |--- |
|0 |0 |0.31873 |0.30228 |
|0 |1 |0.30018 |0.29184 |
|1 |0 |0.35246 |0.31895 |
|1 |1 |0.35742 |0.30417 |
Reviewed By: bwasti, raziel
Differential Revision: D24471854
fbshipit-source-id: 4ac37dca7d2a0c362120a7f02fd3995460c9a55c
2020-11-04 07:42:24 +00:00
|
|
|
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
|
2020-10-07 03:52:29 +00:00
|
|
|
metrics = attention_a.benchmark_individual_ops(
|
[pt][static_runtime] Memory model (#46896)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46896
The idea of the memory model is quite similar to that of BlackBoxPredictor, however, it's more complicated in pt due to 1) tensor views that share storage with storage refcount bumps but with different TensorImpls, 2) tensors sharing the same TensorImpl and the same storage, but with no refcount bump of the StorageImpl, 3) data types such as TensorList and Tuples that have Tensors in them, 4) need to support non-out/out variant mix while we move the aten ops to out variants.
As a result, I have to make the following adjustments:
1) remove tensors in output Tuples from internal blob list;
2) for memory allocation/deallocation, get candidate Tensors from the outputs of ops with out variant, extract StorageImpls from the Tensors, dedup, and remove output tensor StorageImpls, and get the final list of blobs for memory planning;
3) during the clean_up_memory pass, clean up memory held by the StorageImpls as well as Tensors/Lists/Tuples in IValues that don't participate in memory planning to reduce overall memory usage
Risk:
PyTorch team is planning to deprecate the current resize_outout api, which we do rely on. This is a pretty big risk.
https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/caffe2/aten/src/ATen/native/Resize.cpp?commit=6457b329847607553d34e788a3a7092f41f38895&lines=9-23
Test Plan:
```
buck test //caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test //caffe2/caffe2/fb/predictor:pytorch_predictor_test
```
Benchmarks:
```
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 numactl -m 0 -C 13 \
buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench \
--scripted_model=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/traced_precomputation.pt \
--pt_inputs=/home/hlu/ads/adindexer/adindexer_ctr_mobilefeed/pt/merge/container_precomputation_bs1.pt \
--iters=1000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true \
--pt_cleanup_activations=true --pt_enable_out_variant=false
```
|pt_cleanup_activations |pt_enable_out_variant |old ms/iter |new ms/iter |
|--- |--- |--- |--- |
|0 |0 |0.31873 |0.30228 |
|0 |1 |0.30018 |0.29184 |
|1 |0 |0.35246 |0.31895 |
|1 |1 |0.35742 |0.30417 |
Reviewed By: bwasti, raziel
Differential Revision: D24471854
fbshipit-source-id: 4ac37dca7d2a0c362120a7f02fd3995460c9a55c
2020-11-04 07:42:24 +00:00
|
|
|
[src, src, src, src_mask], {}, 2, 2
|
2020-10-07 03:52:29 +00:00
|
|
|
)
|
2020-08-29 06:17:17 +00:00
|
|
|
|
|
|
|
|
def test_mlp(self):
|
|
|
|
|
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
|
|
|
|
|
ln_bot = [512, 512, 64]
|
|
|
|
|
sigmoid_bot = -1
|
|
|
|
|
ln_top = [100, 1024, 1024, 1024, 1]
|
|
|
|
|
sigmoid_top = 3
|
|
|
|
|
bot_l = create_mlp(ln_bot, sigmoid_bot)
|
2021-03-05 18:12:17 +00:00
|
|
|
bot_l_acc = StaticModule(bot_l)
|
2020-08-29 06:17:17 +00:00
|
|
|
top_l = create_mlp(ln_top, sigmoid_top)
|
2021-03-05 18:12:17 +00:00
|
|
|
top_l_acc = StaticModule(top_l)
|
2020-09-25 18:01:10 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
|
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
|
|
|
|
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
2020-08-29 06:17:17 +00:00
|
|
|
ref_bot = bot_l(bot_inp)
|
2021-10-25 15:16:14 +00:00
|
|
|
acc_bot = bot_l_acc(bot_inp)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(acc_bot, ref_bot)
|
2020-08-29 06:17:17 +00:00
|
|
|
ref_top = top_l(top_inp)
|
2021-10-25 15:16:14 +00:00
|
|
|
acc_top = top_l_acc(top_inp)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(acc_top, ref_top)
|
2020-09-14 19:33:02 +00:00
|
|
|
for _ in range(5):
|
2020-09-25 18:01:10 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
|
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
|
|
|
|
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
2020-09-14 19:33:02 +00:00
|
|
|
ref_bot = bot_l(bot_inp)
|
2021-10-25 15:16:14 +00:00
|
|
|
acc_bot = bot_l_acc(bot_inp)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(acc_bot, ref_bot)
|
2020-09-14 19:33:02 +00:00
|
|
|
ref_top = top_l(top_inp)
|
2021-10-25 15:16:14 +00:00
|
|
|
acc_top = top_l_acc(top_inp)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(acc_top, ref_top)
|
2020-08-29 06:17:17 +00:00
|
|
|
|
2020-09-28 19:53:59 +00:00
|
|
|
def test_trivial_graph(self):
|
|
|
|
|
s = torch.full((2, 2), 2)
|
|
|
|
|
tg = torch.jit.script(trivial_graph)
|
|
|
|
|
o_ref = tg(s, s, s)
|
2021-03-05 18:12:17 +00:00
|
|
|
tg_a = StaticModule(tg)
|
2021-10-25 15:16:14 +00:00
|
|
|
o_test = tg_a(s, s, s)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(o_ref, o_test)
|
2020-08-12 20:02:29 +00:00
|
|
|
|
2020-11-14 06:04:06 +00:00
|
|
|
def test_leaky_relu(self):
|
|
|
|
|
s = torch.randn(5, 5)
|
|
|
|
|
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
|
|
|
|
o_ref = tg(s)
|
2021-03-05 18:12:17 +00:00
|
|
|
tg_a = StaticModule(tg)
|
2021-10-25 15:16:14 +00:00
|
|
|
o_test = tg_a(s)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(o_ref, o_test)
|
2020-09-14 19:33:02 +00:00
|
|
|
|
2021-07-10 21:04:48 +00:00
|
|
|
def test_attr(self):
|
|
|
|
|
"""
|
|
|
|
|
TorchScript IR of TestModule() after freezing:
|
|
|
|
|
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
|
|
|
|
|
%x.1 : Tensor):
|
|
|
|
|
%18 : int = prim::Constant[value=30]()
|
|
|
|
|
%30 : int = prim::Constant[value=13]()
|
|
|
|
|
%3 : int = prim::Constant[value=20]()
|
|
|
|
|
%2 : int = prim::Constant[value=1]()
|
|
|
|
|
%self.sub2.a : int = prim::Constant[value=12]()
|
|
|
|
|
%self.a : int = prim::Constant[value=3]()
|
|
|
|
|
= prim::SetAttr[name="b"](%self, %3)
|
|
|
|
|
%17 : Tensor = aten::add(%x.1, %30, %2)
|
|
|
|
|
%7 : Tensor = aten::add(%17, %self.a, %2)
|
|
|
|
|
%b.1 : int = prim::GetAttr[name="b"](%self)
|
|
|
|
|
%9 : Tensor = aten::add(%7, %b.1, %2)
|
|
|
|
|
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
|
|
|
|
|
= prim::SetAttr[name="b"](%sub2, %18)
|
|
|
|
|
%b : int = prim::GetAttr[name="b"](%sub2)
|
|
|
|
|
%22 : int = aten::add(%self.sub2.a, %b)
|
|
|
|
|
%23 : Tensor = aten::add(%x.1, %22, %2)
|
|
|
|
|
%12 : Tensor = aten::add(%9, %23, %2)
|
|
|
|
|
return (%12)
|
|
|
|
|
"""
|
|
|
|
|
# test prim::SetAttr and prim::GetAttr impl in Static Runtime
|
|
|
|
|
m = TestModule()
|
|
|
|
|
|
|
|
|
|
m.eval()
|
|
|
|
|
input = torch.randn(2, 2)
|
|
|
|
|
output_s = m.forward(input)
|
|
|
|
|
|
|
|
|
|
ms = torch.jit.script(m)
|
|
|
|
|
sm = StaticModule(ms)
|
2021-10-25 15:16:14 +00:00
|
|
|
output_sm = sm(input)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(output_s, output_sm)
|
2021-07-10 21:04:48 +00:00
|
|
|
sm.benchmark([input], {}, 2, 2)
|
|
|
|
|
sm.benchmark_individual_ops([input], {}, 2, 2)
|
|
|
|
|
sm.benchmark([], {"x": input}, 2, 2)
|
|
|
|
|
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
|
|
|
|
|
|
2021-04-06 03:50:39 +00:00
|
|
|
@unittest.skip("Temporarily disabled")
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
def test_fusion_trivial_graph(self):
|
|
|
|
|
s = torch.full((2, 2), 2)
|
|
|
|
|
tg = torch.jit.script(trivial_graph)
|
|
|
|
|
o_ref = tg(s, s, s)
|
2021-03-05 18:12:17 +00:00
|
|
|
torch._C._fuse_to_static_module(tg.graph)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
assert "StaticSubgraph" in str(tg.graph)
|
|
|
|
|
o_test = tg(s, s, s)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(o_ref, o_test)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
|
2021-04-06 03:50:39 +00:00
|
|
|
@unittest.skip("Temporarily disabled")
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
def test_fusion_multihead_attention_layer(self):
|
|
|
|
|
HID_DIM = 256
|
|
|
|
|
QUERY_LEN = 8
|
|
|
|
|
BATCH_SIZE = 128
|
|
|
|
|
LAYERS = 3
|
|
|
|
|
HEADS = 8
|
|
|
|
|
DROPOUT = 0.1
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
|
|
|
|
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
|
|
|
|
|
|
|
|
|
attention.eval()
|
|
|
|
|
attention = torch.jit.script(attention)
|
|
|
|
|
attention.eval()
|
|
|
|
|
o_ref = attention(src, src, src, src_mask)
|
|
|
|
|
|
2021-03-05 18:12:17 +00:00
|
|
|
torch._C._fuse_to_static_module(attention._c)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
o_test = attention(src, src, src, src_mask)
|
|
|
|
|
|
|
|
|
|
for a, b in zip(o_ref, o_test):
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(a, b)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
|
2021-04-06 03:50:39 +00:00
|
|
|
@unittest.skip("Temporarily disabled")
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
def test_fusion_loop(self):
|
|
|
|
|
a = torch.randn(5, 5)
|
|
|
|
|
b = torch.randn(5, 5)
|
|
|
|
|
c = 4
|
|
|
|
|
lg = torch.jit.script(loop_graph)
|
|
|
|
|
o_ref = lg(a, b, c)
|
2021-03-05 18:12:17 +00:00
|
|
|
torch._C._fuse_to_static_module(lg.graph)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
assert "StaticSubgraph" in str(lg.graph)
|
|
|
|
|
o_test = lg(a, b, c)
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(o_ref, o_test)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
|
2021-04-06 03:50:39 +00:00
|
|
|
@unittest.skip("Temporarily disabled")
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
def test_fusion_outputs(self):
|
|
|
|
|
a = torch.randn(2, 2)
|
|
|
|
|
b = torch.randn(2, 2)
|
|
|
|
|
c = 4
|
|
|
|
|
og = torch.jit.script(output_graph)
|
|
|
|
|
o_ref = og(a, b, b, c)
|
2021-03-05 18:12:17 +00:00
|
|
|
torch._C._fuse_to_static_module(og.graph)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
assert "StaticSubgraph" in str(og.graph)
|
|
|
|
|
o_test = og(a, b, b, c)
|
|
|
|
|
for i in o_ref.keys():
|
2021-08-19 19:45:32 +00:00
|
|
|
torch.testing.assert_close(o_ref[i], o_test[i])
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
|
2022-02-03 12:13:51 +00:00
|
|
|
def test_create_object(self):
|
2022-02-03 15:13:09 +00:00
|
|
|
class Foo: # noqa: B903
|
2022-02-03 12:13:51 +00:00
|
|
|
def __init__(self, x: torch.Tensor) -> None:
|
|
|
|
|
self.x = x
|
|
|
|
|
|
|
|
|
|
class Mod(torch.nn.Module):
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
foo = Foo(y)
|
|
|
|
|
return y * foo.x
|
|
|
|
|
|
|
|
|
|
mod = torch.jit.script(Mod()).eval()
|
|
|
|
|
y = torch.randn((1, ))
|
|
|
|
|
expected = mod(y)
|
|
|
|
|
|
|
|
|
|
static_mod = StaticModule(torch.jit.freeze(mod))
|
|
|
|
|
actual = static_mod(y)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(expected, actual)
|
[static runtime] add static subgraph fusion pass (#49185)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185
This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc.
TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case
the loop example looks quite good:
```
graph(%a.1 : Tensor,
%b.1 : Tensor,
%iters.1 : int):
%12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
%c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
%c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
block0(%i : int, %c.12 : Tensor):
%c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
-> (%12, %c.10)
return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
%4 : Tensor):
%5 : int = prim::Constant[value=2]()
%6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
%2 : int = prim::Constant[value=1]()
%c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
%7 : Tensor,
%8 : Tensor):
%9 : int = prim::Constant[value=1]()
%c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
%5 : int = prim::Constant[value=2]()
%c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
%2 : int = prim::Constant[value=1]()
%c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
return (%c.10)
```
(Note: this ignores all push blocking failures!)
Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest
buck test mode/no-gpu caffe2/test:static_runtime
Reviewed By: bertmaher
Differential Revision: D25385702
fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
2020-12-10 22:01:36 +00:00
|
|
|
|
2020-08-12 20:02:29 +00:00
|
|
|
if __name__ == "__main__":
|
2020-08-29 06:17:17 +00:00
|
|
|
run_tests()
|