From 7eecdf2a8650306ed5fbb6150c64f99f587e004d Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 3 Feb 2025 09:37:02 +0100 Subject: [PATCH] Update-tp test (#35844) * update test for now * up * cleanup * update todo --- src/transformers/pytorch_utils.py | 2 ++ tests/tp/test_tp.py | 46 +++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 1bd878a94..e058b639f 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) return torch.isin(elements, test_elements) +# TODO need to add the __repr__ that shows that it is a colwise parallel +# See https://github.com/pytorch/pytorch/issues/145726 def translate_to_torch_parallel_style(style: str): """ In model configurations, we use a neutral type (string) to specify parallel diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py index 3df57c5f9..7b9bff5f1 100644 --- a/tests/tp/test_tp.py +++ b/tests/tp/test_tp.py @@ -17,6 +17,7 @@ import subprocess import tempfile import textwrap +# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py from transformers import is_torch_available from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaModel @@ -110,9 +111,8 @@ if __name__ == "__main__": # Test settings model_id = "meta-llama/Meta-Llama-3-8B-Instruct" - bs = 4 - seqlen = 64 - + bs = 1 + seqlen = 4096 # Get distributed settings rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) @@ -124,23 +124,45 @@ if __name__ == "__main__": # Get model config config = LlamaConfig.from_pretrained(model_id) - # Shrink model size - config.num_hidden_layers //= 8 - config.vocab_size //= 8 - + config.hidden_size = 2048 + config.attention_bias = False # Instantiate model with device: - model = LlamaModel(config) + model = LlamaModel(config).to(dtype=torch.float16) model.eval() - # Tensor Parallel if world_size > 1: model.tensor_parallel(device_mesh) - # Run model + inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) - with torch.no_grad(): - out = model(inputs) + + # Test cuda graphing explicitly + with torch.cuda.device(device): + print("Cuda graphing") + with torch.no_grad(): + inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) + # CUDA Graph setup + s = torch.cuda.Stream(device=device) + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + out = model(inputs) + torch.cuda.current_stream().wait_stream(s) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = model(inputs) + + for _ in range(2): + g.replay() + s.synchronize() assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size]) + + # Test compile + with torch.no_grad(): + out = model(inputs) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + out = model(inputs) + out = model(inputs)