diff --git a/verl_getting_started.ipynb b/verl_getting_started.ipynb
new file mode 100644
index 0000000..85b4d99
--- /dev/null
+++ b/verl_getting_started.ipynb
@@ -0,0 +1,1683 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eXkR4NjYhezg"
+ },
+ "source": [
+ "# Run Qwen PPO with [verl](https://github.com/volcengine/verl)\n",
+ "\n",
+ "This tutorial provides a step-by-step guide to using veRL for executing your RLHF pipeline. You can find our [github repo](https://github.com/volcengine/verl/) and [documentation](https://verl.readthedocs.io/en/latest/index.html) for mode details.\n",
+ "\n",
+ "This notebook is also published on the [Lightning Studio](https://lightning.ai/hlin-verl/studios/verl-getting-started) platform, which provides free GPU quota every month. Checkout the published notebook with pre-installed dependencies using a free L4 GPU [here](https://lightning.ai/hlin-verl/studios/verl-getting-started) (no credit card required).\n",
+ "\n",
+ "### You will learn:\n",
+ "\n",
+ "- How to install veRL from scratch.\n",
+ "- How to use existing scripts to run an RLHF pipeline with your own models and data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XSDNzNuQkJJh"
+ },
+ "source": [
+ "# Dependency Installation\n",
+ "\n",
+ "If you are running on Lightning Studio using the published notebook, the dependencies are **already installed** and you can proceed to step \"**Load Pretrained Language Model**\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gnfZyMm-3BNC",
+ "outputId": "c8520289-511f-447a-8d78-12a147484dee"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: pip in /usr/local/lib/python3.11/dist-packages (25.0.1)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (75.8.0)\n",
+ "Requirement already satisfied: wheel in /usr/local/lib/python3.11/dist-packages (0.45.1)\n",
+ "Collecting torch==2.4.0\n",
+ " Using cached torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)\n",
+ "Collecting torchvision==0.19.0\n",
+ " Using cached torchvision-0.19.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.0 kB)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.17.0)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (4.12.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.1.5)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (2024.10.0)\n",
+ "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)\n",
+ " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0)\n",
+ " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0)\n",
+ " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0)\n",
+ " Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.0)\n",
+ " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.4.0)\n",
+ " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-curand-cu12==10.3.2.106 (from torch==2.4.0)\n",
+ " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
+ "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch==2.4.0)\n",
+ " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch==2.4.0)\n",
+ " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
+ "Collecting nvidia-nccl-cu12==2.20.5 (from torch==2.4.0)\n",
+ " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n",
+ "Collecting nvidia-nvtx-cu12==12.1.105 (from torch==2.4.0)\n",
+ " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
+ "Collecting triton==3.0.0 (from torch==2.4.0)\n",
+ " Using cached triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision==0.19.0) (1.26.4)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision==0.19.0) (11.1.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.0) (12.5.82)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.4.0) (3.0.2)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.4.0) (1.3.0)\n",
+ "Downloading torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl (797.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.3/797.3 MB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading torchvision-0.19.0-cp311-cp311-manylinux1_x86_64.whl (7.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m128.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m51.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m153.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m162.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m42.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m37.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m129.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m136.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m138.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m101.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m130.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
+ "Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: triton, nvidia-nvtx-cu12, nvidia-nccl-cu12, nvidia-cusparse-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusolver-cu12, nvidia-cudnn-cu12, torch, torchvision\n",
+ " Attempting uninstall: triton\n",
+ " Found existing installation: triton 3.1.0\n",
+ " Uninstalling triton-3.1.0:\n",
+ " Successfully uninstalled triton-3.1.0\n",
+ " Attempting uninstall: nvidia-nvtx-cu12\n",
+ " Found existing installation: nvidia-nvtx-cu12 12.4.127\n",
+ " Uninstalling nvidia-nvtx-cu12-12.4.127:\n",
+ " Successfully uninstalled nvidia-nvtx-cu12-12.4.127\n",
+ " Attempting uninstall: nvidia-nccl-cu12\n",
+ " Found existing installation: nvidia-nccl-cu12 2.21.5\n",
+ " Uninstalling nvidia-nccl-cu12-2.21.5:\n",
+ " Successfully uninstalled nvidia-nccl-cu12-2.21.5\n",
+ " Attempting uninstall: nvidia-cusparse-cu12\n",
+ " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
+ " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
+ " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
+ " Attempting uninstall: nvidia-curand-cu12\n",
+ " Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
+ " Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
+ " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
+ " Attempting uninstall: nvidia-cufft-cu12\n",
+ " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
+ " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
+ " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
+ " Attempting uninstall: nvidia-cuda-runtime-cu12\n",
+ " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
+ " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cuda-cupti-cu12\n",
+ " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
+ " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
+ " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
+ " Attempting uninstall: nvidia-cublas-cu12\n",
+ " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
+ " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
+ " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
+ " Attempting uninstall: nvidia-cusolver-cu12\n",
+ " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
+ " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
+ " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
+ " Attempting uninstall: nvidia-cudnn-cu12\n",
+ " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
+ " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
+ " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
+ " Attempting uninstall: torch\n",
+ " Found existing installation: torch 2.5.1+cu124\n",
+ " Uninstalling torch-2.5.1+cu124:\n",
+ " Successfully uninstalled torch-2.5.1+cu124\n",
+ " Attempting uninstall: torchvision\n",
+ " Found existing installation: torchvision 0.20.1+cu124\n",
+ " Uninstalling torchvision-0.20.1+cu124:\n",
+ " Successfully uninstalled torchvision-0.20.1+cu124\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "torchaudio 2.5.1+cu124 requires torch==2.5.1, but you have torch 2.4.0 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvtx-cu12-12.1.105 torch-2.4.0 torchvision-0.19.0 triton-3.0.0\n",
+ "torch 2.4.0\n",
+ "torchaudio 2.5.1+cu124\n",
+ "torchsummary 1.5.1\n",
+ "torchvision 0.19.0\n",
+ "Collecting flash-attn\n",
+ " Using cached flash_attn-2.7.4.post1.tar.gz (6.0 MB)\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from flash-attn) (2.4.0)\n",
+ "Requirement already satisfied: einops in /usr/local/lib/python3.11/dist-packages (from flash-attn) (0.8.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.17.0)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (4.12.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.1.5)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2024.10.0)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2.20.5)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.1.105)\n",
+ "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.0.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->flash-attn) (12.5.82)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch->flash-attn) (3.0.2)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch->flash-attn) (1.3.0)\n",
+ "Building wheels for collected packages: flash-attn\n",
+ " Building wheel for flash-attn (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp311-cp311-linux_x86_64.whl size=187805408 sha256=92cf49e6f66795b6934cec0cba526ed6e45d3313de3f905d45df8773f19092a9\n",
+ " Stored in directory: /root/.cache/pip/wheels/3d/88/d8/284b89f56af7d5bf366b10d6b8e251ac8a7c7bf3f04203fb4f\n",
+ "Successfully built flash-attn\n",
+ "Installing collected packages: flash-attn\n",
+ "Successfully installed flash-attn-2.7.4.post1\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip3 install --upgrade pip setuptools wheel\n",
+ "!pip3 install torch==2.4.0 torchvision==0.19.0\n",
+ "!pip3 list | grep torch\n",
+ "!pip3 install flash-attn --no-build-isolation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/Jiayi-Pan/TinyZero.git\n",
+ "%cd TinyZero\n",
+ "\n",
+ "!pip install verl"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "collapsed": true,
+ "id": "_6DOQNxJ1MC1",
+ "outputId": "3f9f1d9d-9f29-4563-acba-0d7f97f14695"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "fatal: destination path 'TinyZero' already exists and is not an empty directory.\n",
+ "/content/TinyZero\n",
+ "Requirement already satisfied: verl in /usr/local/lib/python3.11/dist-packages (0.1)\n",
+ "Requirement already satisfied: torch==2.4.0 in /usr/local/lib/python3.11/dist-packages (from verl) (2.4.0)\n",
+ "Requirement already satisfied: tensordict in /usr/local/lib/python3.11/dist-packages (from verl) (0.5.0)\n",
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (from verl) (4.48.2)\n",
+ "Requirement already satisfied: codetiming in /usr/local/lib/python3.11/dist-packages (from verl) (1.4.0)\n",
+ "Requirement already satisfied: pybind11 in /usr/local/lib/python3.11/dist-packages (from verl) (2.13.6)\n",
+ "Requirement already satisfied: hydra-core in /usr/local/lib/python3.11/dist-packages (from verl) (1.3.2)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from verl) (1.26.4)\n",
+ "Requirement already satisfied: yapf in /usr/local/lib/python3.11/dist-packages (from verl) (0.43.0)\n",
+ "Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from verl) (0.3.9)\n",
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.11/dist-packages (from verl) (1.3.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (3.17.0)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (4.12.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (3.1.5)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (2024.10.0)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (2.20.5)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (12.1.105)\n",
+ "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->verl) (3.0.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.0->verl) (12.5.82)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from accelerate->verl) (24.2)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate->verl) (5.9.5)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from accelerate->verl) (6.0.2)\n",
+ "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.11/dist-packages (from accelerate->verl) (0.28.1)\n",
+ "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from accelerate->verl) (0.5.2)\n",
+ "Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.11/dist-packages (from hydra-core->verl) (2.3.0)\n",
+ "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from hydra-core->verl) (4.9.3)\n",
+ "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.11/dist-packages (from tensordict->verl) (3.1.1)\n",
+ "Requirement already satisfied: orjson in /usr/local/lib/python3.11/dist-packages (from tensordict->verl) (3.10.15)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers->verl) (2024.11.6)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers->verl) (2.32.3)\n",
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers->verl) (0.21.0)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers->verl) (4.67.1)\n",
+ "Requirement already satisfied: platformdirs>=3.5.1 in /usr/local/lib/python3.11/dist-packages (from yapf->verl) (4.3.6)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.4.0->verl) (3.0.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers->verl) (3.4.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers->verl) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers->verl) (2.3.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers->verl) (2025.1.31)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.4.0->verl) (1.3.0)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HzV28CwOmruV"
+ },
+ "source": [
+ "## Install and verify verl\n",
+ "Now we're ready to install verl!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0mtIn1VOk2E7",
+ "outputId": "8a83156e-c3aa-4921-97e2-a9472f22ed9d",
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Obtaining file:///teamspace/studios/this_studio/verl_repo\n",
+ " Installing build dependencies ... \u001b[?25ldone\n",
+ "\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n",
+ "\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n",
+ "\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n",
+ "\u001b[?25hRequirement already satisfied: accelerate in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (1.1.1)\n",
+ "Requirement already satisfied: codetiming in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (1.4.0)\n",
+ "Requirement already satisfied: datasets in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (3.1.0)\n",
+ "Requirement already satisfied: dill in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (0.3.8)\n",
+ "Requirement already satisfied: hydra-core in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (1.3.2)\n",
+ "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (1.26.4)\n",
+ "Requirement already satisfied: pybind11 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (2.13.6)\n",
+ "Requirement already satisfied: ray in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (2.10.0)\n",
+ "Requirement already satisfied: tensordict in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (0.5.0)\n",
+ "Requirement already satisfied: transformers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (4.46.3)\n",
+ "Requirement already satisfied: vllm<=0.6.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from verl==0.1) (0.5.4)\n",
+ "Requirement already satisfied: cmake>=3.21 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (3.31.1)\n",
+ "Requirement already satisfied: ninja in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (1.11.1.2)\n",
+ "Requirement already satisfied: psutil in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (6.1.0)\n",
+ "Requirement already satisfied: sentencepiece in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.2.0)\n",
+ "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (2.32.3)\n",
+ "Requirement already satisfied: tqdm in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (4.67.1)\n",
+ "Requirement already satisfied: py-cpuinfo in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (9.0.0)\n",
+ "Requirement already satisfied: tokenizers>=0.19.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.20.3)\n",
+ "Requirement already satisfied: fastapi in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.115.4)\n",
+ "Requirement already satisfied: aiohttp in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (3.10.10)\n",
+ "Requirement already satisfied: openai in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (1.55.3)\n",
+ "Requirement already satisfied: uvicorn[standard] in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.32.0)\n",
+ "Requirement already satisfied: pydantic>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (2.9.2)\n",
+ "Requirement already satisfied: pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (10.4.0)\n",
+ "Requirement already satisfied: prometheus-client>=0.18.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.21.0)\n",
+ "Requirement already satisfied: prometheus-fastapi-instrumentator>=7.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (7.0.0)\n",
+ "Requirement already satisfied: tiktoken>=0.6.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.7.0)\n",
+ "Requirement already satisfied: lm-format-enforcer==0.10.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.10.3)\n",
+ "Requirement already satisfied: outlines<0.1,>=0.0.43 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.0.46)\n",
+ "Requirement already satisfied: typing-extensions in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (4.12.2)\n",
+ "Requirement already satisfied: filelock>=3.10.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (3.16.1)\n",
+ "Requirement already satisfied: pyzmq in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (26.2.0)\n",
+ "Requirement already satisfied: nvidia-ml-py in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (12.560.30)\n",
+ "Requirement already satisfied: torch==2.4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (2.4.0)\n",
+ "Requirement already satisfied: torchvision==0.19 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.19.0)\n",
+ "Requirement already satisfied: xformers==0.0.27.post2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (0.0.27.post2)\n",
+ "Requirement already satisfied: vllm-flash-attn==2.6.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from vllm<=0.6.3->verl==0.1) (2.6.1)\n",
+ "Requirement already satisfied: interegular>=0.3.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from lm-format-enforcer==0.10.3->vllm<=0.6.3->verl==0.1) (0.3.3)\n",
+ "Requirement already satisfied: packaging in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from lm-format-enforcer==0.10.3->vllm<=0.6.3->verl==0.1) (24.1)\n",
+ "Requirement already satisfied: pyyaml in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from lm-format-enforcer==0.10.3->vllm<=0.6.3->verl==0.1) (6.0.2)\n",
+ "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (1.13.3)\n",
+ "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (3.1.4)\n",
+ "Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (2024.9.0)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (2.20.5)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.1.105)\n",
+ "Requirement already satisfied: triton==3.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.4.0->vllm<=0.6.3->verl==0.1) (3.0.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.0->vllm<=0.6.3->verl==0.1) (12.6.77)\n",
+ "Requirement already satisfied: click>=7.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (8.1.7)\n",
+ "Requirement already satisfied: jsonschema in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (4.23.0)\n",
+ "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (1.1.0)\n",
+ "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (4.23.4)\n",
+ "Requirement already satisfied: aiosignal in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (1.3.1)\n",
+ "Requirement already satisfied: frozenlist in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ray->verl==0.1) (1.5.0)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers->verl==0.1) (0.26.3)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers->verl==0.1) (2023.10.3)\n",
+ "Requirement already satisfied: safetensors>=0.4.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers->verl==0.1) (0.4.5)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets->verl==0.1) (18.1.0)\n",
+ "Requirement already satisfied: pandas in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets->verl==0.1) (2.1.4)\n",
+ "Requirement already satisfied: xxhash in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets->verl==0.1) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets->verl==0.1) (0.70.16)\n",
+ "Requirement already satisfied: omegaconf<2.4,>=2.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from hydra-core->verl==0.1) (2.3.0)\n",
+ "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from hydra-core->verl==0.1) (4.9.3)\n",
+ "Requirement already satisfied: cloudpickle in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from tensordict->verl==0.1) (3.1.0)\n",
+ "Requirement already satisfied: orjson in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from tensordict->verl==0.1) (3.10.12)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp->vllm<=0.6.3->verl==0.1) (2.4.3)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp->vllm<=0.6.3->verl==0.1) (24.2.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp->vllm<=0.6.3->verl==0.1) (6.1.0)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.12.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp->vllm<=0.6.3->verl==0.1) (1.17.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp->vllm<=0.6.3->verl==0.1) (4.0.3)\n",
+ "Requirement already satisfied: lark in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (1.2.2)\n",
+ "Requirement already satisfied: nest-asyncio in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (1.6.0)\n",
+ "Requirement already satisfied: diskcache in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (5.6.3)\n",
+ "Requirement already satisfied: numba in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (0.60.0)\n",
+ "Requirement already satisfied: referencing in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (0.35.1)\n",
+ "Requirement already satisfied: pycountry in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (24.6.1)\n",
+ "Requirement already satisfied: pyairports in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (2.1.1)\n",
+ "Requirement already satisfied: starlette<1.0.0,>=0.30.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from prometheus-fastapi-instrumentator>=7.0.0->vllm<=0.6.3->verl==0.1) (0.41.2)\n",
+ "Requirement already satisfied: annotated-types>=0.6.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pydantic>=2.0->vllm<=0.6.3->verl==0.1) (0.7.0)\n",
+ "Requirement already satisfied: pydantic-core==2.23.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pydantic>=2.0->vllm<=0.6.3->verl==0.1) (2.23.4)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->vllm<=0.6.3->verl==0.1) (3.4.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->vllm<=0.6.3->verl==0.1) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->vllm<=0.6.3->verl==0.1) (2.2.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->vllm<=0.6.3->verl==0.1) (2024.8.30)\n",
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jsonschema->ray->verl==0.1) (2024.10.1)\n",
+ "Requirement already satisfied: rpds-py>=0.7.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jsonschema->ray->verl==0.1) (0.20.1)\n",
+ "Requirement already satisfied: anyio<5,>=3.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from openai->vllm<=0.6.3->verl==0.1) (4.6.2.post1)\n",
+ "Requirement already satisfied: distro<2,>=1.7.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from openai->vllm<=0.6.3->verl==0.1) (1.9.0)\n",
+ "Requirement already satisfied: httpx<1,>=0.23.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from openai->vllm<=0.6.3->verl==0.1) (0.27.2)\n",
+ "Requirement already satisfied: jiter<1,>=0.4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from openai->vllm<=0.6.3->verl==0.1) (0.8.0)\n",
+ "Requirement already satisfied: sniffio in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from openai->vllm<=0.6.3->verl==0.1) (1.3.1)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->datasets->verl==0.1) (2.9.0.post0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->datasets->verl==0.1) (2024.2)\n",
+ "Requirement already satisfied: tzdata>=2022.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->datasets->verl==0.1) (2024.2)\n",
+ "Requirement already satisfied: h11>=0.8 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (0.14.0)\n",
+ "Requirement already satisfied: httptools>=0.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (0.6.4)\n",
+ "Requirement already satisfied: python-dotenv>=0.13 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (1.0.1)\n",
+ "Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (0.21.0)\n",
+ "Requirement already satisfied: watchfiles>=0.13 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (0.24.0)\n",
+ "Requirement already satisfied: websockets>=10.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from uvicorn[standard]->vllm<=0.6.3->verl==0.1) (13.1)\n",
+ "Requirement already satisfied: exceptiongroup>=1.0.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from anyio<5,>=3.5.0->openai->vllm<=0.6.3->verl==0.1) (1.2.2)\n",
+ "Requirement already satisfied: httpcore==1.* in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai->vllm<=0.6.3->verl==0.1) (1.0.6)\n",
+ "Requirement already satisfied: six>=1.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets->verl==0.1) (1.16.0)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from yarl<2.0,>=1.12.0->aiohttp->vllm<=0.6.3->verl==0.1) (0.2.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch==2.4.0->vllm<=0.6.3->verl==0.1) (3.0.2)\n",
+ "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from numba->outlines<0.1,>=0.0.43->vllm<=0.6.3->verl==0.1) (0.43.0)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch==2.4.0->vllm<=0.6.3->verl==0.1) (1.3.0)\n",
+ "Building wheels for collected packages: verl\n",
+ " Building editable for verl (pyproject.toml) ... \u001b[?25ldone\n",
+ "\u001b[?25h Created wheel for verl: filename=verl-0.1-0.editable-py3-none-any.whl size=13000 sha256=8fd1f1241dfe89d7f8384fe884f50ec4e070d18029c37472e5584300f5a326de\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-pz36kou4/wheels/f4/30/ea/7a2d2086bd780aba22048a0b415dc5e5a9e50b2c87e39e9717\n",
+ "Successfully built verl\n",
+ "Installing collected packages: verl\n",
+ "Successfully installed verl-0.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "# In case you run this notebook and have not cloned verl yet:\n",
+ "# !git clone https://github.com/volcengine/verl $HOME/verl_repo\n",
+ "\n",
+ "!cd $HOME/verl_repo && pip3 install -e . -U"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6DT4dmsB0WxH"
+ },
+ "source": [
+ "## Restart the python kernel"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "M9R9_tfa0WxH",
+ "outputId": "85ad8aae-74a5-4e4b-d4f3-e433e21d40b8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'status': 'ok', 'restart': True}"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import IPython\n",
+ "\n",
+ "# Restart the kernel to pickup the latest python packages\n",
+ "IPython.get_ipython().kernel.do_shutdown(restart=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install ray"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "collapsed": true,
+ "id": "EaUqfCKK2AF7",
+ "outputId": "a4107a60-e4ce-4a97-c824-b443aa4cac60"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting ray\n",
+ " Downloading ray-2.42.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (18 kB)\n",
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.11/dist-packages (from ray) (8.1.8)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from ray) (3.17.0)\n",
+ "Requirement already satisfied: jsonschema in /usr/local/lib/python3.11/dist-packages (from ray) (4.23.0)\n",
+ "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ray) (1.1.0)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from ray) (24.2)\n",
+ "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.11/dist-packages (from ray) (4.25.6)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from ray) (6.0.2)\n",
+ "Requirement already satisfied: aiosignal in /usr/local/lib/python3.11/dist-packages (from ray) (1.3.2)\n",
+ "Requirement already satisfied: frozenlist in /usr/local/lib/python3.11/dist-packages (from ray) (1.5.0)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from ray) (2.32.3)\n",
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema->ray) (25.1.0)\n",
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema->ray) (2024.10.1)\n",
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema->ray) (0.36.2)\n",
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema->ray) (0.22.3)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->ray) (3.4.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->ray) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->ray) (2.3.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->ray) (2025.1.31)\n",
+ "Requirement already satisfied: typing-extensions>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from referencing>=0.28.4->jsonschema->ray) (4.12.2)\n",
+ "Downloading ray-2.42.0-cp311-cp311-manylinux2014_x86_64.whl (67.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.4/67.4 MB\u001b[0m \u001b[31m112.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: ray\n",
+ "Successfully installed ray-2.42.0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "mOBX8Jqc-ZBe"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "try:\n",
+ " assert torch.cuda.is_available() is True\n",
+ " torch.ones(1, dtype=torch.bfloat16).cuda()\n",
+ "except AssertionError:\n",
+ " print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
+ "\n",
+ "try:\n",
+ " import verl\n",
+ "except Exception as e:\n",
+ " print(\"Please install verl via pip and restart the kernel\")\n",
+ " raise e\n",
+ "\n",
+ "import flash_attn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9mawNxDfo3Uu"
+ },
+ "source": [
+ "# Load Pretrained Language Model\n",
+ "\n",
+ "verl supports models available in Huggingface transformers (as well as custom Megatron models).\n",
+ "\n",
+ "Let's download the model first."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "k8FsgBYnpR-R",
+ "outputId": "ff52d13a-b84b-4739-cb84-0a8993b39ab8",
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\rFetching 10 files: 0% 0/10 [00:00, ?it/s]Downloading 'merges.txt' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/PtHk0z_I45atnj23IIRhTExwT3w=.20024bfe7c83998e9aeaf98a0cd6a2ce6306c2f0.incomplete'\n",
+ "Downloading 'config.json' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/8_PA_wEVGiVa2goH2H4KQOQpvVY=.0dbb161213629a23f0fc00ef286e6b1e366d180f.incomplete'\n",
+ "Downloading 'generation_config.json' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/3EVKVggOldJcKSsGjSdoUCN1AyQ=.dfc11073787daf1b0f9c0f1499487ab5f4c93738.incomplete'\n",
+ "Downloading 'model.safetensors' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/xGOKKLRSlIhH692hSVvI1-gpoa8=.fdf756fa7fcbe7404d5c60e26bff1a0c8b8aa1f72ced49e7dd0210fe288fb7fe.incomplete'\n",
+ "\n",
+ "\rconfig.json: 0% 0.00/659 [00:00, ?B/s]\u001b[A\rconfig.json: 100% 659/659 [00:00<00:00, 5.84MB/s]\n",
+ "\n",
+ "\rmerges.txt: 0% 0.00/1.67M [00:00, ?B/s]\u001b[ADownload complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/config.json\n",
+ "Downloading 'LICENSE' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/DhCjcNQuMpl4FL346qr3tvNUCgY=.6634c8cc3133b3848ec74b9f275acaaa1ea618ab.incomplete'\n",
+ "Downloading 'tokenizer.json' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/HgM_lKo9sdSCfRtVg7MMFS7EKqo=.443909a61d429dff23010e5bddd28ff530edda00.incomplete'\n",
+ "Downloading '.gitattributes' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/wPaCkH-WbT7GsmxMKKrNZTV4nSM=.a6344aac8c09253b3b630fb776ae94478aa0275b.incomplete'\n",
+ "\n",
+ "\n",
+ "\rtokenizer.json: 0% 0.00/7.03M [00:00, ?B/s]\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "\r.gitattributes: 0% 0.00/1.52k [00:00, ?B/s]\u001b[A\u001b[A\u001b[A\r.gitattributes: 100% 1.52k/1.52k [00:00<00:00, 13.8MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/.gitattributes\n",
+ "\rFetching 10 files: 10% 1/10 [00:00<00:01, 6.32it/s]Downloading 'README.md' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/Xn7B-BWUGOee2Y6hCZtEhtFu4BE=.4b8373851d093eb9f3017443f27781c6971eff24.incomplete'\n",
+ "Downloading 'tokenizer_config.json' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/vzaExXFZNBay89bvlQv-ZcI6BTg=.07bfe0640cb5a0037f9322287fbfc682806cf672.incomplete'\n",
+ "\n",
+ "\n",
+ "\n",
+ "\rgeneration_config.json: 0% 0.00/242 [00:00, ?B/s]\u001b[A\u001b[A\u001b[A\rgeneration_config.json: 100% 242/242 [00:00<00:00, 1.72MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/generation_config.json\n",
+ "Downloading 'vocab.json' to '/root/models/Qwen2.5-0.5B-Instruct/.cache/huggingface/download/j3m-Hy6QvBddw8RXA1uSWl1AJ0c=.4783fe10ac3adce15ac8f358ef5462739852c569.incomplete'\n",
+ "\n",
+ "\n",
+ "\n",
+ "\rmodel.safetensors: 0% 0.00/988M [00:00, ?B/s]\u001b[A\u001b[A\u001b[A\n",
+ "\rmerges.txt: 100% 1.67M/1.67M [00:00<00:00, 16.1MB/s]\u001b[A\rmerges.txt: 100% 1.67M/1.67M [00:00<00:00, 15.9MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/merges.txt\n",
+ "\n",
+ "\rtokenizer_config.json: 0% 0.00/7.30k [00:00, ?B/s]\u001b[A\rtokenizer_config.json: 100% 7.30k/7.30k [00:00<00:00, 39.6MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/tokenizer_config.json\n",
+ "\n",
+ "\rLICENSE: 0% 0.00/11.3k [00:00, ?B/s]\u001b[A\rLICENSE: 100% 11.3k/11.3k [00:00<00:00, 68.9MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/LICENSE\n",
+ "\n",
+ "\n",
+ "tokenizer.json: 100% 7.03M/7.03M [00:00<00:00, 23.8MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/tokenizer.json\n",
+ "\n",
+ "\n",
+ "\n",
+ "model.safetensors: 1% 10.5M/988M [00:00<00:25, 38.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "vocab.json: 0% 0.00/2.78M [00:00, ?B/s]\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 2% 21.0M/988M [00:00<00:23, 40.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "README.md: 100% 4.92k/4.92k [00:00<00:00, 25.7MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/README.md\n",
+ "Fetching 10 files: 30% 3/10 [00:00<00:01, 3.91it/s]\n",
+ "vocab.json: 100% 2.78M/2.78M [00:00<00:00, 6.93MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/vocab.json\n",
+ "\n",
+ "\n",
+ "\n",
+ "model.safetensors: 3% 31.5M/988M [00:00<00:23, 41.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 4% 41.9M/988M [00:01<00:22, 41.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 5% 52.4M/988M [00:01<00:22, 42.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 6% 62.9M/988M [00:01<00:21, 42.2MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 7% 73.4M/988M [00:01<00:21, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 8% 83.9M/988M [00:01<00:21, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 10% 94.4M/988M [00:02<00:21, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 11% 105M/988M [00:02<00:21, 42.0MB/s] \u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 12% 115M/988M [00:02<00:20, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 13% 126M/988M [00:02<00:20, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 14% 136M/988M [00:03<00:20, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 15% 147M/988M [00:03<00:19, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 16% 157M/988M [00:03<00:19, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 17% 168M/988M [00:03<00:19, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 18% 178M/988M [00:04<00:19, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 19% 189M/988M [00:04<00:18, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 20% 199M/988M [00:04<00:18, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 21% 210M/988M [00:04<00:18, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 22% 220M/988M [00:05<00:18, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 23% 231M/988M [00:05<00:17, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 24% 241M/988M [00:05<00:17, 42.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 25% 252M/988M [00:05<00:17, 42.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 27% 262M/988M [00:06<00:16, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 28% 273M/988M [00:06<00:16, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 29% 283M/988M [00:06<00:16, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 30% 294M/988M [00:06<00:16, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 31% 304M/988M [00:07<00:16, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 32% 315M/988M [00:07<00:15, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 33% 325M/988M [00:07<00:15, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 34% 336M/988M [00:07<00:15, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 35% 346M/988M [00:08<00:15, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 36% 357M/988M [00:08<00:14, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 37% 367M/988M [00:08<00:14, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 38% 377M/988M [00:08<00:14, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 39% 388M/988M [00:09<00:14, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 40% 398M/988M [00:09<00:13, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 41% 409M/988M [00:09<00:13, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 42% 419M/988M [00:09<00:13, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 44% 430M/988M [00:10<00:13, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 45% 440M/988M [00:10<00:11, 47.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 46% 451M/988M [00:10<00:11, 46.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 47% 461M/988M [00:10<00:11, 45.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 48% 472M/988M [00:11<00:11, 44.2MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 49% 482M/988M [00:11<00:11, 43.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 50% 493M/988M [00:11<00:11, 43.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 51% 503M/988M [00:11<00:11, 43.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 52% 514M/988M [00:12<00:11, 42.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 53% 524M/988M [00:12<00:10, 42.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 54% 535M/988M [00:12<00:10, 43.0MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 55% 545M/988M [00:12<00:10, 42.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 56% 556M/988M [00:12<00:10, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 57% 566M/988M [00:13<00:09, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 58% 577M/988M [00:13<00:12, 33.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 59% 587M/988M [00:13<00:10, 37.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 60% 598M/988M [00:14<00:10, 39.0MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 62% 608M/988M [00:14<00:09, 39.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 63% 619M/988M [00:14<00:09, 40.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 64% 629M/988M [00:14<00:08, 41.0MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 65% 640M/988M [00:15<00:08, 41.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 66% 650M/988M [00:15<00:08, 41.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 67% 661M/988M [00:15<00:07, 41.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 68% 671M/988M [00:15<00:07, 41.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 69% 682M/988M [00:16<00:07, 42.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 70% 692M/988M [00:16<00:07, 42.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 71% 703M/988M [00:16<00:06, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 72% 713M/988M [00:16<00:06, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 73% 724M/988M [00:17<00:06, 42.1MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 74% 734M/988M [00:17<00:05, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 75% 744M/988M [00:17<00:05, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 76% 755M/988M [00:17<00:05, 42.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 77% 765M/988M [00:18<00:05, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 79% 776M/988M [00:18<00:05, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 80% 786M/988M [00:18<00:04, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 81% 797M/988M [00:18<00:04, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 82% 807M/988M [00:19<00:04, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 83% 818M/988M [00:19<00:03, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 84% 828M/988M [00:19<00:03, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 85% 839M/988M [00:19<00:03, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 86% 849M/988M [00:20<00:03, 42.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 87% 860M/988M [00:20<00:03, 42.6MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 88% 870M/988M [00:20<00:02, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 89% 881M/988M [00:20<00:02, 42.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 90% 891M/988M [00:21<00:02, 42.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 91% 902M/988M [00:21<00:02, 38.8MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 92% 912M/988M [00:21<00:02, 35.7MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 93% 923M/988M [00:22<00:01, 36.9MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 94% 933M/988M [00:22<00:01, 38.3MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 96% 944M/988M [00:22<00:01, 39.4MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 97% 954M/988M [00:22<00:00, 40.2MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 98% 965M/988M [00:23<00:00, 40.5MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 99% 975M/988M [00:23<00:00, 41.2MB/s]\u001b[A\u001b[A\u001b[A\n",
+ "\n",
+ "\n",
+ "model.safetensors: 100% 988M/988M [00:23<00:00, 42.0MB/s]\n",
+ "Download complete. Moving file to /root/models/Qwen2.5-0.5B-Instruct/model.safetensors\n",
+ "Fetching 10 files: 100% 10/10 [00:23<00:00, 2.37s/it]\n",
+ "/root/models/Qwen2.5-0.5B-Instruct\n"
+ ]
+ }
+ ],
+ "source": [
+ "!huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir $HOME/models/Qwen2.5-0.5B-Instruct\n",
+ "\n",
+ "# If huggingface-cli is not stable, use the method below\n",
+ "# import transformers\n",
+ "# transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zWlDQ6EjnBWz"
+ },
+ "source": [
+ "# Dataset preparation\n",
+ "\n",
+ "We train with the Grade School Math 8K (GSM8k) task in this demo. The dataset is downloaded from huggingface [gsm8k](https://huggingface.co/datasets/openai/gsm8k) and below are some samples:\n",
+ "\n",
+ "\n",
+ "**Prompt**\n",
+ "\n",
+ "Katy makes coffee using teaspoons of sugar and cups of water in the ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups of water, calculate the number of teaspoonfuls of sugar she used.\n",
+ "\n",
+ "**Solution**\n",
+ "\n",
+ "The total ratio representing the ingredients she used to make the coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the number of teaspoons she used is 7/20, she used 7/20120 = <<7/20120=42>>42 #### 42"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "AgRCvb6V6B3A",
+ "outputId": "ec3fd239-1727-4995-c935-e972949e5c49"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (3.2.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets) (3.17.0)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (1.26.4)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (17.0.0)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets) (2.2.2)\n",
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets) (4.67.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets) (3.11.12)\n",
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.28.1)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from datasets) (24.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from datasets) (6.0.2)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (2.4.4)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.3.2)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (25.1.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.5.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (6.1.0)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (0.2.1)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.18.3)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.4.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2.3.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2025.1.31)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.1)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
+ "README.md: 100% 7.94k/7.94k [00:00<00:00, 44.3MB/s]\n",
+ "train-00000-of-00001.parquet: 100% 2.31M/2.31M [00:00<00:00, 29.1MB/s]\n",
+ "test-00000-of-00001.parquet: 100% 419k/419k [00:00<00:00, 187MB/s]\n",
+ "Generating train split: 100% 7473/7473 [00:00<00:00, 173916.13 examples/s]\n",
+ "Generating test split: 100% 1319/1319 [00:00<00:00, 330088.72 examples/s]\n",
+ "Map: 100% 7473/7473 [00:00<00:00, 19404.02 examples/s]\n",
+ "Map: 100% 1319/1319 [00:00<00:00, 17962.32 examples/s]\n",
+ "Creating parquet from Arrow format: 100% 8/8 [00:00<00:00, 190.79ba/s]\n",
+ "Creating parquet from Arrow format: 100% 2/2 [00:00<00:00, 282.67ba/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install datasets\n",
+ "!mkdir -p $HOME/data/gsm8k\n",
+ "!python3 /content/TinyZero/examples/data_preprocess/gsm8k.py --local_dir $HOME/data/gsm8k"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JPZZKBxunAoj"
+ },
+ "source": [
+ "# the reward\n",
+ "\n",
+ "We use a rule-based reward model. We force the model to produce a final answer following 4 `#` as shown in the solution. We extract the final answer from both the solution and model's output using regular expression matching. We compare them and assign a reward of 1 to correct answer, 0.1 to incorrect answer and 0 to no answer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SjjLVuO60WD1",
+ "outputId": "cd96ca72-a6ea-448a-e9da-6c39d6fc55d2"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):\n",
+ " answer = extract_solution(solution_str=solution_str, method=method)\n",
+ " if answer is None:\n",
+ " return 0\n",
+ " else:\n",
+ " if answer == ground_truth:\n",
+ " return score\n",
+ " else:\n",
+ " return format_score\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import inspect\n",
+ "from verl.utils.reward_score.gsm8k import compute_score as gsm8k_reward\n",
+ "print(inspect.getsource(gsm8k_reward))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NPBGPdSD0sCF"
+ },
+ "source": [
+ "# Run the RL Pipeline\n",
+ "Let's start with the Proximal Policy Optimization (PPO) algorithm, one of the most widely used methods for post-training large language models.\n",
+ "\n",
+ "The main entry point of the PPO algorithm example is: `main_ppo.py`. A detailed guide to understanding the code architecture of `main_ppo.py` is available [here](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html).\n",
+ "\n",
+ "In this tutorial, we will demonstrate how to run the PPO algorithm with **Qwen 2.5-0.5B** by setting:\n",
+ "- `trainer.n_gpus_per_node`: Number of GPUs per node.\n",
+ "\n",
+ "- `actor_rollout_ref.rollout.tensor_model_parallel_size`: TP size for rollout. Only effective for vllm.\n",
+ "\n",
+ "- `actor_rollout_ref/critic.model.path`: Huggingface model path. This can be either local path or HDFS path. For HDFS path, we provide utils to download it to DRAM and convert the HDFS path to local path.\n",
+ "\n",
+ "- `data.train_batch_size`: Batch size sampled for one training iteration of different RL algorithms.\n",
+ "\n",
+ "- `data.max_prompt_length`: Maximum prompt length. All prompts will be left-padded to this length. An error will be reported if the length is too long.\n",
+ "\n",
+ "- `data.max_response_length`: Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.\n",
+ "\n",
+ "- `actor_rollout_ref.actor.ppo_mini_batch_size`: One sample is split into multiple sub-batches with batch_size=ppo_mini_batch_size for PPO updates.\n",
+ "\n",
+ "- `actor_rollout_ref/critic.actor.ppo_micro_batch_size`: Similar to gradient accumulation, the micro_batch_size for one forward pass, trading speed for GPU memory.\n",
+ "\n",
+ "The full configuration explanation is available [here](https://verl.readthedocs.io/en/latest/examples/config.html).\n",
+ "\n",
+ "The training may take a few hours to finish but you can observe how the model performance increases. It will progressively output:\n",
+ "\n",
+ "- generated sentences.\n",
+ "\n",
+ "- step information with RL metrics, such as entropy loss, kl, and ``val/test_score/openai/gsm8k`` (validated every ``trainer.test_freq`` steps)\n",
+ "\n",
+ "If you come across GPU out of memory issues, set smaller values for the micro batch size used for gradient accumulation:\n",
+ "\n",
+ "- actor_rollout_ref.actor.ppo_micro_batch_size=1\n",
+ "- critic.ppo_micro_batch_size=1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install vllm==0.6.3"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "twD-2n6L4qrI",
+ "outputId": "2be29b46-66ae-40c2-c4b1-fe2c70f81152"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting vllm==0.6.3\n",
+ " Downloading vllm-0.6.3-cp38-abi3-manylinux1_x86_64.whl.metadata (10 kB)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (5.9.5)\n",
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (0.2.0)\n",
+ "Requirement already satisfied: numpy<2.0.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (1.26.4)\n",
+ "Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (2.32.3)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (4.67.1)\n",
+ "Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (9.0.0)\n",
+ "Requirement already satisfied: transformers>=4.45.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (4.48.2)\n",
+ "Requirement already satisfied: tokenizers>=0.19.1 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (0.21.0)\n",
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (4.25.6)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (3.11.12)\n",
+ "Requirement already satisfied: openai>=1.40.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (1.61.1)\n",
+ "Collecting uvicorn[standard] (from vllm==0.6.3)\n",
+ " Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)\n",
+ "Requirement already satisfied: pydantic>=2.9 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (2.10.6)\n",
+ "Requirement already satisfied: pillow in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (11.1.0)\n",
+ "Requirement already satisfied: prometheus-client>=0.18.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (0.21.1)\n",
+ "Collecting prometheus-fastapi-instrumentator>=7.0.0 (from vllm==0.6.3)\n",
+ " Downloading prometheus_fastapi_instrumentator-7.0.2-py3-none-any.whl.metadata (13 kB)\n",
+ "Collecting tiktoken>=0.6.0 (from vllm==0.6.3)\n",
+ " Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n",
+ "Collecting lm-format-enforcer==0.10.6 (from vllm==0.6.3)\n",
+ " Downloading lm_format_enforcer-0.10.6-py3-none-any.whl.metadata (16 kB)\n",
+ "Collecting outlines<0.1,>=0.0.43 (from vllm==0.6.3)\n",
+ " Downloading outlines-0.0.46-py3-none-any.whl.metadata (15 kB)\n",
+ "Requirement already satisfied: typing-extensions>=4.10 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (4.12.2)\n",
+ "Requirement already satisfied: filelock>=3.10.4 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (3.17.0)\n",
+ "Collecting partial-json-parser (from vllm==0.6.3)\n",
+ " Downloading partial_json_parser-0.2.1.1.post5-py3-none-any.whl.metadata (6.1 kB)\n",
+ "Requirement already satisfied: pyzmq in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (24.0.1)\n",
+ "Collecting msgspec (from vllm==0.6.3)\n",
+ " Downloading msgspec-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)\n",
+ "Collecting gguf==0.10.0 (from vllm==0.6.3)\n",
+ " Downloading gguf-0.10.0-py3-none-any.whl.metadata (3.5 kB)\n",
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (8.6.1)\n",
+ "Collecting mistral-common>=1.4.4 (from mistral-common[opencv]>=1.4.4->vllm==0.6.3)\n",
+ " Downloading mistral_common-1.5.3-py3-none-any.whl.metadata (4.5 kB)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (6.0.2)\n",
+ "Requirement already satisfied: einops in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (0.8.0)\n",
+ "Requirement already satisfied: ray>=2.9 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (2.42.0)\n",
+ "Collecting nvidia-ml-py (from vllm==0.6.3)\n",
+ " Downloading nvidia_ml_py-12.570.86-py3-none-any.whl.metadata (8.7 kB)\n",
+ "Requirement already satisfied: torch==2.4.0 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (2.4.0)\n",
+ "Requirement already satisfied: torchvision==0.19 in /usr/local/lib/python3.11/dist-packages (from vllm==0.6.3) (0.19.0)\n",
+ "Collecting xformers==0.0.27.post2 (from vllm==0.6.3)\n",
+ " Downloading xformers-0.0.27.post2-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.0 kB)\n",
+ "Collecting fastapi!=0.113.*,!=0.114.0,>=0.107.0 (from vllm==0.6.3)\n",
+ " Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)\n",
+ "Collecting interegular>=0.3.2 (from lm-format-enforcer==0.10.6->vllm==0.6.3)\n",
+ " Downloading interegular-0.3.3-py37-none-any.whl.metadata (3.0 kB)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from lm-format-enforcer==0.10.6->vllm==0.6.3) (24.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (1.13.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (3.1.5)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (2024.9.0)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (9.1.0.70)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (2.20.5)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (12.1.105)\n",
+ "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0->vllm==0.6.3) (3.0.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.0->vllm==0.6.3) (12.5.82)\n",
+ "Collecting starlette<0.46.0,>=0.40.0 (from fastapi!=0.113.*,!=0.114.0,>=0.107.0->vllm==0.6.3)\n",
+ " Downloading starlette-0.45.3-py3-none-any.whl.metadata (6.3 kB)\n",
+ "Requirement already satisfied: jsonschema>=4.21.1 in /usr/local/lib/python3.11/dist-packages (from mistral-common>=1.4.4->mistral-common[opencv]>=1.4.4->vllm==0.6.3) (4.23.0)\n",
+ "Requirement already satisfied: opencv-python-headless>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from mistral-common[opencv]>=1.4.4->vllm==0.6.3) (4.11.0.86)\n",
+ "Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.11/dist-packages (from openai>=1.40.0->vllm==0.6.3) (3.7.1)\n",
+ "Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.11/dist-packages (from openai>=1.40.0->vllm==0.6.3) (1.9.0)\n",
+ "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from openai>=1.40.0->vllm==0.6.3) (0.28.1)\n",
+ "Requirement already satisfied: jiter<1,>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from openai>=1.40.0->vllm==0.6.3) (0.8.2)\n",
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.11/dist-packages (from openai>=1.40.0->vllm==0.6.3) (1.3.1)\n",
+ "Collecting lark (from outlines<0.1,>=0.0.43->vllm==0.6.3)\n",
+ " Downloading lark-1.2.2-py3-none-any.whl.metadata (1.8 kB)\n",
+ "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from outlines<0.1,>=0.0.43->vllm==0.6.3) (1.6.0)\n",
+ "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.11/dist-packages (from outlines<0.1,>=0.0.43->vllm==0.6.3) (3.1.1)\n",
+ "Collecting diskcache (from outlines<0.1,>=0.0.43->vllm==0.6.3)\n",
+ " Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)\n",
+ "Requirement already satisfied: numba in /usr/local/lib/python3.11/dist-packages (from outlines<0.1,>=0.0.43->vllm==0.6.3) (0.61.0)\n",
+ "Requirement already satisfied: referencing in /usr/local/lib/python3.11/dist-packages (from outlines<0.1,>=0.0.43->vllm==0.6.3) (0.36.2)\n",
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (from outlines<0.1,>=0.0.43->vllm==0.6.3) (3.2.0)\n",
+ "Collecting pycountry (from outlines<0.1,>=0.0.43->vllm==0.6.3)\n",
+ " Downloading pycountry-24.6.1-py3-none-any.whl.metadata (12 kB)\n",
+ "Collecting pyairports (from outlines<0.1,>=0.0.43->vllm==0.6.3)\n",
+ " Downloading pyairports-2.1.1-py3-none-any.whl.metadata (1.7 kB)\n",
+ "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.9->vllm==0.6.3) (0.7.0)\n",
+ "Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.9->vllm==0.6.3) (2.27.2)\n",
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.11/dist-packages (from ray>=2.9->vllm==0.6.3) (8.1.8)\n",
+ "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ray>=2.9->vllm==0.6.3) (1.1.0)\n",
+ "Requirement already satisfied: aiosignal in /usr/local/lib/python3.11/dist-packages (from ray>=2.9->vllm==0.6.3) (1.3.2)\n",
+ "Requirement already satisfied: frozenlist in /usr/local/lib/python3.11/dist-packages (from ray>=2.9->vllm==0.6.3) (1.5.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.26.0->vllm==0.6.3) (3.4.1)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.26.0->vllm==0.6.3) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.26.0->vllm==0.6.3) (2.3.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.26.0->vllm==0.6.3) (2025.1.31)\n",
+ "Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.11/dist-packages (from tiktoken>=0.6.0->vllm==0.6.3) (2024.11.6)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.11/dist-packages (from tokenizers>=0.19.1->vllm==0.6.3) (0.28.1)\n",
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.45.0->vllm==0.6.3) (0.5.2)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->vllm==0.6.3) (2.4.4)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->vllm==0.6.3) (25.1.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->vllm==0.6.3) (6.1.0)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->vllm==0.6.3) (0.2.1)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->vllm==0.6.3) (1.18.3)\n",
+ "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata->vllm==0.6.3) (3.21.0)\n",
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.11/dist-packages (from uvicorn[standard]->vllm==0.6.3) (0.14.0)\n",
+ "Collecting httptools>=0.6.3 (from uvicorn[standard]->vllm==0.6.3)\n",
+ " Downloading httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)\n",
+ "Collecting python-dotenv>=0.13 (from uvicorn[standard]->vllm==0.6.3)\n",
+ " Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)\n",
+ "Collecting uvloop!=0.15.0,!=0.15.1,>=0.14.0 (from uvicorn[standard]->vllm==0.6.3)\n",
+ " Downloading uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\n",
+ "Collecting watchfiles>=0.13 (from uvicorn[standard]->vllm==0.6.3)\n",
+ " Downloading watchfiles-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\n",
+ "Requirement already satisfied: websockets>=10.4 in /usr/local/lib/python3.11/dist-packages (from uvicorn[standard]->vllm==0.6.3) (14.2)\n",
+ "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->openai>=1.40.0->vllm==0.6.3) (1.0.7)\n",
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.21.1->mistral-common>=1.4.4->mistral-common[opencv]>=1.4.4->vllm==0.6.3) (2024.10.1)\n",
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.21.1->mistral-common>=1.4.4->mistral-common[opencv]>=1.4.4->vllm==0.6.3) (0.22.3)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (17.0.0)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (2.2.2)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (0.70.16)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.4.0->vllm==0.6.3) (3.0.2)\n",
+ "Requirement already satisfied: llvmlite<0.45,>=0.44.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba->outlines<0.1,>=0.0.43->vllm==0.6.3) (0.44.0)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.4.0->vllm==0.6.3) (1.3.0)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (2025.1)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (2025.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets->outlines<0.1,>=0.0.43->vllm==0.6.3) (1.17.0)\n",
+ "Downloading vllm-0.6.3-cp38-abi3-manylinux1_x86_64.whl (193.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m193.5/193.5 MB\u001b[0m \u001b[31m107.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading gguf-0.10.0-py3-none-any.whl (71 kB)\n",
+ "Downloading lm_format_enforcer-0.10.6-py3-none-any.whl (43 kB)\n",
+ "Downloading xformers-0.0.27.post2-cp311-cp311-manylinux2014_x86_64.whl (20.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.8/20.8 MB\u001b[0m \u001b[31m109.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading fastapi-0.115.8-py3-none-any.whl (94 kB)\n",
+ "Downloading mistral_common-1.5.3-py3-none-any.whl (6.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m106.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading outlines-0.0.46-py3-none-any.whl (101 kB)\n",
+ "Downloading prometheus_fastapi_instrumentator-7.0.2-py3-none-any.whl (18 kB)\n",
+ "Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m62.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading msgspec-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (210 kB)\n",
+ "Downloading nvidia_ml_py-12.570.86-py3-none-any.whl (44 kB)\n",
+ "Downloading partial_json_parser-0.2.1.1.post5-py3-none-any.whl (10 kB)\n",
+ "Downloading httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (459 kB)\n",
+ "Downloading interegular-0.3.3-py37-none-any.whl (23 kB)\n",
+ "Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n",
+ "Downloading starlette-0.45.3-py3-none-any.whl (71 kB)\n",
+ "Downloading uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m123.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading watchfiles-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (452 kB)\n",
+ "Downloading diskcache-5.6.3-py3-none-any.whl (45 kB)\n",
+ "Downloading lark-1.2.2-py3-none-any.whl (111 kB)\n",
+ "Downloading pyairports-2.1.1-py3-none-any.whl (371 kB)\n",
+ "Downloading pycountry-24.6.1-py3-none-any.whl (6.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m127.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading uvicorn-0.34.0-py3-none-any.whl (62 kB)\n",
+ "Installing collected packages: pyairports, nvidia-ml-py, uvloop, uvicorn, python-dotenv, pycountry, partial-json-parser, msgspec, lark, interegular, httptools, gguf, diskcache, watchfiles, tiktoken, starlette, prometheus-fastapi-instrumentator, lm-format-enforcer, fastapi, xformers, mistral-common, outlines, vllm\n",
+ "Successfully installed diskcache-5.6.3 fastapi-0.115.8 gguf-0.10.0 httptools-0.6.4 interegular-0.3.3 lark-1.2.2 lm-format-enforcer-0.10.6 mistral-common-1.5.3 msgspec-0.19.0 nvidia-ml-py-12.570.86 outlines-0.0.46 partial-json-parser-0.2.1.1.post5 prometheus-fastapi-instrumentator-7.0.2 pyairports-2.1.1 pycountry-24.6.1 python-dotenv-1.0.1 starlette-0.45.3 tiktoken-0.8.0 uvicorn-0.34.0 uvloop-0.21.0 vllm-0.6.3 watchfiles-1.0.4 xformers-0.0.27.post2\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "GvyEebBB4eCA",
+ "outputId": "1d69a439-128a-4bf8-bc20-14009e703361"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "2025-02-11 17:59:40,119\tINFO worker.py:1841 -- Started a local Ray instance.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m {'actor_rollout_ref': {'actor': {'clip_ratio': 0.2,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'entropy_coeff': 0.001,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'fsdp_config': {'grad_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'optimizer_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'param_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'wrap_policy': {'min_num_params': 0}},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'grad_clip': 1.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'optim': {'lr': 1e-06,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'lr_warmup_steps_ratio': 0.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'min_lr_ratio': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'total_training_steps': -1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'warmup_style': 'constant'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_epochs': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_micro_batch_size': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_mini_batch_size': 64,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'shuffle': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'strategy': 'fsdp'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'hybrid_engine': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'model': {'enable_gradient_checkpointing': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'external_lib': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'override_config': {},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'path': '/root/models/Qwen2.5-0.5B-Instruct'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ref': {'fsdp_config': {'param_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'wrap_policy': {'min_num_params': 0}},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'log_prob_micro_batch_size': 4},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'rollout': {'do_sample': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'dtype': 'bfloat16',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'enforce_eager': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'free_cache_engine': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'gpu_memory_utilization': 0.4,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ignore_eos': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'load_format': 'dummy_dtensor',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'log_prob_micro_batch_size': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'max_num_batched_tokens': 8192,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'max_num_seqs': 1024,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'name': 'vllm',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'prompt_length': 512,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'response_length': 256,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'temperature': 1.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'tensor_model_parallel_size': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'top_k': -1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'top_p': 1}},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'algorithm': {'adv_estimator': 'gae',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'gamma': 1.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'kl_ctrl': {'kl_coef': 0.001, 'type': 'fixed'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'kl_penalty': 'kl',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'lam': 1.0},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'critic': {'cliprange_value': 0.5,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'grad_clip': 1.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'model': {'enable_gradient_checkpointing': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'external_lib': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'fsdp_config': {'grad_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'optimizer_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'param_offload': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'wrap_policy': {'min_num_params': 0}},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'override_config': {},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'path': '/root/models/Qwen2.5-0.5B-Instruct',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'tokenizer_path': '/root/models/Qwen2.5-0.5B-Instruct'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'optim': {'lr': 1e-05,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'lr_warmup_steps_ratio': 0.0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'min_lr_ratio': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'total_training_steps': -1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'warmup_style': 'constant'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_epochs': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_micro_batch_size': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'ppo_mini_batch_size': 64,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'shuffle': True,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'strategy': 'fsdp'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'data': {'max_prompt_length': 512,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'max_response_length': 256,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'prompt_key': 'prompt',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'return_raw_chat': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'return_raw_input_ids': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'tokenizer': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'train_batch_size': 256,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'train_files': '/root/data/gsm8k/train.parquet',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'val_batch_size': 1312,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'val_files': '/root/data/gsm8k/test.parquet'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'reward_model': {'enable': False,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'max_length': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'micro_batch_size': 64,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'model': {'external_lib': None,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'fsdp_config': {'min_num_params': 0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'param_offload': False},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'input_tokenizer': '/root/models/Qwen2.5-0.5B-Instruct',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'path': '~/models/FsfairX-LLaMA3-RM-v0.1'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'strategy': 'fsdp'},\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'trainer': {'critic_warmup': 0,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'default_hdfs_dir': '',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'default_local_dir': 'checkpoints/verl_examples/gsm8k',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'experiment_name': 'gsm8k',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'logger': ['console'],\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'n_gpus_per_node': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'nnodes': 1,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'project_name': 'verl_examples',\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'save_freq': 10,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'test_freq': 10,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'total_epochs': 15,\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 'val_before_train': False}}\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 2025-02-11 17:59:47.193745: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 2025-02-11 17:59:47.214042: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m E0000 00:00:1739296787.233627 24124 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m E0000 00:00:1739296787.239225 24124 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 2025-02-11 17:59:47.258356: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m original dataset len: 7473\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m filter dataset len: 7473\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m original dataset len: 1319\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m filter dataset len: 1319\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m Size of train dataloader: 29\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m Size of val dataloader: 1\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m 2025-02-11 18:00:01.605983: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m 2025-02-11 18:00:01.623081: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m E0000 00:00:1739296801.644696 24317 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m E0000 00:00:1739296801.651419 24317 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m 2025-02-11 18:00:01.673346: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Critic overriding config {'bos_token_id': None, 'eos_token_id': 151645, 'pad_token_id': 151643}\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Qwen2ForCausalLM contains 494.03M parameters\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Before critic FSDP, memory allocated (GB): 0.0, memory reserved (GB): 0.0\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_init_utils.py:440: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.FULL_SHARD since the world size is 1.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m NCCL version 2.20.5+cuda12.4\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m After critic FSDP, memory allocated (GB): 1.8410840034484863, memory reserved (GB): 2.44921875\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Total steps: 435, num_warmup_steps: 0\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Model config after override: Qwen2Config {\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"_name_or_path\": \"/root/models/Qwen2.5-0.5B-Instruct\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"architectures\": [\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"Qwen2ForCausalLM\"\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m ],\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"attention_dropout\": 0.0,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"eos_token_id\": 151645,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"hidden_act\": \"silu\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"hidden_size\": 896,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"initializer_range\": 0.02,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"intermediate_size\": 4864,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"max_position_embeddings\": 32768,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"max_window_layers\": 21,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"model_type\": \"qwen2\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_attention_heads\": 14,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_hidden_layers\": 24,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_key_value_heads\": 2,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"pad_token_id\": 151643,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rms_norm_eps\": 1e-06,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rope_scaling\": null,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rope_theta\": 1000000.0,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"sliding_window\": null,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"tie_word_embeddings\": true,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"torch_dtype\": \"bfloat16\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"transformers_version\": \"4.48.2\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"use_cache\": true,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"use_sliding_window\": false,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"vocab_size\": 151936\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m }\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Qwen2ForCausalLM contains 494.03M parameters\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m wrap_policy: functools.partial(, transformer_layer_cls={})\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_init_utils.py:440: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.FULL_SHARD since the world size is 1.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Model config after override: Qwen2Config {\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"_name_or_path\": \"/root/models/Qwen2.5-0.5B-Instruct\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"architectures\": [\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"Qwen2ForCausalLM\"\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m ],\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"attention_dropout\": 0.0,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"eos_token_id\": 151645,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"hidden_act\": \"silu\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"hidden_size\": 896,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"initializer_range\": 0.02,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"intermediate_size\": 4864,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"max_position_embeddings\": 32768,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"max_window_layers\": 21,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"model_type\": \"qwen2\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_attention_heads\": 14,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_hidden_layers\": 24,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"num_key_value_heads\": 2,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"pad_token_id\": 151643,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rms_norm_eps\": 1e-06,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rope_scaling\": null,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"rope_theta\": 1000000.0,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"sliding_window\": null,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"tie_word_embeddings\": true,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"torch_dtype\": \"bfloat16\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"transformers_version\": \"4.48.2\",\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"use_cache\": true,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"use_sliding_window\": false,\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \"vocab_size\": 151936\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m }\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m \n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Qwen2ForCausalLM contains 494.03M parameters\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m wrap_policy: functools.partial(, transformer_layer_cls={})\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_init_utils.py:440: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.FULL_SHARD since the world size is 1.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Total steps: 435, num_warmup_steps: 0\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m No module named 'vllm._version'\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m from vllm.version import __version__ as VLLM_VERSION\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m Before building vllm rollout, memory allocated (GB): 4.602716445922852, memory reserved (GB): 5.2734375\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m WARNING 02-11 18:00:20 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m local rank 0\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m no hf weight loader need to be updated\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m before init cache memory allocated: 5.94458624GB, reserved: 6.067060736GB\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m after init cache memory allocated: 19.5341312GB, reserved: 19.656605696GB\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m Using LocalLogger is deprecated. The constructor API will change \n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m kwargs: {'n': 1, 'logprobs': 1, 'max_tokens': 256, 'detokenize': False, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m After building vllm rollout, memory allocated (GB): 17.271177291870117, memory reserved (GB): 18.306640625\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m After building sharding manager, memory allocated (GB): 17.271177291870117, memory reserved (GB): 18.306640625\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:773: UserWarning: When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict willbe returned.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m /usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:716: UserWarning: When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict willbe returned.\n",
+ "\u001b[36m(WorkerDict pid=24317)\u001b[0m warnings.warn(\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m validation generation end\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m <|im_start|>system\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m <|im_start|>user\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m After scoring 14 points, Erin now has three times more points than Sara, who scored 8. How many points did Erin have before? Let's think step by step and output the final answer after \"####\".<|im_end|>\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m <|im_start|>assistant\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m To determine how many points Erin had before scoring 14 points, we need to follow these steps:\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 1. **Identify the current score of Sara**: Sara scored 8 points.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 2. **Determine the current score of Erin**: Erin has three times more points than Sara.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 3. **Calculate Erin's current score**: Since Erin has three times more points than Sara, we multiply Sara's score by 3.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\[\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\text{Erin's current score} = 3 \\times 8 = 24\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\]\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m 4. **Find the initial score of Erin**: Since Erin's current score is 24 points, we subtract this from her current score to find out how many points she had before scoring 14 points.\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\[\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\text{Erin's initial score} = 24 - 14 = 10\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \\]\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m Therefore, Erin had \\(\\boxed{10}\\) points before scoring 14 points.<|im_end|>\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m \"Initial validation metrics: {'test_score/openai/gsm8k': 0.0}\"\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:0 - timing/gen:32.204 - timing/ref:4.338 - timing/values:16.998 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.103 - timing/update_critic:46.955 - critic/vf_loss:9.680 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.392 - critic/grad_norm:810.800 - critic/lr(1e-4):0.100 - timing/update_actor:47.846 - actor/entropy_loss:0.450 - actor/pg_loss:0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:2.205 - actor/lr(1e-4):0.010 - critic/score/mean:0.008 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.008 - critic/rewards/max:1.000 - critic/rewards/min:-0.001 - critic/advantages/mean:-0.000 - critic/advantages/max:2.247 - critic/advantages/min:-3.419 - critic/returns/mean:0.005 - critic/returns/max:0.001 - critic/returns/min:-0.001 - critic/values/mean:-2.422 - critic/values/max:10.125 - critic/values/min:-10.688 - response_length/mean:235.449 - response_length/max:256.000 - response_length/min:51.000 - prompt_length/mean:104.625 - prompt_length/max:196.000 - prompt_length/min:63.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:1 - timing/gen:30.479 - timing/ref:4.295 - timing/values:17.340 - critic/kl:-0.000 - critic/kl_coeff:0.001 - timing/adv:0.099 - timing/update_critic:48.050 - critic/vf_loss:12.785 - critic/vf_clipfrac:0.337 - critic/vpred_mean:1.318 - critic/grad_norm:868.434 - critic/lr(1e-4):0.100 - timing/update_actor:48.040 - actor/entropy_loss:0.424 - actor/pg_loss:0.010 - actor/pg_clipfrac:0.001 - actor/ppo_kl:0.000 - actor/grad_norm:2.179 - actor/lr(1e-4):0.010 - critic/score/mean:0.008 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.008 - critic/rewards/max:1.000 - critic/rewards/min:-0.002 - critic/advantages/mean:-0.000 - critic/advantages/max:2.990 - critic/advantages/min:-3.271 - critic/returns/mean:0.007 - critic/returns/max:0.001 - critic/returns/min:-0.000 - critic/values/mean:-2.375 - critic/values/max:9.500 - critic/values/min:-13.250 - response_length/mean:230.027 - response_length/max:256.000 - response_length/min:24.000 - prompt_length/mean:103.516 - prompt_length/max:189.000 - prompt_length/min:66.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:2 - timing/gen:30.691 - timing/ref:4.358 - timing/values:17.346 - critic/kl:-0.001 - critic/kl_coeff:0.001 - timing/adv:0.100 - timing/update_critic:47.085 - critic/vf_loss:5.529 - critic/vf_clipfrac:0.446 - critic/vpred_mean:1.411 - critic/grad_norm:292.781 - critic/lr(1e-4):0.100 - timing/update_actor:48.152 - actor/entropy_loss:0.453 - actor/pg_loss:-0.016 - actor/pg_clipfrac:0.002 - actor/ppo_kl:-0.000 - actor/grad_norm:8.181 - actor/lr(1e-4):0.010 - critic/score/mean:0.008 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.008 - critic/rewards/max:1.000 - critic/rewards/min:-0.002 - critic/advantages/mean:0.000 - critic/advantages/max:3.292 - critic/advantages/min:-3.259 - critic/returns/mean:0.005 - critic/returns/max:0.000 - critic/returns/min:-0.000 - critic/values/mean:2.250 - critic/values/max:11.062 - critic/values/min:-6.656 - response_length/mean:229.094 - response_length/max:256.000 - response_length/min:4.000 - prompt_length/mean:103.168 - prompt_length/max:189.000 - prompt_length/min:70.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:3 - timing/gen:30.730 - timing/ref:4.299 - timing/values:17.278 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.100 - timing/update_critic:47.572 - critic/vf_loss:1.733 - critic/vf_clipfrac:0.317 - critic/vpred_mean:0.059 - critic/grad_norm:169.059 - critic/lr(1e-4):0.100 - timing/update_actor:47.580 - actor/entropy_loss:0.446 - actor/pg_loss:0.001 - actor/pg_clipfrac:0.001 - actor/ppo_kl:-0.000 - actor/grad_norm:2.593 - actor/lr(1e-4):0.010 - critic/score/mean:0.008 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.008 - critic/rewards/max:1.000 - critic/rewards/min:-0.002 - critic/advantages/mean:0.000 - critic/advantages/max:2.847 - critic/advantages/min:-3.170 - critic/returns/mean:0.008 - critic/returns/max:0.000 - critic/returns/min:-0.000 - critic/values/mean:0.758 - critic/values/max:6.281 - critic/values/min:-4.219 - response_length/mean:236.652 - response_length/max:256.000 - response_length/min:10.000 - prompt_length/mean:103.527 - prompt_length/max:202.000 - prompt_length/min:70.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:4 - timing/gen:31.015 - timing/ref:4.392 - timing/values:17.348 - critic/kl:0.002 - critic/kl_coeff:0.001 - timing/adv:0.100 - timing/update_critic:47.830 - critic/vf_loss:0.749 - critic/vf_clipfrac:0.177 - critic/vpred_mean:0.267 - critic/grad_norm:89.965 - critic/lr(1e-4):0.100 - timing/update_actor:48.226 - actor/entropy_loss:0.435 - actor/pg_loss:0.014 - actor/pg_clipfrac:0.001 - actor/ppo_kl:0.000 - actor/grad_norm:6.182 - actor/lr(1e-4):0.010 - critic/score/mean:0.008 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.008 - critic/rewards/max:0.999 - critic/rewards/min:-0.002 - critic/advantages/mean:0.000 - critic/advantages/max:2.795 - critic/advantages/min:-4.796 - critic/returns/mean:0.006 - critic/returns/max:0.001 - critic/returns/min:-0.001 - critic/values/mean:0.471 - critic/values/max:6.438 - critic/values/min:-3.016 - response_length/mean:232.574 - response_length/max:256.000 - response_length/min:4.000 - prompt_length/mean:102.988 - prompt_length/max:201.000 - prompt_length/min:69.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:5 - timing/gen:30.906 - timing/ref:4.306 - timing/values:17.189 - critic/kl:0.019 - critic/kl_coeff:0.001 - timing/adv:0.100 - timing/update_critic:47.114 - critic/vf_loss:0.420 - critic/vf_clipfrac:0.178 - critic/vpred_mean:0.146 - critic/grad_norm:76.500 - critic/lr(1e-4):0.100 - timing/update_actor:47.928 - actor/entropy_loss:0.488 - actor/pg_loss:0.026 - actor/pg_clipfrac:0.001 - actor/ppo_kl:-0.001 - actor/grad_norm:13.241 - actor/lr(1e-4):0.010 - critic/score/mean:0.016 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.015 - critic/rewards/max:1.000 - critic/rewards/min:-0.005 - critic/advantages/mean:0.000 - critic/advantages/max:4.234 - critic/advantages/min:-2.572 - critic/returns/mean:0.013 - critic/returns/max:0.001 - critic/returns/min:-0.002 - critic/values/mean:0.424 - critic/values/max:2.797 - critic/values/min:-3.516 - response_length/mean:219.227 - response_length/max:256.000 - response_length/min:5.000 - prompt_length/mean:101.426 - prompt_length/max:180.000 - prompt_length/min:69.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:6 - timing/gen:30.797 - timing/ref:4.373 - timing/values:17.376 - critic/kl:0.021 - critic/kl_coeff:0.001 - timing/adv:0.117 - timing/update_critic:48.074 - critic/vf_loss:0.273 - critic/vf_clipfrac:0.094 - critic/vpred_mean:-0.124 - critic/grad_norm:65.489 - critic/lr(1e-4):0.100 - timing/update_actor:48.469 - actor/entropy_loss:0.493 - actor/pg_loss:0.004 - actor/pg_clipfrac:0.006 - actor/ppo_kl:0.002 - actor/grad_norm:6.809 - actor/lr(1e-4):0.010 - critic/score/mean:0.012 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.011 - critic/rewards/max:0.999 - critic/rewards/min:-0.006 - critic/advantages/mean:-0.000 - critic/advantages/max:3.729 - critic/advantages/min:-5.827 - critic/returns/mean:0.008 - critic/returns/max:0.001 - critic/returns/min:-0.001 - critic/values/mean:0.148 - critic/values/max:4.438 - critic/values/min:-2.609 - response_length/mean:215.906 - response_length/max:256.000 - response_length/min:4.000 - prompt_length/mean:103.758 - prompt_length/max:183.000 - prompt_length/min:69.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:7 - timing/gen:31.524 - timing/ref:4.330 - timing/values:17.490 - critic/kl:0.001 - critic/kl_coeff:0.001 - timing/adv:0.113 - timing/update_critic:47.659 - critic/vf_loss:0.169 - critic/vf_clipfrac:0.084 - critic/vpred_mean:0.105 - critic/grad_norm:62.722 - critic/lr(1e-4):0.100 - timing/update_actor:48.111 - actor/entropy_loss:0.442 - actor/pg_loss:-0.008 - actor/pg_clipfrac:0.001 - actor/ppo_kl:0.000 - actor/grad_norm:2.453 - actor/lr(1e-4):0.010 - critic/score/mean:0.016 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.016 - critic/rewards/max:1.000 - critic/rewards/min:-0.003 - critic/advantages/mean:0.000 - critic/advantages/max:5.417 - critic/advantages/min:-3.238 - critic/returns/mean:0.013 - critic/returns/max:0.000 - critic/returns/min:-0.001 - critic/values/mean:0.328 - critic/values/max:2.062 - critic/values/min:-2.609 - response_length/mean:230.543 - response_length/max:256.000 - response_length/min:63.000 - prompt_length/mean:102.227 - prompt_length/max:215.000 - prompt_length/min:63.000\n",
+ "\u001b[36m(main_task pid=24124)\u001b[0m step:8 - timing/gen:30.809 - timing/ref:4.268 - timing/values:17.369 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.100 - timing/update_critic:48.183 - critic/vf_loss:0.098 - critic/vf_clipfrac:0.019 - critic/vpred_mean:0.028 - critic/grad_norm:38.603 - critic/lr(1e-4):0.100 - timing/update_actor:48.014 - actor/entropy_loss:0.410 - actor/pg_loss:-0.004 - actor/pg_clipfrac:0.001 - actor/ppo_kl:0.000 - actor/grad_norm:2.225 - actor/lr(1e-4):0.010 - critic/score/mean:0.012 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.012 - critic/rewards/max:1.000 - critic/rewards/min:-0.002 - critic/advantages/mean:0.000 - critic/advantages/max:5.276 - critic/advantages/min:-3.639 - critic/returns/mean:0.008 - critic/returns/max:0.000 - critic/returns/min:-0.000 - critic/values/mean:0.034 - critic/values/max:1.477 - critic/values/min:-2.078 - response_length/mean:232.293 - response_length/max:256.000 - response_length/min:54.000 - prompt_length/mean:103.660 - prompt_length/max:183.000 - prompt_length/min:66.000\n"
+ ]
+ }
+ ],
+ "source": [
+ "!PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \\\n",
+ " data.train_files=$HOME/data/gsm8k/train.parquet \\\n",
+ " data.val_files=$HOME/data/gsm8k/test.parquet \\\n",
+ " data.train_batch_size=256 \\\n",
+ " data.val_batch_size=1312 \\\n",
+ " data.max_prompt_length=512 \\\n",
+ " data.max_response_length=256 \\\n",
+ " actor_rollout_ref.model.path=$HOME/models/Qwen2.5-0.5B-Instruct \\\n",
+ " actor_rollout_ref.actor.optim.lr=1e-6 \\\n",
+ " actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n",
+ " actor_rollout_ref.actor.ppo_micro_batch_size=1 \\\n",
+ " actor_rollout_ref.rollout.log_prob_micro_batch_size=1 \\\n",
+ " actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n",
+ " actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n",
+ " actor_rollout_ref.ref.log_prob_micro_batch_size=4 \\\n",
+ " critic.optim.lr=1e-5 \\\n",
+ " critic.model.path=$HOME/models/Qwen2.5-0.5B-Instruct \\\n",
+ " critic.ppo_micro_batch_size=1 \\\n",
+ " algorithm.kl_ctrl.kl_coef=0.001 \\\n",
+ " +trainer.val_before_train=False \\\n",
+ " trainer.default_hdfs_dir='' \\\n",
+ " trainer.n_gpus_per_node=1 \\\n",
+ " trainer.nnodes=1 \\\n",
+ " trainer.save_freq=10 \\\n",
+ " trainer.test_freq=10 \\\n",
+ " trainer.total_epochs=15 \\\n",
+ " trainer.logger=\\[console\\]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zSn7lNlZ2vfL"
+ },
+ "source": [
+ "# Stop and clean up resources"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "QuJ-LgdTAPkb",
+ "outputId": "c531239e-f4d2-4b6a-c0c5-0f3c96362b25"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/bin/bash: line 1: ray: command not found\n"
+ ]
+ }
+ ],
+ "source": [
+ "!ray stop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XxhdsvVR0WxI"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "A100",
+ "provenance": [],
+ "machine_shape": "hm",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file