mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665 Approved by: https://github.com/albanD
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
import torch
|
|
from torch.distributed.pipelining import pipe_split, pipeline
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
# Building block for model
|
|
class Block(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=16, out_channels=16, kernel_size=3, padding=1
|
|
)
|
|
self.lin0 = torch.nn.Linear(256, 256)
|
|
self.relu = torch.nn.ReLU()
|
|
self.lin1 = torch.nn.Linear(256, 256)
|
|
|
|
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
|
|
x = self.conv(x)
|
|
x = self.lin0(x)
|
|
pipe_split()
|
|
x.add(constant)
|
|
x = self.lin1(x)
|
|
return self.relu(x)
|
|
|
|
|
|
# Full model
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block0 = Block()
|
|
self.block1 = Block()
|
|
|
|
def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
|
|
x = self.block0(x, constant=constant)
|
|
pipe_split()
|
|
x = self.block1(x, constant=constant)
|
|
return x
|
|
|
|
|
|
class UnflattenTests(TestCase):
|
|
def test_unflatten(self):
|
|
x = torch.randn(1, 16, 256, 256)
|
|
constant = torch.ones(1, 16, 256, 256)
|
|
|
|
mod = M()
|
|
|
|
pipe = pipeline(
|
|
mod,
|
|
(x,),
|
|
{"constant": constant},
|
|
)
|
|
|
|
assert pipe.num_stages == 4
|
|
orig_state_dict = mod.state_dict()
|
|
|
|
# Check qualnames
|
|
for stage_idx in range(pipe.num_stages):
|
|
stage_mod = pipe.get_stage_module(stage_idx)
|
|
for param_name, _ in stage_mod.named_parameters():
|
|
assert (
|
|
param_name in orig_state_dict
|
|
), f"{param_name} not in original state dict"
|
|
print("Param qualname test passed")
|
|
|
|
# Check equivalence
|
|
ref = mod(x, constant)
|
|
out = pipe(x, constant)[0]
|
|
torch.testing.assert_close(out, ref)
|
|
print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|