From 7ff9d450cdf39e723cf6c8aa6a96a9bdff980596 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 1 Mar 2022 17:47:17 +0100 Subject: [PATCH] Scatter should run on CUDA (#15872) --- docker/transformers-all-latest-gpu/Dockerfile | 2 +- docker/transformers-pytorch-gpu/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 3a61866f1..cdef4c96e 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -13,7 +13,7 @@ RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] RUN python3 -m pip install --no-cache-dir -U torch tensorflow RUN python3 -m pip uninstall -y flax jax -RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+cpu.html +RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+cu102.html RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract https://github.com/kpu/kenlm/archive/master.zip RUN python3 -m pip install -U "itsdangerous<2.1.0" diff --git a/docker/transformers-pytorch-gpu/Dockerfile b/docker/transformers-pytorch-gpu/Dockerfile index 54e804937..4e6b81c02 100644 --- a/docker/transformers-pytorch-gpu/Dockerfile +++ b/docker/transformers-pytorch-gpu/Dockerfile @@ -17,7 +17,7 @@ ARG PYTORCH='' RUN [ ${#PYTORCH} -gt 0 ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch'; python3 -m pip install --no-cache-dir -U $VERSION RUN python3 -m pip uninstall -y tensorflow flax -RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+cpu.html +RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+cu102.html RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract https://github.com/kpu/kenlm/archive/master.zip RUN python3 -m pip install -U "itsdangerous<2.1.0"