diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py index e5aafe2732..d183825da0 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os import sys from dataclasses import dataclass from itertools import product @@ -12,6 +13,8 @@ import numpy as np import pytest from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix +max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64)) + def dtype_to_suffix(dtype): return { @@ -85,7 +88,7 @@ dtypes = ["float32", "float16"] all_transabs = list(product([True, False], repeat=2)) -@pytest.mark.parametrize("batch", [1, 64]) +@pytest.mark.parametrize("batch", [1, max_batch_size]) @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) @@ -95,7 +98,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): # Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests -@pytest.mark.parametrize("batch", [1, 64]) +@pytest.mark.parametrize("batch", [1, max_batch_size]) @pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py index b54fa31cbf..312592e2f4 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os import sys from dataclasses import dataclass from itertools import product @@ -12,6 +13,8 @@ import numpy as np import pytest from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix +max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64)) + def dtype_to_suffix(dtype): return { @@ -91,7 +94,7 @@ dtypes = ["float32", "float16"] all_transabs = list(product([True, False], repeat=2)) -@pytest.mark.parametrize("batch", [1, 64]) +@pytest.mark.parametrize("batch", [1, max_batch_size]) @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) @@ -101,7 +104,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): @pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") -@pytest.mark.parametrize("batch", [1, 64]) +@pytest.mark.parametrize("batch", [1, max_batch_size]) @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) @@ -111,7 +114,7 @@ def test_ck_gemm_all_cases(dtype, transa, transb, m, n, k, batch): # Tunable is basically wrapped around of rocblas and ck gemm, so no need for full tests -@pytest.mark.parametrize("batch", [1, 64]) +@pytest.mark.parametrize("batch", [1, max_batch_size]) @pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index d97b3b5e2b..b6ca713503 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -4,6 +4,7 @@ import argparse import logging +import os import sys from _test_commons import run_subprocess @@ -178,15 +179,17 @@ def main(): run_ortmodule_poc_net(cwd, log, no_cuda=False, data_dir=args.mnist) - run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist) + if os.getenv("ORTMODULE_DISABLE_CPU_TRAINING_TEST", "0") != "1": + run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist) run_ortmodule_hf_bert_for_sequence_classification_from_pretrained( cwd, log, no_cuda=False, data_dir=args.bert_data, transformers_cache=args.transformers_cache ) - run_ortmodule_hf_bert_for_sequence_classification_from_pretrained( - cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache - ) + if os.getenv("ORTMODULE_DISABLE_CPU_TRAINING_TEST", "0") != "1": + run_ortmodule_hf_bert_for_sequence_classification_from_pretrained( + cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache + ) run_ortmodule_torch_lightning(cwd, log, args.mnist) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index d7eadc56db..625c1ce0d4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6051,7 +6051,11 @@ def test_e2e_padding_elimination(): if pt_param.grad is not None: _test_helpers.assert_values_are_close(pt_param.grad, ort_param.grad, atol=1e-4, rtol=1e-5) - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) + if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": + # For ROCm EP, the difference between ORT and PyTorch is larger than CUDA EP. + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=2e-3, rtol=2e-4) + else: + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index da042bc339..4e4073ae84 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -137,7 +137,7 @@ jobs: inputs: script: |- echo "Select agent: $(Agent.Name), GPU: $HIP_VISIBLE_DEVICES, render: $DRIVER_RENDER" - bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_get_thread.sh $(Agent.Name) $HIP_VISIBLE_DEVICES + bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh $(Agent.Name) $HIP_VISIBLE_DEVICES workingDirectory: $(Build.SourcesDirectory) displayName: 'Check ROCm Environment' @@ -162,7 +162,7 @@ jobs: /onnxruntime_src/tools/ci_build/github/pai/pai_test_launcher.sh" workingDirectory: $(Build.SourcesDirectory) displayName: 'Run onnxruntime unit tests' - condition: succeededOrFailed() + condition: succeeded() - task: CmdLine@2 inputs: @@ -181,6 +181,7 @@ jobs: /bin/bash -c " set -ex; \ export KERNEL_EXPLORER_BUILD_DIR=/build/$(BuildConfig); \ + export KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE=8; \ pytest /onnxruntime_src/onnxruntime/python/tools/kernel_explorer/ -n 8 --reruns 1 --durations=100" workingDirectory: $(Build.SourcesDirectory) displayName: 'Run kernel explorer tests' @@ -240,6 +241,8 @@ jobs: pip install /build/$(BuildConfig)/dist/$whlfilename; \ python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install; \ mkdir /home/onnxruntimedev/mnist /home/onnxruntimedev/bert_data; \ + export ORTMODULE_DISABLE_CPU_TRAINING_TEST=1; \ + export ORTMODULE_ROCM_TEST=1; \ python orttraining_ortmodule_tests.py \ --mnist /home/onnxruntimedev/mnist \ --bert_data /home/onnxruntimedev/bert_data/hf_data/glue_data/CoLA/original/raw" @@ -247,4 +250,12 @@ jobs: displayName: 'Run orttraining_ortmodule_tests.py' condition: succeededOrFailed() + - task: CmdLine@2 + inputs: + script: |- + bash $(Build.SourcesDirectory)/tools/ci_build/github/pai/pai_clean_device.sh $(Agent.Name) $HIP_VISIBLE_DEVICES + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Clean ROCm Environment' + condition: always() + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/pai/pai_get_thread.sh b/tools/ci_build/github/pai/pai_clean_device.sh similarity index 100% rename from tools/ci_build/github/pai/pai_get_thread.sh rename to tools/ci_build/github/pai/pai_clean_device.sh