onnxruntime/orttraining/tools/ci_test/run_gpt2_perf_test.py
Justin Chu be7541ef4a
[Linter] Bump ruff and remove pylint (#17797)
Bump ruff version and remove pylint from the linter list. Fix any new
error detected by ruff.

### Motivation and Context

Ruff covers many of the pylint rules. Since pylint is not enabled in
this repo and runs slow, we remove it from the linters
2023-10-05 21:07:33 -07:00

71 lines
2.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import os
import subprocess
import sys
from collections import namedtuple
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
def parse_args():
parser = argparse.ArgumentParser(description="Runs GPT-2 performance tests.")
parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.")
parser.add_argument("--training_data_root", required=True, help="Path to the training data root directory.")
parser.add_argument("--model_root", required=True, help="Path to the model root directory.")
return parser.parse_args()
# TODO - review to finalize params
def main():
args = parse_args()
Config = namedtuple("Config", ["use_mixed_precision", "max_seq_length", "batch_size"])
configs = [Config(True, 1024, 1), Config(False, 1024, 1)]
# run GPT-2 training
for c in configs:
print(
"######## testing name - "
+ ("fp16-" if c.use_mixed_precision else "fp32-")
+ str(c.max_seq_length)
+ " ##############"
)
cmds = [
os.path.join(args.binary_dir, "onnxruntime_training_gpt2"),
"--model_name",
os.path.join(
args.model_root,
"megatron-gpt2_hidden-size-1024_num-layers-24_vocab-size-50257_num-attention-heads-16_max-position-embeddings-1024_optimized_opset12",
),
"--train_data_dir",
os.path.join(args.training_data_root, "train"),
"--test_data_dir",
os.path.join(args.training_data_root, "test"),
"--train_batch_size",
str(c.batch_size),
"--mode",
"train",
"--max_seq_length",
str(c.max_seq_length),
"--num_train_steps",
"640",
"--gradient_accumulation_steps",
"1",
"--perf_output_dir",
os.path.join(SCRIPT_DIR, "results"),
]
if c.use_mixed_precision:
cmds.append("--use_mixed_precision"),
subprocess.run(cmds).check_returncode() # noqa: PLW1510
return 0
if __name__ == "__main__":
sys.exit(main())