mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
Refactoring of Stable Diffusion scripts (#17138)
Reduce duplicated code in two stable diffusion pipelines (CUDA and TensorRT). Move the common code to models.py
This commit is contained in:
parent
5e971bc51a
commit
3aba736ee2
5 changed files with 594 additions and 716 deletions
|
|
@ -0,0 +1,368 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Models used in Stable diffusion.
|
||||
"""
|
||||
import logging
|
||||
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import torch
|
||||
from onnx import shape_inference
|
||||
from ort_optimizer import OrtStableDiffusionOptimizer
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrtOptimizer:
|
||||
def __init__(self, onnx_graph):
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def cleanup(self):
|
||||
self.graph.cleanup().toposort()
|
||||
|
||||
def get_optimized_onnx_graph(self):
|
||||
return gs.export_onnx(self.graph)
|
||||
|
||||
def select_outputs(self, keep, names=None):
|
||||
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
||||
if names:
|
||||
for i, name in enumerate(names):
|
||||
self.graph.outputs[i].name = name
|
||||
|
||||
def fold_constants(self):
|
||||
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def infer_shapes(self):
|
||||
onnx_graph = gs.export_onnx(self.graph)
|
||||
if onnx_graph.ByteSize() > 2147483648:
|
||||
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
||||
else:
|
||||
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
||||
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77):
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
|
||||
self.min_batch = 1
|
||||
self.max_batch = max_batch_size
|
||||
self.min_image_shape = 256 # min image resolution: 256x256
|
||||
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
||||
self.min_latent_shape = self.min_image_shape // 8
|
||||
self.max_latent_shape = self.max_image_shape // 8
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_maxlen = text_maxlen
|
||||
|
||||
self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae"
|
||||
self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type)
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_input_names(self):
|
||||
pass
|
||||
|
||||
def get_output_names(self):
|
||||
pass
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return None
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
pass
|
||||
|
||||
def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
"""For TensorRT EP"""
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
|
||||
profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
|
||||
|
||||
if self.name != "CLIP":
|
||||
if static_image_shape:
|
||||
profile_id += f"_h_{image_height}_w_{image_width}"
|
||||
else:
|
||||
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
|
||||
|
||||
return profile_id
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
"""For TensorRT"""
|
||||
return None
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
return None
|
||||
|
||||
def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True):
|
||||
self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16)
|
||||
|
||||
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
|
||||
onnx_graph = onnx.load(input_onnx_path)
|
||||
opt = TrtOptimizer(onnx_graph)
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.cleanup()
|
||||
onnx_opt_graph = opt.get_optimized_onnx_graph()
|
||||
onnx.save(onnx_opt_graph, optimized_onnx_path)
|
||||
|
||||
def check_dims(self, batch_size, image_height, image_width):
|
||||
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
||||
assert image_height % 8 == 0 or image_width % 8 == 0
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
||||
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
||||
return (latent_height, latent_width)
|
||||
|
||||
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
min_batch = batch_size if static_batch else self.min_batch
|
||||
max_batch = batch_size if static_batch else self.max_batch
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
min_image_height = image_height if static_image_shape else self.min_image_shape
|
||||
max_image_height = image_height if static_image_shape else self.max_image_shape
|
||||
min_image_width = image_width if static_image_shape else self.min_image_shape
|
||||
max_image_width = image_width if static_image_shape else self.max_image_shape
|
||||
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
|
||||
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
|
||||
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
|
||||
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
|
||||
return (
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="CLIP",
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["input_ids"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["text_embeddings"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
||||
batch_size, image_height, image_width, static_batch, static_image_shape
|
||||
)
|
||||
return {
|
||||
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"input_ids": (batch_size, self.text_maxlen),
|
||||
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
|
||||
onnx_graph = onnx.load(input_onnx_path)
|
||||
opt = TrtOptimizer(onnx_graph)
|
||||
opt.select_outputs([0]) # delete graph output#1
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
||||
opt.cleanup()
|
||||
onnx_opt_graph = opt.get_optimized_onnx_graph()
|
||||
onnx.save(onnx_opt_graph, optimized_onnx_path)
|
||||
|
||||
|
||||
class UNet(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
device="cuda",
|
||||
fp16=False, # used by TRT
|
||||
max_batch_size=16,
|
||||
embedding_dim=768,
|
||||
text_maxlen=77,
|
||||
unet_dim=4,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="UNet",
|
||||
device=device,
|
||||
fp16=fp16,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
text_maxlen=text_maxlen,
|
||||
)
|
||||
self.unet_dim = unet_dim
|
||||
|
||||
def get_input_names(self):
|
||||
return ["sample", "timestep", "encoder_hidden_states"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
"sample": {0: "2B", 2: "H", 3: "W"},
|
||||
"encoder_hidden_states": {0: "2B"},
|
||||
"latent": {0: "2B", 2: "H", 3: "W"},
|
||||
}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
return {
|
||||
"sample": [
|
||||
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
||||
(2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
||||
],
|
||||
"encoder_hidden_states": [
|
||||
(2 * min_batch, self.text_maxlen, self.embedding_dim),
|
||||
(2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
(2 * max_batch, self.text_maxlen, self.embedding_dim),
|
||||
],
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
"timestep": [1],
|
||||
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
dtype = torch.float16 if self.fp16 else torch.float32
|
||||
return (
|
||||
torch.randn(
|
||||
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
||||
),
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
||||
)
|
||||
|
||||
|
||||
class VAE(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="VAE Decoder",
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["images"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
return {
|
||||
"latent": [
|
||||
(min_batch, 4, min_latent_height, min_latent_width),
|
||||
(batch_size, 4, latent_height, latent_width),
|
||||
(max_batch, 4, max_latent_height, max_latent_width),
|
||||
]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"latent": (batch_size, 4, latent_height, latent_width),
|
||||
"images": (batch_size, 3, image_height, image_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
|
||||
|
|
@ -31,9 +31,8 @@ pip install numpy>=1.24.1 onnx>=1.13.0 coloredlogs protobuf==3.20.3 psutil sympy
|
|||
pip install onnxruntime-gpu
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -44,385 +43,13 @@ from diffusers.pipelines.stable_diffusion import (
|
|||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
from diffusers.utils import DIFFUSERS_CACHE
|
||||
from huggingface_hub import snapshot_download
|
||||
from ort_utils import OrtCudaSession
|
||||
from models import CLIP, VAE, UNet
|
||||
from ort_utils import Engines
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Engine(OrtCudaSession):
|
||||
def __init__(self, engine_path, provider, device_id: int = 0, enable_cuda_graph=False):
|
||||
self.engine_path = engine_path
|
||||
self.provider = provider
|
||||
self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
|
||||
device = torch.device("cuda", device_id)
|
||||
ort_session = ort.InferenceSession(
|
||||
self.engine_path,
|
||||
providers=[
|
||||
(provider, self.provider_options),
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
|
||||
super().__init__(ort_session, device, enable_cuda_graph)
|
||||
|
||||
def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool):
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
"enable_cuda_graph": enable_cuda_graph,
|
||||
}
|
||||
|
||||
|
||||
class OrtStableDiffusionOptimizer:
|
||||
def __init__(self, model_type: str):
|
||||
assert model_type in ["vae", "unet", "clip"]
|
||||
self.model_type = model_type
|
||||
self.model_type_class_mapping = {
|
||||
"unet": UnetOnnxModel,
|
||||
"vae": VaeOnnxModel,
|
||||
"clip": ClipOnnxModel,
|
||||
}
|
||||
|
||||
def optimize_by_ort(self, onnx_model):
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
|
||||
# Use this step to see the final graph that executed by Onnx Runtime.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save to a temporary file so that we can load it with Onnx Runtime.
|
||||
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
||||
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
||||
onnx_model.save_model_to_file(str(tmp_model_path))
|
||||
ort_optimized_model_path = tmp_model_path
|
||||
optimize_by_onnxruntime(
|
||||
str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
|
||||
)
|
||||
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
||||
return self.model_type_class_mapping[self.model_type](model)
|
||||
|
||||
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True):
|
||||
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
|
||||
logger.info(f"Optimize {input_fp32_onnx_path}...")
|
||||
fusion_options = FusionOptions(self.model_type)
|
||||
if self.model_type in ["unet"] and not float16:
|
||||
fusion_options.enable_packed_kv = False
|
||||
fusion_options.enable_packed_qkv = False
|
||||
|
||||
m = optimize_model(
|
||||
input_fp32_onnx_path,
|
||||
model_type=self.model_type,
|
||||
num_heads=0, # will be deduced from graph
|
||||
hidden_size=0, # will be deduced from graph
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=True,
|
||||
)
|
||||
|
||||
if self.model_type == "clip":
|
||||
m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output.
|
||||
|
||||
if float16:
|
||||
logger.info("Convert to float16 ...")
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
op_block_list=["RandomNormalLike"],
|
||||
)
|
||||
|
||||
# Note that ORT 1.15 could not save model larger than 2GB. This only works for float16
|
||||
if float16 or (self.model_type != "unet"):
|
||||
m = self.optimize_by_ort(m)
|
||||
|
||||
m.get_operator_statistics()
|
||||
m.get_fused_operator_statistics()
|
||||
m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16)
|
||||
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, model, name, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.device = device
|
||||
|
||||
self.min_batch = 1
|
||||
self.max_batch = max_batch_size
|
||||
self.min_image_shape = 256 # min image resolution: 256x256
|
||||
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
||||
self.min_latent_shape = self.min_image_shape // 8
|
||||
self.max_latent_shape = self.max_image_shape // 8
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_maxlen = text_maxlen
|
||||
|
||||
self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae"
|
||||
self.optimizer = OrtStableDiffusionOptimizer(self.model_type)
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_input_names(self):
|
||||
pass
|
||||
|
||||
def get_output_names(self):
|
||||
pass
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return None
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
pass
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
return None
|
||||
|
||||
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, fp16):
|
||||
self.optimizer.optimize(input_fp32_onnx_path, optimized_onnx_path, fp16)
|
||||
|
||||
def check_dims(self, batch_size, image_height, image_width):
|
||||
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
||||
assert image_height % 8 == 0 or image_width % 8 == 0
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
||||
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
||||
return (latent_height, latent_width)
|
||||
|
||||
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
min_batch = batch_size if static_batch else self.min_batch
|
||||
max_batch = batch_size if static_batch else self.max_batch
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
min_image_height = image_height if static_image_shape else self.min_image_shape
|
||||
max_image_height = image_height if static_image_shape else self.max_image_shape
|
||||
min_image_width = image_width if static_image_shape else self.min_image_shape
|
||||
max_image_width = image_width if static_image_shape else self.max_image_shape
|
||||
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
|
||||
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
|
||||
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
|
||||
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
|
||||
return (
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
)
|
||||
|
||||
|
||||
def get_onnx_path(model_name, onnx_dir):
|
||||
return os.path.join(onnx_dir, model_name + ".onnx")
|
||||
|
||||
|
||||
def get_engine_path(engine_dir, model_name, profile_id):
|
||||
return os.path.join(engine_dir, model_name + profile_id + ".onnx")
|
||||
|
||||
|
||||
def build_engines(
|
||||
models,
|
||||
engine_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
force_engine_rebuild: bool = False,
|
||||
fp16: bool = True,
|
||||
provider: str = "CUDAExecutionProvider",
|
||||
device_id: int = 0,
|
||||
enable_cuda_graph: bool = False,
|
||||
):
|
||||
profile_id = "_fp16" if fp16 else "_fp32"
|
||||
|
||||
if force_engine_rebuild:
|
||||
if os.path.isdir(onnx_dir):
|
||||
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir)
|
||||
shutil.rmtree(onnx_dir)
|
||||
if os.path.isdir(engine_dir):
|
||||
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir)
|
||||
shutil.rmtree(engine_dir)
|
||||
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
|
||||
# Export models to ONNX
|
||||
for model_name, model_obj in models.items():
|
||||
onnx_path = get_onnx_path(model_name, onnx_dir)
|
||||
onnx_opt_path = get_engine_path(engine_dir, model_name, profile_id)
|
||||
if os.path.exists(onnx_opt_path):
|
||||
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
||||
else:
|
||||
if os.path.exists(onnx_path):
|
||||
logger.info("Found cached model: %s", onnx_path)
|
||||
else:
|
||||
logger.info("Exporting model: %s", onnx_path)
|
||||
model = model_obj.get_model().to(model_obj.device)
|
||||
with torch.inference_mode():
|
||||
inputs = model_obj.get_sample_input(1, 512, 512)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Optimize onnx
|
||||
logger.info("Generating optimized model: %s", onnx_opt_path)
|
||||
model_obj.optimize(onnx_path, onnx_opt_path, fp16)
|
||||
|
||||
built_engines = {}
|
||||
for model_name in models:
|
||||
engine_path = get_engine_path(engine_dir, model_name, profile_id)
|
||||
engine = Engine(engine_path, provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph)
|
||||
logger.info("%s options for %s: %s", provider, model_name, engine.provider_options)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
return built_engines
|
||||
|
||||
|
||||
def run_engine(engine, feed_dict):
|
||||
return engine.infer(feed_dict)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="CLIP",
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["input_ids"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["text_embeddings", "pooler_output"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"input_ids": (batch_size, self.text_maxlen),
|
||||
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
||||
# "pooler_output": (batch_size, self.embedding_dim)
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
|
||||
class UNet(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
device="cuda",
|
||||
max_batch_size=16,
|
||||
embedding_dim=768,
|
||||
text_maxlen=77,
|
||||
unet_dim=4,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="UNet",
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
text_maxlen=text_maxlen,
|
||||
)
|
||||
self.unet_dim = unet_dim
|
||||
|
||||
def get_input_names(self):
|
||||
return ["sample", "timestep", "encoder_hidden_states"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
"sample": {0: "2B", 2: "H", 3: "W"},
|
||||
"encoder_hidden_states": {0: "2B"},
|
||||
"latent": {0: "2B", 2: "H", 3: "W"},
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
"timestep": [1],
|
||||
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return (
|
||||
torch.randn(
|
||||
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
||||
),
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=torch.float32, device=self.device),
|
||||
)
|
||||
|
||||
|
||||
class VAE(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="VAE Decoder",
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["images"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"latent": (batch_size, 4, latent_height, latent_width),
|
||||
"images": (batch_size, 3, image_height, image_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
|
@ -457,7 +84,6 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
self.unet_in_channels = unet.config.in_channels
|
||||
|
||||
self.inpaint = False
|
||||
self.onnx_opset = onnx_opset
|
||||
self.onnx_dir = onnx_dir
|
||||
self.engine_dir = engine_dir
|
||||
self.force_engine_rebuild = force_engine_rebuild
|
||||
|
|
@ -466,9 +92,8 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
self.max_batch_size = 16
|
||||
|
||||
self.models = {} # loaded in __load_models()
|
||||
self.engines = {} # loaded in build_engines()
|
||||
self.engines = Engines("CUDAExecutionProvider", onnx_opset)
|
||||
|
||||
self.provider = "CUDAExecutionProvider"
|
||||
self.fp16 = False
|
||||
|
||||
def __load_models(self):
|
||||
|
|
@ -484,6 +109,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
self.models["unet"] = UNet(
|
||||
self.unet,
|
||||
device=self.torch_device,
|
||||
fp16=self.fp16,
|
||||
max_batch_size=self.max_batch_size,
|
||||
embedding_dim=self.embedding_dim,
|
||||
unet_dim=(9 if self.inpaint else 4),
|
||||
|
|
@ -529,18 +155,16 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
self.torch_device = torch.device(torch_device)
|
||||
|
||||
# load models
|
||||
self.fp16 = torch_dtype == torch.float16
|
||||
self.__load_models()
|
||||
|
||||
# build engines
|
||||
self.fp16 = torch_dtype == torch.float16
|
||||
self.engines = build_engines(
|
||||
self.engines.build(
|
||||
self.models,
|
||||
self.engine_dir,
|
||||
self.onnx_dir,
|
||||
self.onnx_opset,
|
||||
force_engine_rebuild=self.force_engine_rebuild,
|
||||
fp16=self.fp16,
|
||||
provider=self.provider,
|
||||
device_id=self.torch_device.index or torch.cuda.current_device(),
|
||||
enable_cuda_graph=self.enable_cuda_graph,
|
||||
)
|
||||
|
|
@ -582,7 +206,9 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
)
|
||||
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone()
|
||||
text_embeddings = (
|
||||
self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone()
|
||||
)
|
||||
|
||||
# Tokenize negative prompt
|
||||
uncond_input_ids = (
|
||||
|
|
@ -597,7 +223,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
.to(self.torch_device)
|
||||
)
|
||||
|
||||
uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"]
|
||||
uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"]
|
||||
|
||||
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
|
||||
|
|
@ -618,8 +244,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = run_engine(
|
||||
self.engines["unet"],
|
||||
noise_pred = self.engines.get_engine("unet").infer(
|
||||
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
|
||||
)["latent"]
|
||||
|
||||
|
|
@ -633,14 +258,16 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
return latents
|
||||
|
||||
def __decode_latent(self, latents):
|
||||
images = run_engine(self.engines["vae"], {"latent": latents})["images"]
|
||||
images = self.engines.get_engine("vae").infer({"latent": latents})["images"]
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
def __allocate_buffers(self, image_height, image_width, batch_size):
|
||||
# Allocate output tensors for I/O bindings
|
||||
for model_name, obj in self.models.items():
|
||||
self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width))
|
||||
self.engines.get_engine(model_name).allocate_buffers(
|
||||
obj.get_shape_dict(batch_size, image_height, image_width)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
|
|
@ -736,9 +363,6 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
model_name_or_path = "runwayml/stable-diffusion-v1-5"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler")
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ import os
|
|||
import shutil
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
|
|
@ -50,9 +48,8 @@ from diffusers.pipelines.stable_diffusion import (
|
|||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx import shape_inference
|
||||
from models import CLIP, VAE, UNet
|
||||
from ort_utils import OrtCudaSession
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
|
|
@ -124,142 +121,6 @@ class Engine(OrtCudaSession):
|
|||
return trt_ep_options
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, onnx_graph):
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def cleanup(self):
|
||||
self.graph.cleanup().toposort()
|
||||
|
||||
def get_optimized_onnx_graph(self):
|
||||
return gs.export_onnx(self.graph)
|
||||
|
||||
def select_outputs(self, keep, names=None):
|
||||
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
||||
if names:
|
||||
for i, name in enumerate(names):
|
||||
self.graph.outputs[i].name = name
|
||||
|
||||
def fold_constants(self):
|
||||
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def infer_shapes(self):
|
||||
onnx_graph = gs.export_onnx(self.graph)
|
||||
if onnx_graph.ByteSize() > 2147483648:
|
||||
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
||||
else:
|
||||
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
||||
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, model, name, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
|
||||
self.min_batch = 1
|
||||
self.max_batch = max_batch_size
|
||||
self.min_image_shape = 256 # min image resolution: 256x256
|
||||
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
||||
self.min_latent_shape = self.min_image_shape // 8
|
||||
self.max_latent_shape = self.max_image_shape // 8
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_maxlen = text_maxlen
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_input_names(self):
|
||||
pass
|
||||
|
||||
def get_output_names(self):
|
||||
pass
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return None
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
pass
|
||||
|
||||
def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
|
||||
profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
|
||||
|
||||
if self.name != "CLIP":
|
||||
if static_image_shape:
|
||||
profile_id += f"_h_{image_height}_w_{image_width}"
|
||||
else:
|
||||
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
|
||||
|
||||
return profile_id
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
return None
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
return None
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.cleanup()
|
||||
return opt.get_optimized_onnx_graph()
|
||||
|
||||
def check_dims(self, batch_size, image_height, image_width):
|
||||
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
||||
assert image_height % 8 == 0 or image_width % 8 == 0
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
||||
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
||||
return (latent_height, latent_width)
|
||||
|
||||
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
min_batch = batch_size if static_batch else self.min_batch
|
||||
max_batch = batch_size if static_batch else self.max_batch
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
min_image_height = image_height if static_image_shape else self.min_image_shape
|
||||
max_image_height = image_height if static_image_shape else self.max_image_shape
|
||||
min_image_width = image_width if static_image_shape else self.min_image_shape
|
||||
max_image_width = image_width if static_image_shape else self.max_image_shape
|
||||
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
|
||||
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
|
||||
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
|
||||
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
|
||||
return (
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
)
|
||||
|
||||
|
||||
def get_onnx_path(model_name, onnx_dir, opt=True):
|
||||
return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
|
||||
|
||||
|
|
@ -352,8 +213,7 @@ def build_engines(
|
|||
# Optimize onnx
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
logger.info("Generating optimizing model: %s", onnx_opt_path)
|
||||
onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
|
||||
onnx.save(onnx_opt_graph, onnx_opt_path)
|
||||
model_obj.optimize_trt(onnx_path, onnx_opt_path)
|
||||
else:
|
||||
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
||||
|
||||
|
|
@ -403,177 +263,6 @@ def run_engine(engine, feed_dict):
|
|||
return engine.infer(feed_dict)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model, name="CLIP", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["input_ids"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["text_embeddings", "pooler_output"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
||||
batch_size, image_height, image_width, static_batch, static_image_shape
|
||||
)
|
||||
return {
|
||||
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"input_ids": (batch_size, self.text_maxlen),
|
||||
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.select_outputs([0]) # delete graph output#1
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
||||
opt.cleanup()
|
||||
return opt.get_optimized_onnx_graph()
|
||||
|
||||
|
||||
class UNet(BaseModel):
|
||||
def __init__(
|
||||
self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
name="UNet",
|
||||
fp16=fp16,
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
text_maxlen=text_maxlen,
|
||||
)
|
||||
self.unet_dim = unet_dim
|
||||
|
||||
def get_input_names(self):
|
||||
return ["sample", "timestep", "encoder_hidden_states"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
"sample": {0: "2B", 2: "H", 3: "W"},
|
||||
"encoder_hidden_states": {0: "2B"},
|
||||
"latent": {0: "2B", 2: "H", 3: "W"},
|
||||
}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
return {
|
||||
"sample": [
|
||||
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
||||
(2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
||||
],
|
||||
"encoder_hidden_states": [
|
||||
(2 * min_batch, self.text_maxlen, self.embedding_dim),
|
||||
(2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
(2 * max_batch, self.text_maxlen, self.embedding_dim),
|
||||
],
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
"timestep": [1],
|
||||
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
dtype = torch.float16 if self.fp16 else torch.float32
|
||||
return (
|
||||
torch.randn(
|
||||
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
||||
),
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
||||
)
|
||||
|
||||
|
||||
class VAE(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super().__init__(
|
||||
model=model, name="VAE decoder", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
|
||||
def get_input_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["images"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
||||
return {
|
||||
"latent": [
|
||||
(min_batch, 4, min_latent_height, min_latent_width),
|
||||
(batch_size, 4, latent_height, latent_width),
|
||||
(max_batch, 4, max_latent_height, max_latent_width),
|
||||
]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"latent": (batch_size, 4, latent_height, latent_width),
|
||||
"images": (batch_size, 3, image_height, image_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
|
||||
|
||||
|
||||
class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime.
|
||||
|
|
@ -644,8 +333,8 @@ class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
|
||||
self.models["unet"] = UNet(
|
||||
self.unet,
|
||||
fp16=True,
|
||||
device=self.torch_device,
|
||||
fp16=True,
|
||||
max_batch_size=self.max_batch_size,
|
||||
embedding_dim=self.embedding_dim,
|
||||
unet_dim=(9 if self.inpaint else 4),
|
||||
|
|
@ -888,23 +577,22 @@ class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
model_name_or_path = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler")
|
||||
|
||||
pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
model_name_or_path,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=scheduler,
|
||||
image_height=512,
|
||||
image_width=512,
|
||||
max_batch_size=1,
|
||||
max_batch_size=4,
|
||||
)
|
||||
|
||||
# re-use cached folder to save ONNX models and TensorRT Engines
|
||||
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision="fp16")
|
||||
pipe.set_cached_folder(model_name_or_path, revision="fp16")
|
||||
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
ONNX Model Optimizer for Stable Diffusion
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrtStableDiffusionOptimizer:
|
||||
def __init__(self, model_type: str):
|
||||
assert model_type in ["vae", "unet", "clip"]
|
||||
self.model_type = model_type
|
||||
self.model_type_class_mapping = {
|
||||
"unet": UnetOnnxModel,
|
||||
"vae": VaeOnnxModel,
|
||||
"clip": ClipOnnxModel,
|
||||
}
|
||||
|
||||
def optimize_by_ort(self, onnx_model):
|
||||
# Use this step to see the final graph that executed by Onnx Runtime.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save to a temporary file so that we can load it with Onnx Runtime.
|
||||
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
||||
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
||||
onnx_model.save_model_to_file(str(tmp_model_path))
|
||||
ort_optimized_model_path = tmp_model_path
|
||||
optimize_by_onnxruntime(
|
||||
str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
|
||||
)
|
||||
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
||||
return self.model_type_class_mapping[self.model_type](model)
|
||||
|
||||
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True):
|
||||
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
|
||||
logger.info(f"Optimize {input_fp32_onnx_path}...")
|
||||
fusion_options = FusionOptions(self.model_type)
|
||||
if self.model_type in ["unet"] and not float16:
|
||||
fusion_options.enable_packed_kv = False
|
||||
fusion_options.enable_packed_qkv = False
|
||||
|
||||
m = optimize_model(
|
||||
input_fp32_onnx_path,
|
||||
model_type=self.model_type,
|
||||
num_heads=0, # will be deduced from graph
|
||||
hidden_size=0, # will be deduced from graph
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=True,
|
||||
)
|
||||
|
||||
if self.model_type == "clip":
|
||||
m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output.
|
||||
|
||||
if float16:
|
||||
logger.info("Convert to float16 ...")
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
op_block_list=["RandomNormalLike"],
|
||||
)
|
||||
|
||||
# Note that ORT 1.15 could not save model larger than 2GB. This only works for float16
|
||||
if float16 or (self.model_type != "unet"):
|
||||
m = self.optimize_by_ort(m)
|
||||
|
||||
m.get_operator_statistics()
|
||||
m.get_fused_operator_statistics()
|
||||
m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16)
|
||||
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
|
||||
|
|
@ -3,17 +3,23 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.io_binding_helper import TypeHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrtCudaSession:
|
||||
"""ONNX Runtime Session for CUDA or TensorRT provider"""
|
||||
"""Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
|
||||
|
||||
def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False):
|
||||
self.ort_session = ort_session
|
||||
|
|
@ -110,3 +116,111 @@ class OrtCudaSession:
|
|||
self.ort_session.run_with_iobinding(self.io_binding)
|
||||
|
||||
return self.output_tensors
|
||||
|
||||
|
||||
class Engine(OrtCudaSession):
|
||||
def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False):
|
||||
self.engine_path = engine_path
|
||||
self.provider = provider
|
||||
self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
|
||||
device = torch.device("cuda", device_id)
|
||||
ort_session = ort.InferenceSession(
|
||||
self.engine_path,
|
||||
providers=[
|
||||
(provider, self.provider_options),
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
|
||||
super().__init__(ort_session, device, enable_cuda_graph)
|
||||
|
||||
def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]:
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
"enable_cuda_graph": enable_cuda_graph,
|
||||
}
|
||||
|
||||
|
||||
class Engines:
|
||||
def __init__(self, provider, onnx_opset: int = 14):
|
||||
self.provider = provider
|
||||
self.engines = {}
|
||||
self.onnx_opset = onnx_opset
|
||||
|
||||
@staticmethod
|
||||
def get_onnx_path(onnx_dir, model_name):
|
||||
return os.path.join(onnx_dir, model_name + ".onnx")
|
||||
|
||||
@staticmethod
|
||||
def get_engine_path(engine_dir, model_name, profile_id):
|
||||
return os.path.join(engine_dir, model_name + profile_id + ".onnx")
|
||||
|
||||
def build(
|
||||
self,
|
||||
models,
|
||||
engine_dir: str,
|
||||
onnx_dir: str,
|
||||
force_engine_rebuild: bool = False,
|
||||
fp16: bool = True,
|
||||
device_id: int = 0,
|
||||
enable_cuda_graph: bool = False,
|
||||
):
|
||||
profile_id = "_fp16" if fp16 else "_fp32"
|
||||
|
||||
if force_engine_rebuild:
|
||||
if os.path.isdir(onnx_dir):
|
||||
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir)
|
||||
shutil.rmtree(onnx_dir)
|
||||
if os.path.isdir(engine_dir):
|
||||
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir)
|
||||
shutil.rmtree(engine_dir)
|
||||
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
|
||||
# Export models to ONNX
|
||||
for model_name, model_obj in models.items():
|
||||
onnx_path = Engines.get_onnx_path(onnx_dir, model_name)
|
||||
onnx_opt_path = Engines.get_engine_path(engine_dir, model_name, profile_id)
|
||||
if os.path.exists(onnx_opt_path):
|
||||
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
||||
else:
|
||||
if os.path.exists(onnx_path):
|
||||
logger.info("Found cached model: %s", onnx_path)
|
||||
else:
|
||||
logger.info("Exporting model: %s", onnx_path)
|
||||
model = model_obj.get_model().to(model_obj.device)
|
||||
with torch.inference_mode():
|
||||
inputs = model_obj.get_sample_input(1, 512, 512)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=self.onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Optimize onnx
|
||||
logger.info("Generating optimized model: %s", onnx_opt_path)
|
||||
model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=fp16)
|
||||
|
||||
for model_name in models:
|
||||
engine_path = Engines.get_engine_path(engine_dir, model_name, profile_id)
|
||||
engine = Engine(engine_path, self.provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph)
|
||||
logger.info("%s options for %s: %s", self.provider, model_name, engine.provider_options)
|
||||
self.engines[model_name] = engine
|
||||
|
||||
def get_engine(self, model_name):
|
||||
return self.engines[model_name]
|
||||
|
|
|
|||
Loading…
Reference in a new issue