mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[docs] DeepSpeed (#28542)
* config * optim * pre deploy * deploy * save weights, memory, troubleshoot, non-Trainer * done
This commit is contained in:
parent
bb6aa8bc5f
commit
738ec75c90
5 changed files with 1312 additions and 2286 deletions
|
|
@ -144,6 +144,8 @@
|
|||
title: Multiple GPUs and parallelism
|
||||
- local: fsdp
|
||||
title: Fully Sharded Data Parallel
|
||||
- local: deepspeed
|
||||
title: DeepSpeed
|
||||
- local: perf_train_cpu
|
||||
title: Efficient training on CPU
|
||||
- local: perf_train_cpu_many
|
||||
|
|
@ -253,7 +255,7 @@
|
|||
- local: main_classes/trainer
|
||||
title: Trainer
|
||||
- local: main_classes/deepspeed
|
||||
title: DeepSpeed Integration
|
||||
title: DeepSpeed
|
||||
- local: main_classes/feature_extractor
|
||||
title: Feature Extractor
|
||||
- local: main_classes/image_processor
|
||||
|
|
|
|||
|
|
@ -84,6 +84,76 @@ sudo ln -s /usr/bin/gcc-7 /usr/local/cuda-10.2/bin/gcc
|
|||
sudo ln -s /usr/bin/g++-7 /usr/local/cuda-10.2/bin/g++
|
||||
```
|
||||
|
||||
### Prebuild
|
||||
|
||||
If you're still having issues with installing DeepSpeed or if you're building DeepSpeed at run time, you can try to prebuild the DeepSpeed modules before installing them. To make a local build for DeepSpeed:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/microsoft/DeepSpeed/
|
||||
cd DeepSpeed
|
||||
rm -rf build
|
||||
TORCH_CUDA_ARCH_LIST="8.6" DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install . \
|
||||
--global-option="build_ext" --global-option="-j8" --no-cache -v \
|
||||
--disable-pip-version-check 2>&1 | tee build.log
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
To use NVMe offload, add the `DS_BUILD_AIO=1` parameter to the build command and make sure you install the libaio-dev package system-wide.
|
||||
|
||||
</Tip>
|
||||
|
||||
Next, you'll have to specify your GPU's architecture by editing the `TORCH_CUDA_ARCH_LIST` variable (find a complete list of NVIDIA GPUs and their corresponding architectures on this [page](https://developer.nvidia.com/cuda-gpus)). To check the PyTorch version that corresponds to your architecture, run the following command:
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.cuda.get_arch_list())"
|
||||
```
|
||||
|
||||
Find the architecture for a GPU with the following command:
|
||||
|
||||
<hfoptions id="arch">
|
||||
<hfoption id="same GPUs">
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python -c "import torch; print(torch.cuda.get_device_capability())"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="specific GPU">
|
||||
|
||||
To find the architecture for GPU `0`:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python -c "import torch; \
|
||||
print(torch.cuda.get_device_properties(torch.device('cuda')))
|
||||
"_CudaDeviceProperties(name='GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)"
|
||||
```
|
||||
|
||||
This means your GPU architecture is `8.6`.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
If you get `8, 6`, then you can set `TORCH_CUDA_ARCH_LIST="8.6"`. For multiple GPUs with different architectures, list them like `TORCH_CUDA_ARCH_LIST="6.1;8.6"`.
|
||||
|
||||
It is also possible to not specifiy `TORCH_CUDA_ARCH_LIST` and the build program automatically queries the GPU architecture of the build. However, it may or may not match the actual GPU on the target machine which is why it is better to explicitly specfify the correct architecture.
|
||||
|
||||
For training on multiple machines with the same setup, you'll need to make a binary wheel:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/microsoft/DeepSpeed/
|
||||
cd DeepSpeed
|
||||
rm -rf build
|
||||
TORCH_CUDA_ARCH_LIST="8.6" DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 \
|
||||
python setup.py build_ext -j8 bdist_wheel
|
||||
```
|
||||
|
||||
This command generates a binary wheel that'll look something like `dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`. Now you can install this wheel locally or on another machine.
|
||||
|
||||
```bash
|
||||
pip install deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl
|
||||
```
|
||||
|
||||
## Multi-GPU Network Issues Debug
|
||||
|
||||
When training or inferencing with `DistributedDataParallel` and multiple GPU, if you run into issue of inter-communication between processes and/or nodes, you can use the following script to diagnose network issues.
|
||||
|
|
|
|||
1219
docs/source/en/deepspeed.md
Normal file
1219
docs/source/en/deepspeed.md
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1312,3 +1312,19 @@ You can vote for this feature and see where it is at these CI-specific threads:
|
|||
|
||||
- [Github Actions:](https://github.com/actions/toolkit/issues/399)
|
||||
- [CircleCI:](https://ideas.circleci.com/ideas/CCI-I-344)
|
||||
|
||||
## DeepSpeed integration
|
||||
|
||||
For a PR that involves the DeepSpeed integration, keep in mind our CircleCI PR CI setup doesn't have GPUs. Tests requiring GPUs are run on a different CI nightly. This means if you get a passing CI report in your PR, it doesn’t mean the DeepSpeed tests pass.
|
||||
|
||||
To run DeepSpeed tests:
|
||||
|
||||
```bash
|
||||
RUN_SLOW=1 pytest tests/deepspeed/test_deepspeed.py
|
||||
```
|
||||
|
||||
Any changes to the modeling or PyTorch examples code requires running the model zoo tests as well.
|
||||
|
||||
```bash
|
||||
RUN_SLOW=1 pytest tests/deepspeed
|
||||
```
|
||||
|
|
|
|||
Loading…
Reference in a new issue