onnxruntime/orttraining/tools/ci_test/run_gpt2_perf_test.py
Justin Chu c7c8757a1c
Use ruff as the formatter to replace black-isort (#23397)
Use ruff as the code formatter in place of black and isort since it is
much faster, and as projects like PyTorch and ONNX have adopted ruff
format as well.

This PR include only auto-fixed changes in formatting.
2025-01-16 11:14:15 -08:00

71 lines
2.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import os
import subprocess
import sys
from collections import namedtuple
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
def parse_args():
parser = argparse.ArgumentParser(description="Runs GPT-2 performance tests.")
parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.")
parser.add_argument("--training_data_root", required=True, help="Path to the training data root directory.")
parser.add_argument("--model_root", required=True, help="Path to the model root directory.")
return parser.parse_args()
# TODO - review to finalize params
def main():
args = parse_args()
Config = namedtuple("Config", ["use_mixed_precision", "max_seq_length", "batch_size"])
configs = [Config(True, 1024, 1), Config(False, 1024, 1)]
# run GPT-2 training
for c in configs:
print(
"######## testing name - "
+ ("fp16-" if c.use_mixed_precision else "fp32-")
+ str(c.max_seq_length)
+ " ##############"
)
cmds = [
os.path.join(args.binary_dir, "onnxruntime_training_gpt2"),
"--model_name",
os.path.join(
args.model_root,
"megatron-gpt2_hidden-size-1024_num-layers-24_vocab-size-50257_num-attention-heads-16_max-position-embeddings-1024_optimized_opset12",
),
"--train_data_dir",
os.path.join(args.training_data_root, "train"),
"--test_data_dir",
os.path.join(args.training_data_root, "test"),
"--train_batch_size",
str(c.batch_size),
"--mode",
"train",
"--max_seq_length",
str(c.max_seq_length),
"--num_train_steps",
"640",
"--gradient_accumulation_steps",
"1",
"--perf_output_dir",
os.path.join(SCRIPT_DIR, "results"),
]
if c.use_mixed_precision:
(cmds.append("--use_mixed_precision"),)
subprocess.run(cmds).check_returncode() # noqa: PLW1510
return 0
if __name__ == "__main__":
sys.exit(main())