mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[ROCm] Optimize ROCm CI to reduce time (#16620)
This PR mainly optimize ROCm CI test to reduce time and CPU utilization. - use smaller batch size on strided_batched_gemm/batched_gemm test - disable cpu training test - fix test_e2e_padding_elimination Occasional failures on ROCm.
This commit is contained in:
parent
af89496fc7
commit
ebc311365b
6 changed files with 36 additions and 12 deletions
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
|
@ -12,6 +13,8 @@ import numpy as np
|
|||
import pytest
|
||||
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix
|
||||
|
||||
max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64))
|
||||
|
||||
|
||||
def dtype_to_suffix(dtype):
|
||||
return {
|
||||
|
|
@ -85,7 +88,7 @@ dtypes = ["float32", "float16"]
|
|||
all_transabs = list(product([True, False], repeat=2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch", [1, 64])
|
||||
@pytest.mark.parametrize("batch", [1, max_batch_size])
|
||||
@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transa, transb", all_transabs)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
|
|
@ -95,7 +98,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch):
|
|||
|
||||
|
||||
# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests
|
||||
@pytest.mark.parametrize("batch", [1, 64])
|
||||
@pytest.mark.parametrize("batch", [1, max_batch_size])
|
||||
@pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transa, transb", all_transabs)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
|
@ -12,6 +13,8 @@ import numpy as np
|
|||
import pytest
|
||||
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix
|
||||
|
||||
max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64))
|
||||
|
||||
|
||||
def dtype_to_suffix(dtype):
|
||||
return {
|
||||
|
|
@ -91,7 +94,7 @@ dtypes = ["float32", "float16"]
|
|||
all_transabs = list(product([True, False], repeat=2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch", [1, 64])
|
||||
@pytest.mark.parametrize("batch", [1, max_batch_size])
|
||||
@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transa, transb", all_transabs)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
|
|
@ -101,7 +104,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled")
|
||||
@pytest.mark.parametrize("batch", [1, 64])
|
||||
@pytest.mark.parametrize("batch", [1, max_batch_size])
|
||||
@pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transa, transb", all_transabs)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
|
|
@ -111,7 +114,7 @@ def test_ck_gemm_all_cases(dtype, transa, transb, m, n, k, batch):
|
|||
|
||||
|
||||
# Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests
|
||||
@pytest.mark.parametrize("batch", [1, 64])
|
||||
@pytest.mark.parametrize("batch", [1, max_batch_size])
|
||||
@pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False))
|
||||
@pytest.mark.parametrize("transa, transb", all_transabs)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from _test_commons import run_subprocess
|
||||
|
|
@ -178,15 +179,17 @@ def main():
|
|||
|
||||
run_ortmodule_poc_net(cwd, log, no_cuda=False, data_dir=args.mnist)
|
||||
|
||||
run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist)
|
||||
if os.getenv("ORTMODULE_DISABLE_CPU_TRAINING_TEST", "0") != "1":
|
||||
run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist)
|
||||
|
||||
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(
|
||||
cwd, log, no_cuda=False, data_dir=args.bert_data, transformers_cache=args.transformers_cache
|
||||
)
|
||||
|
||||
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(
|
||||
cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache
|
||||
)
|
||||
if os.getenv("ORTMODULE_DISABLE_CPU_TRAINING_TEST", "0") != "1":
|
||||
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(
|
||||
cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache
|
||||
)
|
||||
|
||||
run_ortmodule_torch_lightning(cwd, log, args.mnist)
|
||||
|
||||
|
|
|
|||
|
|
@ -6051,7 +6051,11 @@ def test_e2e_padding_elimination():
|
|||
if pt_param.grad is not None:
|
||||
_test_helpers.assert_values_are_close(pt_param.grad, ort_param.grad, atol=1e-4, rtol=1e-5)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)
|
||||
if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1":
|
||||
# For ROCm EP, the difference between ORT and PyTorch is larger than CUDA EP.
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=2e-3, rtol=2e-4)
|
||||
else:
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)
|
||||
|
||||
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
|
||||
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ jobs:
|
|||
inputs:
|
||||
script: |-
|
||||
echo "Select agent: $(Agent.Name), GPU: $HIP_VISIBLE_DEVICES, render: $DRIVER_RENDER"
|
||||
bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_get_thread.sh $(Agent.Name) $HIP_VISIBLE_DEVICES
|
||||
bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh $(Agent.Name) $HIP_VISIBLE_DEVICES
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Check ROCm Environment'
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ jobs:
|
|||
/onnxruntime_src/tools/ci_build/github/pai/pai_test_launcher.sh"
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Run onnxruntime unit tests'
|
||||
condition: succeededOrFailed()
|
||||
condition: succeeded()
|
||||
|
||||
- task: CmdLine@2
|
||||
inputs:
|
||||
|
|
@ -181,6 +181,7 @@ jobs:
|
|||
/bin/bash -c "
|
||||
set -ex; \
|
||||
export KERNEL_EXPLORER_BUILD_DIR=/build/$(BuildConfig); \
|
||||
export KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE=8; \
|
||||
pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 8 --reruns 1 --durations=100"
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Run kernel explorer tests'
|
||||
|
|
@ -240,6 +241,8 @@ jobs:
|
|||
pip install /build/$(BuildConfig)/dist/$whlfilename; \
|
||||
python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install; \
|
||||
mkdir /home/onnxruntimedev/mnist /home/onnxruntimedev/bert_data; \
|
||||
export ORTMODULE_DISABLE_CPU_TRAINING_TEST=1; \
|
||||
export ORTMODULE_ROCM_TEST=1; \
|
||||
python orttraining_ortmodule_tests.py \
|
||||
--mnist /home/onnxruntimedev/mnist \
|
||||
--bert_data /home/onnxruntimedev/bert_data/hf_data/glue_data/CoLA/original/raw"
|
||||
|
|
@ -247,4 +250,12 @@ jobs:
|
|||
displayName: 'Run orttraining_ortmodule_tests.py'
|
||||
condition: succeededOrFailed()
|
||||
|
||||
- task: CmdLine@2
|
||||
inputs:
|
||||
script: |-
|
||||
bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh $(Agent.Name) $HIP_VISIBLE_DEVICES
|
||||
workingDirectory: $(Build.SourcesDirectory)
|
||||
displayName: 'Clean ROCm Environment'
|
||||
condition: always()
|
||||
|
||||
- template: templates/clean-agent-build-directory-step.yml
|
||||
|
|
|
|||
Loading…
Reference in a new issue