[pipelining] Add tests for tracing frontend (#125449)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125449
Approved by: https://github.com/wconstab
ghstack dependencies: #125273, #125448
This commit is contained in:
Ke Wen 2024-05-03 16:42:48 -07:00 committed by PyTorch MergeBot
parent bdaa7bbd7d
commit cbb3791891
4 changed files with 348 additions and 0 deletions

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.pipelining import (
ArgsChunkSpec,
KwargsChunkSpec,
pipe_split,
pipeline,
)
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
batch_size = 256
torch.manual_seed(0)
class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y, z=torch.zeros(batch_size, d_hid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
x = x + z
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
pipe_split()
x = torch.relu(x)
x = torch.mm(x, self.mm_param2)
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x
class ChunkSpecTests(TestCase):
def test_chunk_spec(self):
mod = ExampleCode()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
z = torch.randn(batch_size, d_hid)
chunks = 4
with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}):
pipe = pipeline(
mod,
chunks,
example_args=(x, y),
example_kwargs={"z": z},
)
assert pipe.num_stages == 4
ref = mod(x, y, z)
out = pipe(x, y, z)[0]
torch.testing.assert_close(out, ref)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()

View file

@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.pipelining import pipe_split, pipeline
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
batch_size = 256
torch.manual_seed(0)
# Basic example
class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y):
x = torch.mm(x, self.mm_param0)
skip_connection = x
x = x + y
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
pipe_split()
x = torch.relu(x)
x = x + skip_connection
x = torch.mm(x, self.mm_param2)
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x
# MLP example
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
class MultiMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp0 = MLPModule(d_hid)
self.mlp1 = MLPModule(d_hid)
self.mlp2 = MLPModule(d_hid)
self.mlp3 = MLPModule(d_hid)
def forward(self, x, y):
x = self.mlp0(x)
pipe_split()
x = self.mlp1(x)
pipe_split()
x = self.mlp2(x)
pipe_split()
x = self.mlp3(x)
return x - y
class PipeTests(TestCase):
def _test_model_split(self, model_class):
mod = model_class()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
pipe = pipeline(
mod,
num_chunks=4,
example_args=(x, y),
)
assert pipe.num_stages == 4, f"nstages = {pipe.num_stages}, expect 4"
ref_out = mod(x, y)
out = pipe(x, y)[0]
torch.testing.assert_close(out, ref_out)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")
# Check qualname
# state_dict.keys include both parameters and persistent buffers
old_names = set(mod.state_dict().keys())
new_names = set()
for idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(idx)
new_names.update(stage_mod.state_dict().keys())
assert (
old_names == new_names
), f"""
old names {old_names}
new names {new_names}
"""
print("Qualname check passed")
def test_example_code(self):
self._test_model_split(ExampleCode)
def test_multi_mlp(self):
self._test_model_split(MultiMLP)
if __name__ == "__main__":
run_tests()

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.pipelining import pipeline, SplitPoint
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 16
n_layers = 8
batch_size = 4
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
class TransformerLike(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.Sequential(*[MLPModule(d_hid) for _ in range(n_layers)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class TransformerTests(TestCase):
def test_ir(self):
transformer = TransformerLike()
print("Original model:\n", transformer)
x = torch.randn(batch_size, d_hid)
# Split into 2 stages
num_stages = 2
split_spec = {f"layers.{n_layers // num_stages}": SplitPoint.BEGINNING}
pipe = pipeline(
transformer,
1,
(x,),
split_spec=split_spec,
)
assert pipe.num_stages == num_stages, f"{pipe.num_stages=}, expect {num_stages}"
def get_layers(module):
layers = [name for name, _ in module.layers.named_children()]
return layers
# Collect all layers in pipe
layers = []
for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx)
print(f"\nStage {stage_idx}: \n", stage_mod)
layers += get_layers(stage_mod)
# Check layer completeness
orig_layers = get_layers(transformer)
assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}"
print("Layers matched! ", layers)
# Check equivalence
ref = transformer(x)
out = pipe(x)[0]
torch.testing.assert_close(out, ref)
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.pipelining import pipe_split, pipeline
from torch.testing._internal.common_utils import run_tests, TestCase
# Building block for model
class Block(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=16, out_channels=16, kernel_size=3, padding=1
)
self.lin0 = torch.nn.Linear(256, 256)
self.relu = torch.nn.ReLU()
self.lin1 = torch.nn.Linear(256, 256)
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
x = self.conv(x)
x = self.lin0(x)
pipe_split()
x.add_(constant)
x = self.lin1(x)
return self.relu(x)
# Full model
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.block0 = Block()
self.block1 = Block()
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
x = self.block0(x, constant=constant)
pipe_split()
x = self.block1(x, constant=constant)
return x
class UnflattenTests(TestCase):
def test_unflatten(self):
x = torch.randn(1, 16, 256, 256)
constant = torch.ones(1, 16, 256, 256)
mod = M()
print("Original model:\n", mod)
pipe = pipeline(
mod,
1,
(x,),
{"constant": constant},
)
assert pipe.num_stages == 4
orig_state_dict = mod.state_dict()
# Check qualnames
print("\nParameters of each stage:")
for stage_idx in range(pipe.num_stages):
print(f"\nStage {stage_idx}:")
stage_mod = pipe.get_stage_module(stage_idx)
for param_name, param in stage_mod.named_parameters():
assert (
param_name in orig_state_dict
), f"{param_name} not in original state dict"
print(f"{param_name}: {param.size()}")
# Check equivalence
ref = mod(x, constant)
out = pipe(x, constant)[0]
torch.testing.assert_close(out, ref)
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()