mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[pipelining] Add tests for tracing frontend (#125449)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125449 Approved by: https://github.com/wconstab ghstack dependencies: #125273, #125448
This commit is contained in:
parent
bdaa7bbd7d
commit
cbb3791891
4 changed files with 348 additions and 0 deletions
72
test/distributed/pipelining/test_chunkspec.py
Normal file
72
test/distributed/pipelining/test_chunkspec.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
from torch.distributed.pipelining import (
|
||||
ArgsChunkSpec,
|
||||
KwargsChunkSpec,
|
||||
pipe_split,
|
||||
pipeline,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
d_hid = 512
|
||||
batch_size = 256
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class ExampleCode(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x, y, z=torch.zeros(batch_size, d_hid)):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
x = x + y
|
||||
x = torch.relu(x)
|
||||
x = x + z
|
||||
pipe_split()
|
||||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
pipe_split()
|
||||
x = torch.relu(x)
|
||||
x = torch.mm(x, self.mm_param2)
|
||||
pipe_split()
|
||||
x = self.lin2(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChunkSpecTests(TestCase):
|
||||
def test_chunk_spec(self):
|
||||
mod = ExampleCode()
|
||||
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
z = torch.randn(batch_size, d_hid)
|
||||
|
||||
chunks = 4
|
||||
|
||||
with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}):
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x, y),
|
||||
example_kwargs={"z": z},
|
||||
)
|
||||
|
||||
assert pipe.num_stages == 4
|
||||
|
||||
ref = mod(x, y, z)
|
||||
out = pipe(x, y, z)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
119
test/distributed/pipelining/test_pipe.py
Normal file
119
test/distributed/pipelining/test_pipe.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
from torch.distributed.pipelining import pipe_split, pipeline
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
d_hid = 512
|
||||
batch_size = 256
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
# Basic example
|
||||
class ExampleCode(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
skip_connection = x
|
||||
x = x + y
|
||||
x = torch.relu(x)
|
||||
pipe_split()
|
||||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
pipe_split()
|
||||
x = torch.relu(x)
|
||||
x = x + skip_connection
|
||||
x = torch.mm(x, self.mm_param2)
|
||||
pipe_split()
|
||||
x = self.lin2(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
# MLP example
|
||||
class MLPModule(torch.nn.Module):
|
||||
def __init__(self, d_hid):
|
||||
super().__init__()
|
||||
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.net1(x)
|
||||
x = self.relu(x)
|
||||
x = self.net2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiMLP(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp0 = MLPModule(d_hid)
|
||||
self.mlp1 = MLPModule(d_hid)
|
||||
self.mlp2 = MLPModule(d_hid)
|
||||
self.mlp3 = MLPModule(d_hid)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.mlp0(x)
|
||||
pipe_split()
|
||||
x = self.mlp1(x)
|
||||
pipe_split()
|
||||
x = self.mlp2(x)
|
||||
pipe_split()
|
||||
x = self.mlp3(x)
|
||||
return x - y
|
||||
|
||||
|
||||
class PipeTests(TestCase):
|
||||
def _test_model_split(self, model_class):
|
||||
mod = model_class()
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
num_chunks=4,
|
||||
example_args=(x, y),
|
||||
)
|
||||
|
||||
assert pipe.num_stages == 4, f"nstages = {pipe.num_stages}, expect 4"
|
||||
|
||||
ref_out = mod(x, y)
|
||||
out = pipe(x, y)[0]
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")
|
||||
|
||||
# Check qualname
|
||||
# state_dict.keys include both parameters and persistent buffers
|
||||
old_names = set(mod.state_dict().keys())
|
||||
new_names = set()
|
||||
for idx in range(pipe.num_stages):
|
||||
stage_mod = pipe.get_stage_module(idx)
|
||||
new_names.update(stage_mod.state_dict().keys())
|
||||
|
||||
assert (
|
||||
old_names == new_names
|
||||
), f"""
|
||||
old names {old_names}
|
||||
new names {new_names}
|
||||
"""
|
||||
print("Qualname check passed")
|
||||
|
||||
def test_example_code(self):
|
||||
self._test_model_split(ExampleCode)
|
||||
|
||||
def test_multi_mlp(self):
|
||||
self._test_model_split(MultiMLP)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
78
test/distributed/pipelining/test_transformer.py
Normal file
78
test/distributed/pipelining/test_transformer.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
from torch.distributed.pipelining import pipeline, SplitPoint
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
d_hid = 16
|
||||
n_layers = 8
|
||||
batch_size = 4
|
||||
|
||||
|
||||
class MLPModule(torch.nn.Module):
|
||||
def __init__(self, d_hid):
|
||||
super().__init__()
|
||||
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.net1(x)
|
||||
x = self.relu(x)
|
||||
x = self.net2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerLike(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.Sequential(*[MLPModule(d_hid) for _ in range(n_layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class TransformerTests(TestCase):
|
||||
def test_ir(self):
|
||||
transformer = TransformerLike()
|
||||
print("Original model:\n", transformer)
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
|
||||
# Split into 2 stages
|
||||
num_stages = 2
|
||||
split_spec = {f"layers.{n_layers // num_stages}": SplitPoint.BEGINNING}
|
||||
|
||||
pipe = pipeline(
|
||||
transformer,
|
||||
1,
|
||||
(x,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
assert pipe.num_stages == num_stages, f"{pipe.num_stages=}, expect {num_stages}"
|
||||
|
||||
def get_layers(module):
|
||||
layers = [name for name, _ in module.layers.named_children()]
|
||||
return layers
|
||||
|
||||
# Collect all layers in pipe
|
||||
layers = []
|
||||
for stage_idx in range(pipe.num_stages):
|
||||
stage_mod = pipe.get_stage_module(stage_idx)
|
||||
print(f"\nStage {stage_idx}: \n", stage_mod)
|
||||
layers += get_layers(stage_mod)
|
||||
|
||||
# Check layer completeness
|
||||
orig_layers = get_layers(transformer)
|
||||
assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}"
|
||||
print("Layers matched! ", layers)
|
||||
|
||||
# Check equivalence
|
||||
ref = transformer(x)
|
||||
out = pipe(x)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
79
test/distributed/pipelining/test_unflatten.py
Normal file
79
test/distributed/pipelining/test_unflatten.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
from torch.distributed.pipelining import pipe_split, pipeline
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
# Building block for model
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels=16, out_channels=16, kernel_size=3, padding=1
|
||||
)
|
||||
self.lin0 = torch.nn.Linear(256, 256)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.lin1 = torch.nn.Linear(256, 256)
|
||||
|
||||
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.lin0(x)
|
||||
pipe_split()
|
||||
x.add_(constant)
|
||||
x = self.lin1(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
# Full model
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.block0 = Block()
|
||||
self.block1 = Block()
|
||||
|
||||
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
|
||||
x = self.block0(x, constant=constant)
|
||||
pipe_split()
|
||||
x = self.block1(x, constant=constant)
|
||||
return x
|
||||
|
||||
|
||||
class UnflattenTests(TestCase):
|
||||
def test_unflatten(self):
|
||||
x = torch.randn(1, 16, 256, 256)
|
||||
constant = torch.ones(1, 16, 256, 256)
|
||||
|
||||
mod = M()
|
||||
print("Original model:\n", mod)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
1,
|
||||
(x,),
|
||||
{"constant": constant},
|
||||
)
|
||||
|
||||
assert pipe.num_stages == 4
|
||||
orig_state_dict = mod.state_dict()
|
||||
|
||||
# Check qualnames
|
||||
print("\nParameters of each stage:")
|
||||
for stage_idx in range(pipe.num_stages):
|
||||
print(f"\nStage {stage_idx}:")
|
||||
stage_mod = pipe.get_stage_module(stage_idx)
|
||||
for param_name, param in stage_mod.named_parameters():
|
||||
assert (
|
||||
param_name in orig_state_dict
|
||||
), f"{param_name} not in original state dict"
|
||||
print(f"{param_name}: {param.size()}")
|
||||
|
||||
# Check equivalence
|
||||
ref = mod(x, constant)
|
||||
out = pipe(x, constant)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
Loading…
Reference in a new issue