Refactoring of Stable Diffusion scripts (#17138)

Reduce duplicated code in two stable diffusion pipelines (CUDA and TensorRT). Move the common code to models.py
This commit is contained in:
Tianlei Wu 2023-08-15 09:36:31 -07:00 committed by GitHub
parent 5e971bc51a
commit 3aba736ee2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 594 additions and 716 deletions

View file

@ -0,0 +1,368 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# Copyright 2023 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Models used in Stable diffusion.
"""
import logging
import onnx
import onnx_graphsurgeon as gs
import torch
from onnx import shape_inference
from ort_optimizer import OrtStableDiffusionOptimizer
from polygraphy.backend.onnx.loader import fold_constants
logger = logging.getLogger(__name__)
class TrtOptimizer:
def __init__(self, onnx_graph):
self.graph = gs.import_onnx(onnx_graph)
def cleanup(self):
self.graph.cleanup().toposort()
def get_optimized_onnx_graph(self):
return gs.export_onnx(self.graph)
def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name
def fold_constants(self):
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
self.graph = gs.import_onnx(onnx_graph)
def infer_shapes(self):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
raise TypeError("ERROR: model size exceeds supported 2GB limit")
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)
self.graph = gs.import_onnx(onnx_graph)
class BaseModel:
def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77):
self.model = model
self.name = name
self.fp16 = fp16
self.device = device
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8
self.embedding_dim = embedding_dim
self.text_maxlen = text_maxlen
self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae"
self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type)
def get_model(self):
return self.model
def get_input_names(self):
pass
def get_output_names(self):
pass
def get_dynamic_axes(self):
return None
def get_sample_input(self, batch_size, image_height, image_width):
pass
def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
"""For TensorRT EP"""
(
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
_,
_,
_,
_,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
if self.name != "CLIP":
if static_image_shape:
profile_id += f"_h_{image_height}_w_{image_width}"
else:
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
return profile_id
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
"""For TensorRT"""
return None
def get_shape_dict(self, batch_size, image_height, image_width):
return None
def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True):
self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16)
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
opt = TrtOptimizer(onnx_graph)
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.cleanup()
onnx_opt_graph = opt.get_optimized_onnx_graph()
onnx.save(onnx_opt_graph, optimized_onnx_path)
def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // 8
latent_width = image_width // 8
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // 8
latent_width = image_width // 8
min_image_height = image_height if static_image_shape else self.min_image_shape
max_image_height = image_height if static_image_shape else self.max_image_shape
min_image_width = image_width if static_image_shape else self.min_image_shape
max_image_width = image_width if static_image_shape else self.max_image_shape
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
return (
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
)
class CLIP(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="CLIP",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)
def get_input_names(self):
return ["input_ids"]
def get_output_names(self):
return ["text_embeddings"]
def get_dynamic_axes(self):
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
batch_size, image_height, image_width, static_batch, static_image_shape
)
return {
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return {
"input_ids": (batch_size, self.text_maxlen),
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
}
def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
opt = TrtOptimizer(onnx_graph)
opt.select_outputs([0]) # delete graph output#1
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
opt.cleanup()
onnx_opt_graph = opt.get_optimized_onnx_graph()
onnx.save(onnx_opt_graph, optimized_onnx_path)
class UNet(BaseModel):
def __init__(
self,
model,
device="cuda",
fp16=False, # used by TRT
max_batch_size=16,
embedding_dim=768,
text_maxlen=77,
unet_dim=4,
):
super().__init__(
model=model,
name="UNet",
device=device,
fp16=fp16,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
text_maxlen=text_maxlen,
)
self.unet_dim = unet_dim
def get_input_names(self):
return ["sample", "timestep", "encoder_hidden_states"]
def get_output_names(self):
return ["latent"]
def get_dynamic_axes(self):
return {
"sample": {0: "2B", 2: "H", 3: "W"},
"encoder_hidden_states": {0: "2B"},
"latent": {0: "2B", 2: "H", 3: "W"},
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"sample": [
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
(2 * batch_size, self.unet_dim, latent_height, latent_width),
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
],
"encoder_hidden_states": [
(2 * min_batch, self.text_maxlen, self.embedding_dim),
(2 * batch_size, self.text_maxlen, self.embedding_dim),
(2 * max_batch, self.text_maxlen, self.embedding_dim),
],
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
"timestep": [1],
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
"latent": (2 * batch_size, 4, latent_height, latent_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
dtype = torch.float16 if self.fp16 else torch.float32
return (
torch.randn(
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
),
torch.tensor([1.0], dtype=torch.float32, device=self.device),
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
)
class VAE(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="VAE Decoder",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)
def get_input_names(self):
return ["latent"]
def get_output_names(self):
return ["images"]
def get_dynamic_axes(self):
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"latent": [
(min_batch, 4, min_latent_height, min_latent_width),
(batch_size, 4, latent_height, latent_width),
(max_batch, 4, max_latent_height, max_latent_width),
]
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"latent": (batch_size, 4, latent_height, latent_width),
"images": (batch_size, 3, image_height, image_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)

View file

@ -31,9 +31,8 @@ pip install numpy>=1.24.1 onnx>=1.13.0 coloredlogs protobuf==3.20.3 psutil sympy
pip install onnxruntime-gpu
"""
import gc
import logging
import os
import shutil
from typing import List, Optional, Union
import torch
@ -44,385 +43,13 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import DIFFUSERS_CACHE, logging
from diffusers.utils import DIFFUSERS_CACHE
from huggingface_hub import snapshot_download
from ort_utils import OrtCudaSession
from models import CLIP, VAE, UNet
from ort_utils import Engines
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import onnxruntime as ort
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Engine(OrtCudaSession):
def __init__(self, engine_path, provider, device_id: int = 0, enable_cuda_graph=False):
self.engine_path = engine_path
self.provider = provider
self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph)
device = torch.device("cuda", device_id)
ort_session = ort.InferenceSession(
self.engine_path,
providers=[
(provider, self.provider_options),
"CPUExecutionProvider",
],
)
super().__init__(ort_session, device, enable_cuda_graph)
def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool):
return {
"device_id": device_id,
"arena_extend_strategy": "kSameAsRequested",
"enable_cuda_graph": enable_cuda_graph,
}
class OrtStableDiffusionOptimizer:
def __init__(self, model_type: str):
assert model_type in ["vae", "unet", "clip"]
self.model_type = model_type
self.model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
}
def optimize_by_ort(self, onnx_model):
import tempfile
from pathlib import Path
import onnx
# Use this step to see the final graph that executed by Onnx Runtime.
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
onnx_model.save_model_to_file(str(tmp_model_path))
ort_optimized_model_path = tmp_model_path
optimize_by_onnxruntime(
str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
return self.model_type_class_mapping[self.model_type](model)
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True):
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
logger.info(f"Optimize {input_fp32_onnx_path}...")
fusion_options = FusionOptions(self.model_type)
if self.model_type in ["unet"] and not float16:
fusion_options.enable_packed_kv = False
fusion_options.enable_packed_qkv = False
m = optimize_model(
input_fp32_onnx_path,
model_type=self.model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
)
if self.model_type == "clip":
m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output.
if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=False,
op_block_list=["RandomNormalLike"],
)
# Note that ORT 1.15 could not save model larger than 2GB. This only works for float16
if float16 or (self.model_type != "unet"):
m = self.optimize_by_ort(m)
m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16)
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
class BaseModel:
def __init__(self, model, name, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
self.model = model
self.name = name
self.device = device
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8
self.embedding_dim = embedding_dim
self.text_maxlen = text_maxlen
self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae"
self.optimizer = OrtStableDiffusionOptimizer(self.model_type)
def get_model(self):
return self.model
def get_input_names(self):
pass
def get_output_names(self):
pass
def get_dynamic_axes(self):
return None
def get_sample_input(self, batch_size, image_height, image_width):
pass
def get_shape_dict(self, batch_size, image_height, image_width):
return None
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, fp16):
self.optimizer.optimize(input_fp32_onnx_path, optimized_onnx_path, fp16)
def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // 8
latent_width = image_width // 8
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // 8
latent_width = image_width // 8
min_image_height = image_height if static_image_shape else self.min_image_shape
max_image_height = image_height if static_image_shape else self.max_image_shape
min_image_width = image_width if static_image_shape else self.min_image_shape
max_image_width = image_width if static_image_shape else self.max_image_shape
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
return (
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
)
def get_onnx_path(model_name, onnx_dir):
return os.path.join(onnx_dir, model_name + ".onnx")
def get_engine_path(engine_dir, model_name, profile_id):
return os.path.join(engine_dir, model_name + profile_id + ".onnx")
def build_engines(
models,
engine_dir,
onnx_dir,
onnx_opset,
force_engine_rebuild: bool = False,
fp16: bool = True,
provider: str = "CUDAExecutionProvider",
device_id: int = 0,
enable_cuda_graph: bool = False,
):
profile_id = "_fp16" if fp16 else "_fp32"
if force_engine_rebuild:
if os.path.isdir(onnx_dir):
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir)
shutil.rmtree(onnx_dir)
if os.path.isdir(engine_dir):
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir)
shutil.rmtree(engine_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Export models to ONNX
for model_name, model_obj in models.items():
onnx_path = get_onnx_path(model_name, onnx_dir)
onnx_opt_path = get_engine_path(engine_dir, model_name, profile_id)
if os.path.exists(onnx_opt_path):
logger.info("Found cached optimized model: %s", onnx_opt_path)
else:
if os.path.exists(onnx_path):
logger.info("Found cached model: %s", onnx_path)
else:
logger.info("Exporting model: %s", onnx_path)
model = model_obj.get_model().to(model_obj.device)
with torch.inference_mode():
inputs = model_obj.get_sample_input(1, 512, 512)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
# Optimize onnx
logger.info("Generating optimized model: %s", onnx_opt_path)
model_obj.optimize(onnx_path, onnx_opt_path, fp16)
built_engines = {}
for model_name in models:
engine_path = get_engine_path(engine_dir, model_name, profile_id)
engine = Engine(engine_path, provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph)
logger.info("%s options for %s: %s", provider, model_name, engine.provider_options)
built_engines[model_name] = engine
return built_engines
def run_engine(engine, feed_dict):
return engine.infer(feed_dict)
class CLIP(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="CLIP",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)
def get_input_names(self):
return ["input_ids"]
def get_output_names(self):
return ["text_embeddings", "pooler_output"]
def get_dynamic_axes(self):
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return {
"input_ids": (batch_size, self.text_maxlen),
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
# "pooler_output": (batch_size, self.embedding_dim)
}
def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
class UNet(BaseModel):
def __init__(
self,
model,
device="cuda",
max_batch_size=16,
embedding_dim=768,
text_maxlen=77,
unet_dim=4,
):
super().__init__(
model=model,
name="UNet",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
text_maxlen=text_maxlen,
)
self.unet_dim = unet_dim
def get_input_names(self):
return ["sample", "timestep", "encoder_hidden_states"]
def get_output_names(self):
return ["latent"]
def get_dynamic_axes(self):
return {
"sample": {0: "2B", 2: "H", 3: "W"},
"encoder_hidden_states": {0: "2B"},
"latent": {0: "2B", 2: "H", 3: "W"},
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
"timestep": [1],
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
"latent": (2 * batch_size, 4, latent_height, latent_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return (
torch.randn(
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
),
torch.tensor([1.0], dtype=torch.float32, device=self.device),
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=torch.float32, device=self.device),
)
class VAE(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model,
name="VAE Decoder",
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
)
def get_input_names(self):
return ["latent"]
def get_output_names(self):
return ["images"]
def get_dynamic_axes(self):
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"latent": (batch_size, 4, latent_height, latent_width),
"images": (batch_size, 3, image_height, image_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
logger = logging.getLogger(__name__)
class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
@ -457,7 +84,6 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
self.unet_in_channels = unet.config.in_channels
self.inpaint = False
self.onnx_opset = onnx_opset
self.onnx_dir = onnx_dir
self.engine_dir = engine_dir
self.force_engine_rebuild = force_engine_rebuild
@ -466,9 +92,8 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
self.max_batch_size = 16
self.models = {} # loaded in __load_models()
self.engines = {} # loaded in build_engines()
self.engines = Engines("CUDAExecutionProvider", onnx_opset)
self.provider = "CUDAExecutionProvider"
self.fp16 = False
def __load_models(self):
@ -484,6 +109,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
self.models["unet"] = UNet(
self.unet,
device=self.torch_device,
fp16=self.fp16,
max_batch_size=self.max_batch_size,
embedding_dim=self.embedding_dim,
unet_dim=(9 if self.inpaint else 4),
@ -529,18 +155,16 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
self.torch_device = torch.device(torch_device)
# load models
self.fp16 = torch_dtype == torch.float16
self.__load_models()
# build engines
self.fp16 = torch_dtype == torch.float16
self.engines = build_engines(
self.engines.build(
self.models,
self.engine_dir,
self.onnx_dir,
self.onnx_opset,
force_engine_rebuild=self.force_engine_rebuild,
fp16=self.fp16,
provider=self.provider,
device_id=self.torch_device.index or torch.cuda.current_device(),
enable_cuda_graph=self.enable_cuda_graph,
)
@ -582,7 +206,9 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone()
text_embeddings = (
self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone()
)
# Tokenize negative prompt
uncond_input_ids = (
@ -597,7 +223,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
.to(self.torch_device)
)
uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"]
uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"]
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
@ -618,8 +244,7 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32)
# Predict the noise residual
noise_pred = run_engine(
self.engines["unet"],
noise_pred = self.engines.get_engine("unet").infer(
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
)["latent"]
@ -633,14 +258,16 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
return latents
def __decode_latent(self, latents):
images = run_engine(self.engines["vae"], {"latent": latents})["images"]
images = self.engines.get_engine("vae").infer({"latent": latents})["images"]
images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __allocate_buffers(self, image_height, image_width, batch_size):
# Allocate output tensors for I/O bindings
for model_name, obj in self.models.items():
self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width))
self.engines.get_engine(model_name).allocate_buffers(
obj.get_shape_dict(batch_size, image_height, image_width)
)
@torch.no_grad()
def __call__(
@ -736,9 +363,6 @@ class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline):
if __name__ == "__main__":
import torch
from diffusers import DDIMScheduler
model_name_or_path = "runwayml/stable-diffusion-v1-5"
scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler")

View file

@ -37,8 +37,6 @@ import os
import shutil
from typing import List, Optional, Union
import onnx
import onnx_graphsurgeon as gs
import torch
from cuda import cudart
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -50,9 +48,8 @@ from diffusers.pipelines.stable_diffusion import (
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import DIFFUSERS_CACHE, logging
from huggingface_hub import snapshot_download
from onnx import shape_inference
from models import CLIP, VAE, UNet
from ort_utils import OrtCudaSession
from polygraphy.backend.onnx.loader import fold_constants
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import onnxruntime as ort
@ -124,142 +121,6 @@ class Engine(OrtCudaSession):
return trt_ep_options
class Optimizer:
def __init__(self, onnx_graph):
self.graph = gs.import_onnx(onnx_graph)
def cleanup(self):
self.graph.cleanup().toposort()
def get_optimized_onnx_graph(self):
return gs.export_onnx(self.graph)
def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name
def fold_constants(self):
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
self.graph = gs.import_onnx(onnx_graph)
def infer_shapes(self):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
raise TypeError("ERROR: model size exceeds supported 2GB limit")
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)
self.graph = gs.import_onnx(onnx_graph)
class BaseModel:
def __init__(self, model, name, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
self.model = model
self.name = name
self.fp16 = fp16
self.device = device
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
self.min_latent_shape = self.min_image_shape // 8
self.max_latent_shape = self.max_image_shape // 8
self.embedding_dim = embedding_dim
self.text_maxlen = text_maxlen
def get_model(self):
return self.model
def get_input_names(self):
pass
def get_output_names(self):
pass
def get_dynamic_axes(self):
return None
def get_sample_input(self, batch_size, image_height, image_width):
pass
def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
(
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
_,
_,
_,
_,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
if self.name != "CLIP":
if static_image_shape:
profile_id += f"_h_{image_height}_w_{image_width}"
else:
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
return profile_id
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
return None
def get_shape_dict(self, batch_size, image_height, image_width):
return None
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph)
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.cleanup()
return opt.get_optimized_onnx_graph()
def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
latent_height = image_height // 8
latent_width = image_width // 8
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
latent_height = image_height // 8
latent_width = image_width // 8
min_image_height = image_height if static_image_shape else self.min_image_shape
max_image_height = image_height if static_image_shape else self.max_image_shape
min_image_width = image_width if static_image_shape else self.min_image_shape
max_image_width = image_width if static_image_shape else self.max_image_shape
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
return (
min_batch,
max_batch,
min_image_height,
max_image_height,
min_image_width,
max_image_width,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
)
def get_onnx_path(model_name, onnx_dir, opt=True):
return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
@ -352,8 +213,7 @@ def build_engines(
# Optimize onnx
if not os.path.exists(onnx_opt_path):
logger.info("Generating optimizing model: %s", onnx_opt_path)
onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
onnx.save(onnx_opt_graph, onnx_opt_path)
model_obj.optimize_trt(onnx_path, onnx_opt_path)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
@ -403,177 +263,6 @@ def run_engine(engine, feed_dict):
return engine.infer(feed_dict)
class CLIP(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model, name="CLIP", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
)
def get_input_names(self):
return ["input_ids"]
def get_output_names(self):
return ["text_embeddings", "pooler_output"]
def get_dynamic_axes(self):
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
self.check_dims(batch_size, image_height, image_width)
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
batch_size, image_height, image_width, static_batch, static_image_shape
)
return {
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
}
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return {
"input_ids": (batch_size, self.text_maxlen),
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
}
def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
def optimize(self, onnx_graph):
opt = Optimizer(onnx_graph)
opt.select_outputs([0]) # delete graph output#1
opt.cleanup()
opt.fold_constants()
opt.infer_shapes()
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
opt.cleanup()
return opt.get_optimized_onnx_graph()
class UNet(BaseModel):
def __init__(
self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
):
super().__init__(
model=model,
name="UNet",
fp16=fp16,
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
text_maxlen=text_maxlen,
)
self.unet_dim = unet_dim
def get_input_names(self):
return ["sample", "timestep", "encoder_hidden_states"]
def get_output_names(self):
return ["latent"]
def get_dynamic_axes(self):
return {
"sample": {0: "2B", 2: "H", 3: "W"},
"encoder_hidden_states": {0: "2B"},
"latent": {0: "2B", 2: "H", 3: "W"},
}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"sample": [
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
(2 * batch_size, self.unet_dim, latent_height, latent_width),
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
],
"encoder_hidden_states": [
(2 * min_batch, self.text_maxlen, self.embedding_dim),
(2 * batch_size, self.text_maxlen, self.embedding_dim),
(2 * max_batch, self.text_maxlen, self.embedding_dim),
],
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
"timestep": [1],
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
"latent": (2 * batch_size, 4, latent_height, latent_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
dtype = torch.float16 if self.fp16 else torch.float32
return (
torch.randn(
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
),
torch.tensor([1.0], dtype=torch.float32, device=self.device),
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
)
class VAE(BaseModel):
def __init__(self, model, device, max_batch_size, embedding_dim):
super().__init__(
model=model, name="VAE decoder", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
)
def get_input_names(self):
return ["latent"]
def get_output_names(self):
return ["images"]
def get_dynamic_axes(self):
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
(
min_batch,
max_batch,
_,
_,
_,
_,
min_latent_height,
max_latent_height,
min_latent_width,
max_latent_width,
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
return {
"latent": [
(min_batch, 4, min_latent_height, min_latent_width),
(batch_size, 4, latent_height, latent_width),
(max_batch, 4, max_latent_height, max_latent_width),
]
}
def get_shape_dict(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return {
"latent": (batch_size, 4, latent_height, latent_width),
"images": (batch_size, 3, image_height, image_width),
}
def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime.
@ -644,8 +333,8 @@ class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models["unet"] = UNet(
self.unet,
fp16=True,
device=self.torch_device,
fp16=True,
max_batch_size=self.max_batch_size,
embedding_dim=self.embedding_dim,
unet_dim=(9 if self.inpaint else 4),
@ -888,23 +577,22 @@ class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline):
if __name__ == "__main__":
import torch
from diffusers import DDIMScheduler
model_name_or_path = "runwayml/stable-diffusion-v1-5"
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler")
pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
model_name_or_path,
revision="fp16",
torch_dtype=torch.float16,
scheduler=scheduler,
image_height=512,
image_width=512,
max_batch_size=1,
max_batch_size=4,
)
# re-use cached folder to save ONNX models and TensorRT Engines
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision="fp16")
pipe.set_cached_folder(model_name_or_path, revision="fp16")
pipe = pipe.to("cuda")

View file

@ -0,0 +1,84 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
ONNX Model Optimizer for Stable Diffusion
"""
import logging
import tempfile
from pathlib import Path
import onnx
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
logger = logging.getLogger(__name__)
class OrtStableDiffusionOptimizer:
def __init__(self, model_type: str):
assert model_type in ["vae", "unet", "clip"]
self.model_type = model_type
self.model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
}
def optimize_by_ort(self, onnx_model):
# Use this step to see the final graph that executed by Onnx Runtime.
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
onnx_model.save_model_to_file(str(tmp_model_path))
ort_optimized_model_path = tmp_model_path
optimize_by_onnxruntime(
str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
return self.model_type_class_mapping[self.model_type](model)
def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True):
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
logger.info(f"Optimize {input_fp32_onnx_path}...")
fusion_options = FusionOptions(self.model_type)
if self.model_type in ["unet"] and not float16:
fusion_options.enable_packed_kv = False
fusion_options.enable_packed_qkv = False
m = optimize_model(
input_fp32_onnx_path,
model_type=self.model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
)
if self.model_type == "clip":
m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output.
if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=False,
op_block_list=["RandomNormalLike"],
)
# Note that ORT 1.15 could not save model larger than 2GB. This only works for float16
if float16 or (self.model_type != "unet"):
m = self.optimize_by_ort(m)
m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16)
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)

View file

@ -3,17 +3,23 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import logging
import os
import shutil
from collections import OrderedDict
from typing import Dict
from typing import Any, Dict
import torch
import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import TypeHelper
logger = logging.getLogger(__name__)
class OrtCudaSession:
"""ONNX Runtime Session for CUDA or TensorRT provider"""
"""Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False):
self.ort_session = ort_session
@ -110,3 +116,111 @@ class OrtCudaSession:
self.ort_session.run_with_iobinding(self.io_binding)
return self.output_tensors
class Engine(OrtCudaSession):
def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False):
self.engine_path = engine_path
self.provider = provider
self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph)
device = torch.device("cuda", device_id)
ort_session = ort.InferenceSession(
self.engine_path,
providers=[
(provider, self.provider_options),
"CPUExecutionProvider",
],
)
super().__init__(ort_session, device, enable_cuda_graph)
def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]:
return {
"device_id": device_id,
"arena_extend_strategy": "kSameAsRequested",
"enable_cuda_graph": enable_cuda_graph,
}
class Engines:
def __init__(self, provider, onnx_opset: int = 14):
self.provider = provider
self.engines = {}
self.onnx_opset = onnx_opset
@staticmethod
def get_onnx_path(onnx_dir, model_name):
return os.path.join(onnx_dir, model_name + ".onnx")
@staticmethod
def get_engine_path(engine_dir, model_name, profile_id):
return os.path.join(engine_dir, model_name + profile_id + ".onnx")
def build(
self,
models,
engine_dir: str,
onnx_dir: str,
force_engine_rebuild: bool = False,
fp16: bool = True,
device_id: int = 0,
enable_cuda_graph: bool = False,
):
profile_id = "_fp16" if fp16 else "_fp32"
if force_engine_rebuild:
if os.path.isdir(onnx_dir):
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir)
shutil.rmtree(onnx_dir)
if os.path.isdir(engine_dir):
logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir)
shutil.rmtree(engine_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Export models to ONNX
for model_name, model_obj in models.items():
onnx_path = Engines.get_onnx_path(onnx_dir, model_name)
onnx_opt_path = Engines.get_engine_path(engine_dir, model_name, profile_id)
if os.path.exists(onnx_opt_path):
logger.info("Found cached optimized model: %s", onnx_opt_path)
else:
if os.path.exists(onnx_path):
logger.info("Found cached model: %s", onnx_path)
else:
logger.info("Exporting model: %s", onnx_path)
model = model_obj.get_model().to(model_obj.device)
with torch.inference_mode():
inputs = model_obj.get_sample_input(1, 512, 512)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=self.onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
# Optimize onnx
logger.info("Generating optimized model: %s", onnx_opt_path)
model_obj.optimize_ort(onnx_path, onnx_opt_path, to_fp16=fp16)
for model_name in models:
engine_path = Engines.get_engine_path(engine_dir, model_name, profile_id)
engine = Engine(engine_path, self.provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph)
logger.info("%s options for %s: %s", self.provider, model_name, engine.provider_options)
self.engines[model_name] = engine
def get_engine(self, model_name):
return self.engines[model_name]