mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[pipelining] Add pipeline stage test (#126721)"
This reverts commit b948b1ad7a.
Reverted https://github.com/pytorch/pytorch/pull/126721 on behalf of https://github.com/clee2000 due to The test_public_bindings failure is real, you just got unlucky since it was also broken on trunk for a different reason ([comment](https://github.com/pytorch/pytorch/pull/126721#issuecomment-2121725408))
This commit is contained in:
parent
dc2560f073
commit
e363a8a222
6 changed files with 56 additions and 228 deletions
|
|
@ -17,8 +17,9 @@ class ExampleCode(torch.nn.Module):
|
|||
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
x = x + y
|
||||
x = torch.relu(x)
|
||||
# try passing a value that doesn't require_grad across skip boundaries
|
||||
a_constant = self.cval.clone()
|
||||
|
|
@ -31,29 +32,6 @@ class ExampleCode(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ModelWithKwargs(torch.nn.Module):
|
||||
default_dhid = 512
|
||||
default_batch_size = 256
|
||||
|
||||
def __init__(self, d_hid: int = default_dhid):
|
||||
super().__init__()
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin0 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
x = x + y
|
||||
x = self.lin0(x)
|
||||
x = torch.relu(x)
|
||||
pipe_split()
|
||||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
# MLP Layer
|
||||
class MLPModule(torch.nn.Module):
|
||||
def __init__(self, d_hid):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ batch_size = 256
|
|||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class ModelWithKwargs(torch.nn.Module):
|
||||
class ExampleCode(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
|
|
@ -44,7 +44,7 @@ class ModelWithKwargs(torch.nn.Module):
|
|||
|
||||
class ChunkSpecTests(TestCase):
|
||||
def test_chunk_spec(self):
|
||||
mod = ModelWithKwargs()
|
||||
mod = ExampleCode()
|
||||
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP
|
||||
from model_registry import ExampleCode, MultiMLP
|
||||
from torch.distributed.pipelining import (
|
||||
pipeline,
|
||||
PipelineStage,
|
||||
|
|
@ -50,11 +50,60 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_ec_forward(self):
|
||||
# Setting this flag for numerical stability
|
||||
torch.distributed.pipelining.microbatch._debug_mask_minibatches = True
|
||||
|
||||
mod = ExampleCode(d_hid)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
y = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
example_kwargs={"y": y},
|
||||
)
|
||||
|
||||
stage = PipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleGPipe(stage, chunks)
|
||||
|
||||
# Run
|
||||
if self.rank == 0:
|
||||
schedule.step(x, y=y)
|
||||
else:
|
||||
out = schedule.step()
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
ref_out = mod(x, y=y)
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
|
||||
# Test qualname mapping
|
||||
submod_keys = stage.submod.state_dict().keys()
|
||||
# Confirm keys are consistent with original model
|
||||
old_keys = mod.state_dict().keys()
|
||||
assert all(k in old_keys for k in submod_keys)
|
||||
# Reset this flag
|
||||
torch.distributed.pipelining.microbatch._debug_mask_minibatches = False
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
def test_ec_backward(self, ScheduleClass):
|
||||
mod = ModelWithKwargs(d_hid)
|
||||
mod = ExampleCode(d_hid)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
|
|
|||
|
|
@ -1,198 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
|
||||
from torch.distributed.pipelining import (
|
||||
ManualPipelineStage,
|
||||
pipeline,
|
||||
PipelineStage,
|
||||
ScheduleGPipe,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinousTest,
|
||||
requires_nccl,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
|
||||
|
||||
d_hid = 512
|
||||
batch_size = 256
|
||||
chunks = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class StageTest(MultiProcContinousTest):
|
||||
@classmethod
|
||||
def backend_str(cls) -> str:
|
||||
# Testing with NCCL backend
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Class-scope test fixture. Run once for entire test class, before any test starts.
|
||||
Set up the device.
|
||||
"""
|
||||
super().setUpClass()
|
||||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ModelClass", [ExampleCode, MultiMLP])
|
||||
def test_tracer(self, ModelClass):
|
||||
mod = ModelClass(d_hid)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
)
|
||||
|
||||
stage = PipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleGPipe(stage, chunks)
|
||||
|
||||
# Run
|
||||
if self.rank == 0:
|
||||
schedule.step(x)
|
||||
else:
|
||||
out = schedule.step()
|
||||
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
ref_out = mod(x)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
|
||||
|
||||
# Test qualname mapping
|
||||
submod_keys = stage.submod.state_dict().keys()
|
||||
# Confirm keys are consistent with original model
|
||||
old_keys = mod.state_dict().keys()
|
||||
assert all(k in old_keys for k in submod_keys)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ModelClass", [ModelWithKwargs])
|
||||
def test_tracer_kwargs(self, ModelClass):
|
||||
mod = ModelClass(d_hid)
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
y = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
example_kwargs={"y": y},
|
||||
)
|
||||
|
||||
stage = PipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleGPipe(stage, chunks)
|
||||
|
||||
# Run
|
||||
if self.rank == 0:
|
||||
schedule.step(x, y=y)
|
||||
else:
|
||||
out = schedule.step()
|
||||
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
ref_out = mod(x, y=y)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
|
||||
|
||||
# Test qualname mapping
|
||||
submod_keys = stage.submod.state_dict().keys()
|
||||
# Confirm keys are consistent with original model
|
||||
old_keys = mod.state_dict().keys()
|
||||
assert all(k in old_keys for k in submod_keys)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_manual(self):
|
||||
full_mod = MultiMLP(d_hid).to(self.device)
|
||||
stage_mod = full_mod.get_submodule(f"mlp{self.rank}")
|
||||
stage_mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
stage = ManualPipelineStage(
|
||||
stage_mod,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.device,
|
||||
chunks,
|
||||
input_args=x.chunk(chunks)[0],
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleGPipe(stage, chunks)
|
||||
|
||||
# Run
|
||||
if self.rank == 0:
|
||||
schedule.step(x)
|
||||
else:
|
||||
out = schedule.step()
|
||||
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
ref_out = full_mod(x)
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(StageTest)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if GPU and NCCL are available
|
||||
if not (
|
||||
dist.is_available()
|
||||
and dist.is_nccl_available()
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
print(
|
||||
"c10d NCCL not available or not enough GPUs, skipping tests",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 2))
|
||||
|
||||
if rank != -1:
|
||||
# Launched with torchrun or other multi-proc launchers. Directly run the test.
|
||||
StageTest.run_rank(rank, world_size)
|
||||
else:
|
||||
# Launched as a single process. Spawn subprocess to run the tests.
|
||||
# Also need a rendezvous file for `init_process_group` purpose.
|
||||
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.multiprocessing.spawn(
|
||||
StageTest.run_rank,
|
||||
nprocs=world_size,
|
||||
args=(world_size, rdvz_file),
|
||||
)
|
||||
|
|
@ -8,7 +8,7 @@ from ._IR import (
|
|||
pipeline,
|
||||
SplitPoint,
|
||||
)
|
||||
from ._PipelineStage import ManualPipelineStage, PipelineStage
|
||||
from ._PipelineStage import PipelineStage
|
||||
from .PipelineSchedule import (
|
||||
Schedule1F1B,
|
||||
ScheduleGPipe,
|
||||
|
|
@ -24,7 +24,6 @@ __all__ = [
|
|||
"pipeline",
|
||||
"ArgsChunkSpec",
|
||||
"KwargsChunkSpec",
|
||||
"ManualPipelineStage",
|
||||
"PipelineStage",
|
||||
"Schedule1F1B",
|
||||
"ScheduleGPipe",
|
||||
|
|
|
|||
Loading…
Reference in a new issue