From a6279d4cfb51be98a5cc25dc642013e358fcd01f Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Wed, 29 Mar 2023 15:19:52 +0800 Subject: [PATCH] [ROCm] update Stable Diffusion benchmark to support ROCm EP (#15094) Update Stable Diffusion benchmark to support ROCm EP --- .../models/stable_diffusion/benchmark.py | 96 ++++++++++++++++--- .../stable_diffusion/optimize_pipeline.py | 10 +- .../stable_diffusion/requirements-rocm.txt | 17 ++++ 3 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index a1d2247639..99704d9cfd 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -16,6 +16,11 @@ SD_MODELS = { "2.1": "stabilityai/stable-diffusion-2-1", } +PROVIDERS = { + "cuda": "CUDAExecutionProvider", + "rocm": "ROCMExecutionProvider", +} + def example_prompts(): prompts = [ @@ -187,7 +192,16 @@ def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, dis def run_ort_pipeline( - pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count, start_memory + pipe, + batch_size: int, + image_filename_prefix: str, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + enable_mem_measure, ): from diffusers import OnnxStableDiffusionPipeline @@ -199,8 +213,11 @@ def run_ort_pipeline( pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) - first_run_memory = measure_gpu_memory(warmup, start_memory) - second_run_memory = measure_gpu_memory(warmup, start_memory) + first_run_memory = measure_gpu_memory(warmup, start_memory) if enable_mem_measure else -1 + second_run_memory = measure_gpu_memory(warmup, start_memory) if enable_mem_measure else -1 + + if not enable_mem_measure: + warmup() latency_list = [] for i, prompt in enumerate(prompts): @@ -243,19 +260,31 @@ def run_ort_pipeline( def run_torch_pipeline( - pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count, start_memory + pipe, + batch_size: int, + image_filename_prefix: str, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + enable_mem_measure, ): import torch prompts = example_prompts() - # total 2 runs of warm up, and measure GPU memory + # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) - first_run_memory = measure_gpu_memory(warmup, start_memory) - second_run_memory = measure_gpu_memory(warmup, start_memory) + first_run_memory = measure_gpu_memory(warmup, start_memory) if enable_mem_measure else -1 + second_run_memory = measure_gpu_memory(warmup, start_memory) if enable_mem_measure else -1 + + if not enable_mem_measure: + warmup() torch.set_grad_enabled(False) @@ -313,6 +342,7 @@ def run_ort( num_prompts, batch_count, start_memory, + enable_mem_measure, ): load_start = time.time() pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker) @@ -321,7 +351,16 @@ def run_ort( image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) result = run_ort_pipeline( - pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count, start_memory + pipe, + batch_size, + image_filename_prefix, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + enable_mem_measure, ) result.update( @@ -347,6 +386,7 @@ def run_torch( num_prompts, batch_count, start_memory, + enable_mem_measure, ): import torch @@ -365,11 +405,29 @@ def run_torch( if not enable_torch_compile: with torch.inference_mode(): result = run_torch_pipeline( - pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count, start_memory + pipe, + batch_size, + image_filename_prefix, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + enable_mem_measure, ) else: result = run_torch_pipeline( - pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count, start_memory + pipe, + batch_size, + image_filename_prefix, + height, + width, + steps, + num_prompts, + batch_count, + start_memory, + enable_mem_measure, ) result.update( @@ -396,6 +454,16 @@ def parse_arguments(): help="Engines to benchmark. Default is onnxruntime.", ) + parser.add_argument( + "-r", + "--provider", + required=False, + type=str, + default="cuda", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CUDAExecutionProvider.", + ) + parser.add_argument( "-v", "--version", @@ -500,14 +568,16 @@ def main(): args = parse_arguments() print(args) - start_memory = measure_gpu_memory(None) + enable_mem_measure = args.provider == "cuda" + + start_memory = measure_gpu_memory(None) if enable_mem_measure else -1 print("GPU memory used before loading models:", start_memory) sd_model = SD_MODELS[args.version] + provider = PROVIDERS[args.provider] if args.engine == "onnxruntime": assert args.pipeline, "--pipeline should be specified for onnxruntime engine" - provider = "CUDAExecutionProvider" result = run_ort( sd_model, args.pipeline, @@ -520,6 +590,7 @@ def main(): args.num_prompts, args.batch_count, start_memory, + enable_mem_measure, ) else: result = run_torch( @@ -534,6 +605,7 @@ def main(): args.num_prompts, args.batch_count, start_memory, + enable_mem_measure, ) print(result) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 46e46accf9..8e78635093 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -17,6 +17,9 @@ # # If you are using nightly package (or built from source), you can force MultiHeadAttention to run in float32: # python optimize_pipeline.py -i ./sd-v2-1 -o ./sd-v2-1-fp16 --float16 --force_fp32_ops unet:MultiHeadAttention +# +# ROCm EP doesn't support MultiHeadAttention, add --disable_attention to disable attention fusion: +# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 --disable_attention import argparse import logging @@ -51,6 +54,7 @@ def optimize_sd_pipeline( float16: bool, force_fp32_ops: List[str], enable_runtime_optimization: bool, + args, ): """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16. @@ -123,7 +127,8 @@ def optimize_sd_pipeline( # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"Optimize {onnx_model_path}...") - fusion_options = FusionOptions(model_type) + args.model_type = model_type + fusion_options = FusionOptions.parse(args) if model_type in ["unet"]: # Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd @@ -286,6 +291,8 @@ def parse_arguments(): ) parser.set_defaults(use_external_data_format=False) + FusionOptions.add_arguments(parser) + args = parser.parse_args() return args @@ -303,6 +310,7 @@ def main(): args.float16, args.force_fp32_ops, args.inspect, + args, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt new file mode 100644 index 0000000000..b632e38191 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-rocm.txt @@ -0,0 +1,17 @@ +transformers==4.26.0 +numpy==1.24.1 +accelerate==0.15.0 +onnx==1.13.0 +coloredlogs +packaging==23.0 +protobuf==3.20.3 +psutil==5.9.4 +sympy==1.11.1 + +# Install diffusers from source +# git clone https://github.com/huggingface/diffusers.git +# cd diffusers && git checkout c4892f1855097a68703ca2e949aca15829526958 +# pip install -e . + +# Install onnxruntime-rocm or onnxruntime_training +# Build onnxruntime-rocm from source or install lastest onnxruntime_training rocm nightly python package