Revert "[pipelining] Add pipeline stage test (#126721)"

This reverts commit b948b1ad7a.

Reverted https://github.com/pytorch/pytorch/pull/126721 on behalf of https://github.com/clee2000 due to The test_public_bindings failure is real, you just got unlucky since it was also broken on trunk for a different reason ([comment](https://github.com/pytorch/pytorch/pull/126721#issuecomment-2121725408))
This commit is contained in:
PyTorch MergeBot 2024-05-21 04:40:05 +00:00
parent dc2560f073
commit e363a8a222
6 changed files with 56 additions and 228 deletions

View file

@ -17,8 +17,9 @@ class ExampleCode(torch.nn.Module):
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
# try passing a value that doesn't require_grad across skip boundaries
a_constant = self.cval.clone()
@ -31,29 +32,6 @@ class ExampleCode(torch.nn.Module):
return x
class ModelWithKwargs(torch.nn.Module):
default_dhid = 512
default_batch_size = 256
def __init__(self, d_hid: int = default_dhid):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = self.lin0(x)
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):

View file

@ -16,7 +16,7 @@ batch_size = 256
torch.manual_seed(0)
class ModelWithKwargs(torch.nn.Module):
class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
@ -44,7 +44,7 @@ class ModelWithKwargs(torch.nn.Module):
class ChunkSpecTests(TestCase):
def test_chunk_spec(self):
mod = ModelWithKwargs()
mod = ExampleCode()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)

View file

@ -8,7 +8,7 @@ import tempfile
import torch
import torch.distributed as dist
from model_registry import ModelWithKwargs, MultiMLP
from model_registry import ExampleCode, MultiMLP
from torch.distributed.pipelining import (
pipeline,
PipelineStage,
@ -50,11 +50,60 @@ class ScheduleTest(MultiProcContinousTest):
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_ec_forward(self):
# Setting this flag for numerical stability
torch.distributed.pipelining.microbatch._debug_mask_minibatches = True
mod = ExampleCode(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
y = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
if self.rank == 0:
schedule.step(x, y=y)
else:
out = schedule.step()
dist.barrier()
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = mod(x, y=y)
torch.testing.assert_close(out, ref_out)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
# Reset this flag
torch.distributed.pipelining.microbatch._debug_mask_minibatches = False
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_ec_backward(self, ScheduleClass):
mod = ModelWithKwargs(d_hid)
mod = ExampleCode(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)

View file

@ -1,198 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import torch
import torch.distributed as dist
from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
ScheduleGPipe,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
requires_nccl,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skip_but_pass_in_sandcastle_if,
)
d_hid = 512
batch_size = 256
chunks = 4
torch.manual_seed(0)
class StageTest(MultiProcContinousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
@classmethod
def setUpClass(cls):
"""
Class-scope test fixture. Run once for entire test class, before any test starts.
Set up the device.
"""
super().setUpClass()
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ExampleCode, MultiMLP])
def test_tracer(self, ModelClass):
mod = ModelClass(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
if self.rank == 0:
schedule.step(x)
else:
out = schedule.step()
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = mod(x)
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ModelWithKwargs])
def test_tracer_kwargs(self, ModelClass):
mod = ModelClass(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
y = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
if self.rank == 0:
schedule.step(x, y=y)
else:
out = schedule.step()
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = mod(x, y=y)
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_manual(self):
full_mod = MultiMLP(d_hid).to(self.device)
stage_mod = full_mod.get_submodule(f"mlp{self.rank}")
stage_mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
stage = ManualPipelineStage(
stage_mod,
self.rank,
self.world_size,
self.device,
chunks,
input_args=x.chunk(chunks)[0],
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
if self.rank == 0:
schedule.step(x)
else:
out = schedule.step()
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
instantiate_parametrized_tests(StageTest)
if __name__ == "__main__":
# Check if GPU and NCCL are available
if not (
dist.is_available()
and dist.is_nccl_available()
and torch.cuda.device_count() > 1
):
print(
"c10d NCCL not available or not enough GPUs, skipping tests",
file=sys.stderr,
)
sys.exit(0)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", 2))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
StageTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
StageTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)

View file

@ -8,7 +8,7 @@ from ._IR import (
pipeline,
SplitPoint,
)
from ._PipelineStage import ManualPipelineStage, PipelineStage
from ._PipelineStage import PipelineStage
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
@ -24,7 +24,6 @@ __all__ = [
"pipeline",
"ArgsChunkSpec",
"KwargsChunkSpec",
"ManualPipelineStage",
"PipelineStage",
"Schedule1F1B",
"ScheduleGPipe",