onnxruntime/tools/ci_build/compile_triton.py

175 lines
5.5 KiB
Python
Raw Normal View History

integrate triton into ort (#15862) ### Description In some scenarios, the triton written kernels are more performant than CK or other handwritten kernels, so we implement a framework that onnxruntime can use these triton written kernels. This PR is to integrate triton into ort, so that ort can use kernels that written and compiled by triton. The main change focus on two part: 1. a build part to compile triton written kernel and combine these kernels into libonnxruntime_providers_rocm.so 2. a loader and launcher in c++, for loading and launch triton written kernels. #### Build To compile triton written kernel, add a script `tools/ci_build/compile_triton.py`. This script will dynamic load all kernel files, compile them, and generate `triton_kernel_infos.a` and `triton_kernel_infos.h`. `triton_kernel_infos.a` contains all compiled kernel instructions, this file will be combined into libonnxruntime_providers_rocm.so, using --whole-archive flag. `triton_kernel_infos.h` defines a const array that contains all the metadata for each compiled kernel. These metadata will be used for load and launch. So this header file is included by 'triton_kernel.cu' which defines load and launch functions. Add a build flag in build.py and CMakeList.txt, when building rocm provider, it will call triton_kernel build command, and generate all necessary files. #### C++ Load and Launch On c++ part, we implement load and launch functions in triton_kernel.cu and triton_kernel.h. These two files located in `providers/cuda`, and when compiling rocm, they will be hipified. so this part supports both cuda and rocm. But currently we only call triton kernel in rocm. We also implement a softmax triton op for example. Because there will generate many kernels for different input shape of softmax, we use TunableOp to select the best one. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
2023-05-17 01:35:28 +00:00
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import importlib.util
import os
import shutil
import triton
def compile(function_table, out_dir):
def compile_one(func, sig, **kwargs):
ret = triton.compile(func, signature=sig, **kwargs)
return ret
metadata = []
for func_desc in function_table:
name = func_desc["name"]
group = func_desc["group"]
sig = func_desc["sig"]
func = func_desc["func"]
kwargs = func_desc["kwargs"]
# print("compile func: ", func_desc)
ret = compile_one(func, sig, **kwargs)
compile_res = {}
compile_res["name"] = name
compile_res["group"] = group
compile_res["func_name"] = ret.metadata["name"]
compile_res["num_warps"] = ret.metadata["num_warps"]
compile_res["shared"] = ret.metadata["shared"]
if "constants" in kwargs:
compile_res["constants"] = kwargs["constants"]
# move tmp kernel file into current dir
if "hsaco_path" in ret.asm and os.path.exists(ret.asm["hsaco_path"]):
# is rocm
lib_name = f"{name}.hsaco"
shutil.copyfile(ret.asm["hsaco_path"], f"{out_dir}/{lib_name}")
elif "cubin" in ret.asm:
# is cuda
lib_name = f"{name}.cubin"
# need to write cubin into file
with open(f"{out_dir}/{lib_name}", "wb") as fp:
fp.write(ret.asm["cubin"])
else:
raise Exception("not find rocm or cuda compiled kernel")
compile_res["lib_file"] = lib_name
metadata.append(compile_res)
return metadata
def convert_lib_to_obj(lib_file, out_dir):
obj_file = lib_file.split(".")[0] + ".o"
command = f"cd {out_dir}; objcopy -I binary -O elf64-x86-64 -B i386:x86-64 {lib_file} {obj_file}; cd -"
ret = os.system(command)
if ret != 0:
raise Exception(f"exec convert command: {command} failed.")
# check file exist
if not os.path.exists(f"{out_dir}/{obj_file}"):
raise Exception(f"the output file not exist, after exec comamnd: {command}")
return obj_file
def archive_obj_files(obj_files, out_dir, out_obj_file):
obj_files = " ".join(obj_files)
command = f"cd {out_dir}; ar rcs {out_obj_file} {obj_files}; cd -"
ret = os.system(command)
if ret != 0:
raise Exception(f"exec convert command: {command} failed.")
# check file exist
if not os.path.exists(f"{out_dir}/{out_obj_file}"):
raise Exception(f"the output file not exist, after exec comamnd: {command}")
def convert_and_save(metadata, header_file, out_dir, out_obj_file):
c_metadata = []
binary_files = []
for m in metadata:
meta_ele = []
obj_file = convert_lib_to_obj(m["lib_file"], out_dir)
binary_files.append(obj_file)
lib_name = m["lib_file"].replace(".", "_")
meta_ele.append(f'"_binary_{lib_name}_start"')
meta_ele.append(f"\"{m['func_name']}\"")
meta_ele.append(f"\"{m['group']}\"")
meta_ele.append(f"\"{m['name']}\"")
meta_ele.append(str(m["num_warps"]))
meta_ele.append(str(m["shared"]))
# convert constants
constants = []
for k, v in m["constants"].items():
constants.append(f'{{ "{k}", {v!s}}}')
integrate triton into ort (#15862) ### Description In some scenarios, the triton written kernels are more performant than CK or other handwritten kernels, so we implement a framework that onnxruntime can use these triton written kernels. This PR is to integrate triton into ort, so that ort can use kernels that written and compiled by triton. The main change focus on two part: 1. a build part to compile triton written kernel and combine these kernels into libonnxruntime_providers_rocm.so 2. a loader and launcher in c++, for loading and launch triton written kernels. #### Build To compile triton written kernel, add a script `tools/ci_build/compile_triton.py`. This script will dynamic load all kernel files, compile them, and generate `triton_kernel_infos.a` and `triton_kernel_infos.h`. `triton_kernel_infos.a` contains all compiled kernel instructions, this file will be combined into libonnxruntime_providers_rocm.so, using --whole-archive flag. `triton_kernel_infos.h` defines a const array that contains all the metadata for each compiled kernel. These metadata will be used for load and launch. So this header file is included by 'triton_kernel.cu' which defines load and launch functions. Add a build flag in build.py and CMakeList.txt, when building rocm provider, it will call triton_kernel build command, and generate all necessary files. #### C++ Load and Launch On c++ part, we implement load and launch functions in triton_kernel.cu and triton_kernel.h. These two files located in `providers/cuda`, and when compiling rocm, they will be hipified. so this part supports both cuda and rocm. But currently we only call triton kernel in rocm. We also implement a softmax triton op for example. Because there will generate many kernels for different input shape of softmax, we use TunableOp to select the best one. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
2023-05-17 01:35:28 +00:00
meta_ele.append(f"{{ { ', '.join(constants) } }}")
c_metadata.append(f"{{ { ', '.join(meta_ele) } }}")
archive_obj_files(binary_files, out_dir, out_obj_file)
code = f"""
#include <unordered_map>
struct _TritonKernelInfo {{
const char* name_start;
const char* func_name;
const char* group_name;
const char* name;
int num_warps;
int shared;
std::unordered_map<std::string, int> constants;
}};
const _TritonKernelInfo kernel_infos[] = {{
{ ', '.join(c_metadata) },
}};
"""
with open(header_file, "w") as fp:
fp.write(code)
def main(args):
out_obj_file = args.obj_file
out_dir = os.path.dirname(out_obj_file)
out_obj_file = os.path.basename(out_obj_file)
if not os.path.exists(out_dir):
os.mkdir(out_dir)
metadata = []
print("[triton kernel] start compile triton kernel.")
for i, f in enumerate(args.script_files):
# import module in f, and call function
spec = importlib.util.spec_from_file_location(f"module_{i}", f)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
func_tb = module.get_function_table()
integrate triton into ort (#15862) ### Description In some scenarios, the triton written kernels are more performant than CK or other handwritten kernels, so we implement a framework that onnxruntime can use these triton written kernels. This PR is to integrate triton into ort, so that ort can use kernels that written and compiled by triton. The main change focus on two part: 1. a build part to compile triton written kernel and combine these kernels into libonnxruntime_providers_rocm.so 2. a loader and launcher in c++, for loading and launch triton written kernels. #### Build To compile triton written kernel, add a script `tools/ci_build/compile_triton.py`. This script will dynamic load all kernel files, compile them, and generate `triton_kernel_infos.a` and `triton_kernel_infos.h`. `triton_kernel_infos.a` contains all compiled kernel instructions, this file will be combined into libonnxruntime_providers_rocm.so, using --whole-archive flag. `triton_kernel_infos.h` defines a const array that contains all the metadata for each compiled kernel. These metadata will be used for load and launch. So this header file is included by 'triton_kernel.cu' which defines load and launch functions. Add a build flag in build.py and CMakeList.txt, when building rocm provider, it will call triton_kernel build command, and generate all necessary files. #### C++ Load and Launch On c++ part, we implement load and launch functions in triton_kernel.cu and triton_kernel.h. These two files located in `providers/cuda`, and when compiling rocm, they will be hipified. so this part supports both cuda and rocm. But currently we only call triton kernel in rocm. We also implement a softmax triton op for example. Because there will generate many kernels for different input shape of softmax, we use TunableOp to select the best one. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
2023-05-17 01:35:28 +00:00
m = compile(func_tb, out_dir)
metadata.extend(m)
print("[triton kernel] compile triton kernel done.")
# save metadata into header file
convert_and_save(metadata, args.header, out_dir, out_obj_file)
print("[triton kernel] save into file done.")
def get_arges():
parser = argparse.ArgumentParser(description="PyTorch Template Finetune Example")
parser.add_argument(
"--header", type=str, default="triton_kernel_infos.h", help="the header file that should be generated."
)
parser.add_argument("--ort_root", type=str, default="onnxruntime", help="the root dir of onnxruntime.")
parser.add_argument("--script_files", type=str, nargs="+", help="the root dir of onnxruntime.")
parser.add_argument("--obj_file", type=str, default="triton_kernel_infos.a", help="output target object files.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_arges()
main(args)