diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py index 85fe86ab93..5006b7c307 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py @@ -146,6 +146,104 @@ def demo_checkpoint(rank, world_size, use_ort_module): cleanup() +""" +CustomEmbedding is adapted from +https://github.com/huggingface/transformers/blob/312b104ff65514736c0475814fec19e47425b0b5/src/transformers/models/distilbert/modeling_distilbert.py#L91. +""" + + +class CustomEmbeddings(nn.Module): + def __init__(self): + super().__init__() + vocab_size = 511 + dim = 10 + pad_token_id = 0 + max_position_embeddings = 16 + self.word_embeddings = nn.Embedding(vocab_size, dim, padding_idx=pad_token_id) + self.position_embeddings = nn.Embedding(max_position_embeddings, dim) + self.LayerNorm = nn.LayerNorm(dim, eps=1e-12) + self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + seq_length = input_embeds.size(1) + position_ids = self.position_ids[:, :seq_length] + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + return embeddings + + +""" +This module calls `CustomEmbeddings`, which will generate a series of nodes (Shape->Gather->Unsqueeze...) whos +allocate output tensors on CPU memory. Then this test can converge failures when CUDA EP enabled, CPU device +allocation and usage are correct. +""" + + +class AnotherLayerOfToyModel(nn.Module): + def __init__(self): + super().__init__() + self.embedding = CustomEmbeddings() + self.t = ToyModel() + + def forward(self, x): + embed_val = self.embedding(x) + return self.t(embed_val) + + +""" +`Mixed device allocation` here means ORT backend allocates output tensors on CPU for some nodes and +on CUDA for other nodes. This test could help catch regression when ORT allocation planner logic got changed with bugs. +""" + + +def demo_mixed_device_allocation_training(rank, world_size, use_ort_module): + torch.manual_seed(0) + print(f"Running basic DDP example on rank {rank}.") + setup(rank, world_size) + device = "cuda:" + str(rank) + + # create a model and move it to GPU with id rank + model = AnotherLayerOfToyModel().to(device) + if use_ort_module: + model = ORTModule(model) + print(f" Rank {rank} uses ORTModule.") + else: + print(f" Rank {rank} uses Pytorch's nn.Module.") + + ddp_model = DDP(model, device_ids=[device]) + + loss_fn = nn.MSELoss() + optimizer = optim.Adagrad(ddp_model.parameters(), lr=0.01) + + batch = 2 + max_seq_length = 16 + x = torch.randint(1, 511, (batch, max_seq_length)).to(device) + y = torch.randn(batch, max_seq_length, 5).to(device) + + loss_history = [] + + for i in range(5): + optimizer.zero_grad() + p = ddp_model(x) + loss = loss_fn(p, y) + with torch.no_grad(): + print(f" Rank {rank} at iteration {i} has loss {loss}.") + loss.backward() + optimizer.step() + with torch.no_grad(): + loss_history.append(torch.unsqueeze(loss, 0)) + + loss_history = torch.cat(loss_history).cpu() + expected_loss_history = torch.FloatTensor([1.1589857340, 1.0975260735, 1.0628030300, 1.0386666059, 1.0196533203]) + + assert torch.allclose(expected_loss_history, loss_history) + + cleanup() + + def run_demo(demo_fn, world_size, use_ort_module): mp.spawn(demo_fn, args=(world_size, use_ort_module), nprocs=world_size, join=True) @@ -160,3 +258,4 @@ if __name__ == "__main__": args = parse_args() run_demo(demo_basic, 4, args.use_ort_module) run_demo(demo_checkpoint, 4, args.use_ort_module) + run_demo(demo_mixed_device_allocation_training, 4, args.use_ort_module)