onnxruntime/tools/ci_build/github/pai/run_job.py
Weixing Zhang aec4cb489e
ROCm EP for AMD GPU (#5480)
The ROCm EP is designed and implemented based on AMD GPU software stack named ROCm. Here is the link for the details about ROCm: https://rocmdocs.amd.com/en/latest/

ROCm EP was created based on the following things:
1. AMD GPU programming language: HIP
2. AMD GPU HIP language runtime: amdhip64
3. BLAS: rocBLAS, hipBLAS
4. DNN: miOpen
5. Collective Communication library: RCCL
6. cub: hipCub
7. …

Current status:
BERT-L and GPT2 training can be ran on AMD GPU with data parallel.

Next:
1. Make more GPU code be sharable between ROCm EP and CUDA EP since HIP language and HIP runtime API are very close to CUDA.
2. Continue improving the implementation.
3. Continue GPU kernel optimization.
4. Support model parallelism on ROCm EP.
……

The rocm kernels have been removed from this commit and will be in a separate PR. Since the original PR was too big(~180 files), it was suggested to split the PR into two parts, one is rocm-kernels, the other is non rocm kernels.  

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
Co-authored-by: sabreshao <sabre.shao@amd.com>
Co-authored-by: anghostcici <11013544+anghostcici@users.noreply.github.com>
Co-authored-by: Suffian Khan <sukha@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2020-10-29 17:13:04 -07:00

106 lines
3.3 KiB
Python
Executable file

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import json
import os
import re
import sys
import time
import requests
pai_base_url = "https://rr.openpai.org/rest-server"
def parse_args():
parser = argparse.ArgumentParser(description="Runs a job on PAI.")
parser.add_argument("job_yaml_file", help="The job YAML file.")
parser.add_argument("job_name", help="The job name.")
parser.add_argument("--user-env", required=True, help="Environment variable containing the user name.")
parser.add_argument("--token-env", required=True, help="Environment variable containing the authorization token.")
parser.add_argument("--yaml-sub-env", action="append", nargs=2,
help="YAML substitution key and environment variable containing the value.")
return parser.parse_args()
def get_yaml_text_with_substitutions(job_yaml_file_path, substitutions):
substitution_pattern = re.compile(r"@@(\w+)@@")
def replace(match):
if match[1] in substitutions:
return substitutions[match[1]]
print("Warning - no substitution was provided for '{}'.".format(match[0]))
return match[0]
with open(job_yaml_file_path, mode="r") as yaml_file:
return re.sub(substitution_pattern, replace, yaml_file.read())
def submit_job(yaml, token):
url = "{}/api/v2/jobs".format(pai_base_url)
headers = {
"Authorization": "Bearer {}".format(token),
"Content-Type": "text/yaml",
}
response = requests.post(url=url, data=yaml, headers=headers)
response.raise_for_status()
def wait_for_job(job_name, user, token):
url = "{}/api/v2/jobs/{}~{}".format(pai_base_url, user, job_name)
headers = {
"Authorization": "Bearer {}".format(token),
}
while True:
response = requests.get(url=url, headers=headers)
response.raise_for_status()
response_json = response.json()
job_status = response_json["jobStatus"]["state"]
if job_status in ["WAITING", "RUNNING"]:
time.sleep(30)
elif job_status == "SUCCEEDED":
break
else:
raise RuntimeError("Job failed. Status query response:\n{}".format(json.dumps(response_json, indent=2)))
def main():
args = parse_args()
substitutions = {
"job_name": args.job_name,
}
if args.yaml_sub_env is not None:
substitutions.update({kvp[0]: os.environ.get(kvp[1], "") for kvp in args.yaml_sub_env})
yaml_with_substitutions = get_yaml_text_with_substitutions(args.job_yaml_file, substitutions)
user = os.environ[args.user_env]
token = os.environ[args.token_env]
print("Submitting job {} ..".format(args.job_name))
sys.stdout.flush()
submit_job(yaml_with_substitutions, token)
print('See https://rr.openpai.org/job-detail.html?username={}&jobName={}'.format(user, args.job_name))
sys.stdout.flush()
print('\nWarning: The following tests will be excluded:')
with open('tools/ci_build/github/pai/pai-excluded-tests.txt', 'r') as fin:
print(fin.read())
print('')
print("Waiting for job to complete ..")
sys.stdout.flush()
wait_for_job(args.job_name, user, token)
if __name__ == "__main__":
main()