mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add onnx export script for segment anything v2 (#22119)
### Description Add ONNX export script for segment anything v2 (SAM2). ### Limitations * Does not support video. Only support image right now. * The decoder does not support batch inference. ### Credits The demo that is based on [SAM2 notebook](https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb), and modified to run with ORT. The export of decoder is inspired by https://github.com/vietanhdev/samexporter. ### Demo Example output of demo:  ### Motivation and Context For support optimization of SAM2 image segmentation.
This commit is contained in:
parent
05acfb90ab
commit
a9740d6f96
13 changed files with 1796 additions and 5 deletions
|
|
@ -476,6 +476,9 @@ file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPEND
|
|||
file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_sam2_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/sam2/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
|
||||
)
|
||||
|
|
@ -547,6 +550,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
|
||||
|
|
@ -656,6 +660,9 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_phi2_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_sam2_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_stable_diffusion_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from onnxruntime import InferenceSession, RunOptions
|
||||
|
||||
# Type alias
|
||||
ShapeDict = Mapping[str, Union[Tuple, List[int]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -262,7 +265,7 @@ class CudaSession:
|
|||
)
|
||||
self.output_tensors[self.buffer_sharing[name]] = tensor
|
||||
|
||||
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
|
||||
def allocate_buffers(self, shape_dict: ShapeDict):
|
||||
"""Allocate tensors for I/O Binding"""
|
||||
if self.enable_cuda_graph:
|
||||
for name, shape in shape_dict.items():
|
||||
|
|
@ -346,7 +349,7 @@ class GpuBinding(CudaSession):
|
|||
self,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
|
||||
shape_dict: ShapeDict,
|
||||
enable_gpu_graph: bool = False,
|
||||
gpu_graph_id: int = -1,
|
||||
stream: int = 0,
|
||||
|
|
@ -406,7 +409,7 @@ class GpuBindingManager:
|
|||
|
||||
def get_binding(
|
||||
self,
|
||||
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
|
||||
shape_dict: ShapeDict,
|
||||
use_cuda_graph: bool = False,
|
||||
buffer_sharing: Optional[Dict[str, str]] = None,
|
||||
) -> GpuBinding:
|
||||
|
|
|
|||
65
onnxruntime/python/tools/transformers/models/sam2/README.md
Normal file
65
onnxruntime/python/tools/transformers/models/sam2/README.md
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# SAM2 ONNX Model Export
|
||||
|
||||
## Setup Environment
|
||||
It is recommend to setup a machine with python 3.10, 3.11 or 3.12. Then install [PyTorch 2.4.1](https://pytorch.org/) and [Onnx Runtime 1.19.2].
|
||||
|
||||
### CPU Only
|
||||
To install the CPU-only version of PyTorch and Onnx Runtime for exporting and running ONNX models, use the following commands:
|
||||
```
|
||||
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
python3 -m pip install onnxruntime onnx opencv-python matplotlib
|
||||
```
|
||||
|
||||
### GPU
|
||||
If your machine has an NVIDIA GPU, you can install the CUDA version of PyTorch and Onnx Runtime for exporting and running ONNX models:
|
||||
|
||||
```
|
||||
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
|
||||
python3 -m pip install onnxruntime-gpu onnx opencv-python matplotlib
|
||||
```
|
||||
|
||||
onnxruntime-gpu requires CUDA 12.x, cuDNN 9.x, and other dependencies (such as MSVC Runtime on Windows). For more information, see the [installation guide](https://onnxruntime.ai/docs/install/#python-installs).
|
||||
|
||||
## Download Checkpoints
|
||||
|
||||
Clone the SAM2 git repository and download the checkpoints:
|
||||
```bash
|
||||
git clone https://github.com/facebookresearch/segment-anything-2.git
|
||||
cd segment-anything-2
|
||||
python3 -m pip install -e .
|
||||
cd checkpoints
|
||||
sh ./download_ckpts.sh
|
||||
```
|
||||
|
||||
On Windows, you can replace `sh ./download_ckpts.sh` with the following commands:
|
||||
```bash
|
||||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt > sam2_hiera_tiny.pt
|
||||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt > sam2_hiera_small.pt
|
||||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt > sam2_hiera_base_plus.pt
|
||||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt > sam2_hiera_large.pt
|
||||
```
|
||||
|
||||
## Export ONNX
|
||||
To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command:
|
||||
```bash
|
||||
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2
|
||||
```
|
||||
|
||||
The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option.
|
||||
|
||||
If you want the model outputs multiple masks, append the `--multimask_output` option.
|
||||
|
||||
To see all parameters, run the following command:
|
||||
```bash
|
||||
python3 convert_to_onnx.py -h
|
||||
```
|
||||
|
||||
## Run Demo
|
||||
The exported ONNX models can run on a CPU. The demo will output sam2_demo.png.
|
||||
```bash
|
||||
curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg
|
||||
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo
|
||||
```
|
||||
|
||||
## Limitations
|
||||
- The exported image_decoder model does not support batch mode for now.
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from image_decoder import export_decoder_onnx, test_decoder_onnx
|
||||
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
|
||||
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
|
||||
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
|
||||
from sam2_demo import run_demo, show_all_images
|
||||
from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
required=False,
|
||||
type=str,
|
||||
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
|
||||
default="sam2_hiera_large",
|
||||
help="The model type to export",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--components",
|
||||
required=False,
|
||||
nargs="+",
|
||||
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
|
||||
default=["image_encoder", "image_decoder"],
|
||||
help="Type of ONNX models to export. "
|
||||
"Note that image_decoder is a combination of prompt_encoder and mask_decoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
help="The output directory for the ONNX models",
|
||||
default="sam2_onnx_models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic_batch_axes",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Export image_encoder with dynamic batch axes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--multimask_output",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Export mask_decoder or image_decoder with multimask_output",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_dynamic_multimask_via_stability",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
|
||||
"This option will be ignored when multimask_output is True",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sam2_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./segment-anything-2",
|
||||
help="The directory of segment-anything-2 git repository",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Overwrite onnx model file if exists.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--demo",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Run demo with the exported ONNX models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print verbose information",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints")
|
||||
sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs")
|
||||
if not os.path.exists(args.sam2_dir):
|
||||
raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
if not os.path.exists(checkpoints_dir):
|
||||
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
if not os.path.exists(sam2_config_dir):
|
||||
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")):
|
||||
raise FileNotFoundError(
|
||||
f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory."
|
||||
)
|
||||
|
||||
if args.sam2_dir not in sys.path:
|
||||
sys.path.append(args.sam2_dir)
|
||||
|
||||
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu")
|
||||
|
||||
for component in args.components:
|
||||
if component == "image_encoder":
|
||||
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)
|
||||
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False)
|
||||
|
||||
elif component == "mask_decoder":
|
||||
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx")
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_mask_decoder_onnx(
|
||||
sam2_model,
|
||||
onnx_model_path,
|
||||
args.multimask_output,
|
||||
not args.disable_dynamic_multimask_via_stability,
|
||||
args.verbose,
|
||||
)
|
||||
test_mask_decoder_onnx(
|
||||
sam2_model,
|
||||
onnx_model_path,
|
||||
args.multimask_output,
|
||||
not args.disable_dynamic_multimask_via_stability,
|
||||
)
|
||||
elif component == "prompt_encoder":
|
||||
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx")
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_prompt_encoder_onnx(sam2_model, onnx_model_path)
|
||||
test_prompt_encoder_onnx(sam2_model, onnx_model_path)
|
||||
elif component == "image_decoder":
|
||||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output)
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
|
||||
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
|
||||
|
||||
if args.demo:
|
||||
# Export required ONNX models for demo if not already exported.
|
||||
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
|
||||
if not os.path.exists(onnx_model_path):
|
||||
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)
|
||||
|
||||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True)
|
||||
if not os.path.exists(onnx_model_path):
|
||||
export_decoder_onnx(sam2_model, onnx_model_path, True)
|
||||
|
||||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False)
|
||||
if not os.path.exists(onnx_model_path):
|
||||
export_decoder_onnx(sam2_model, onnx_model_path, False)
|
||||
|
||||
ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir)
|
||||
print("demo output files for ONNX Runtime:", ort_image_files)
|
||||
|
||||
# Get results from torch engine to compare.
|
||||
torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir)
|
||||
print("demo output files for PyTorch:", torch_image_files)
|
||||
|
||||
show_all_images(ort_image_files, torch_image_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger(verbose=False)
|
||||
with torch.no_grad():
|
||||
main()
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
|
||||
from mask_decoder import SAM2MaskDecoder
|
||||
from prompt_encoder import SAM2PromptEncoder
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2ImageDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
return_logits: bool = False,
|
||||
mask_threshold: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.prompt_encoder = SAM2PromptEncoder(sam_model)
|
||||
self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
|
||||
self.return_logits = return_logits
|
||||
self.mask_threshold = mask_threshold
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_features_0: torch.Tensor,
|
||||
image_features_1: torch.Tensor,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
input_masks: torch.Tensor,
|
||||
has_input_masks: torch.Tensor,
|
||||
original_image_size: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
|
||||
|
||||
Args:
|
||||
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
|
||||
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
|
||||
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
|
||||
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
|
||||
coordinate in (x, y) format of the P input points in image of size 1024x1024.
|
||||
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
|
||||
positive (foreground), 0 means negative (background), -1 means padding,
|
||||
2 (box left upper corner), 3 (box right bottom corner).
|
||||
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
|
||||
Typically coming from a previous iteration.
|
||||
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
|
||||
original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
|
||||
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
|
||||
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
|
||||
"""
|
||||
sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
|
||||
)
|
||||
|
||||
# Interpolate the low resolution masks back to the original image size.
|
||||
masks = F.interpolate(
|
||||
low_res_masks,
|
||||
(original_image_size[0], original_image_size[1]),
|
||||
mode="bilinear",
|
||||
align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
|
||||
)
|
||||
|
||||
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
||||
if not self.return_logits:
|
||||
masks = masks > self.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
|
||||
def export_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
batch_size = 1
|
||||
image = random_sam2_input_image(batch_size)
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
sam2_decoder = SAM2ImageDecoder(
|
||||
sam2_model,
|
||||
multimask_output=multimask_output,
|
||||
dynamic_multimask_via_stability=True,
|
||||
).cpu()
|
||||
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
|
||||
|
||||
example_inputs = (
|
||||
image_features_0,
|
||||
image_features_1,
|
||||
image_embeddings,
|
||||
point_coords,
|
||||
point_labels,
|
||||
input_masks,
|
||||
has_input_masks,
|
||||
original_image_size,
|
||||
)
|
||||
|
||||
logger.info("point_coords.shape: %s", point_coords.shape)
|
||||
logger.info("point_labels.shape: %s", point_labels.shape)
|
||||
logger.info("input_masks.shape: %s", input_masks.shape)
|
||||
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
|
||||
logger.info("original_image_size.shape: %s", original_image_size.shape)
|
||||
|
||||
if verbose:
|
||||
masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
|
||||
logger.info("masks.shape: %s", masks.shape)
|
||||
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
|
||||
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
|
||||
|
||||
input_names = [
|
||||
"image_features_0",
|
||||
"image_features_1",
|
||||
"image_embeddings",
|
||||
"point_coords",
|
||||
"point_labels",
|
||||
"input_masks",
|
||||
"has_input_masks",
|
||||
"original_image_size",
|
||||
]
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {0: "num_labels", 1: "num_points"},
|
||||
"point_labels": {0: "num_labels", 1: "num_points"},
|
||||
"input_masks": {0: "num_labels"},
|
||||
"has_input_masks": {0: "num_labels"},
|
||||
"masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
|
||||
"low_res_masks": {0: "num_labels"},
|
||||
"iou_predictions": {0: "num_labels"},
|
||||
}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
torch.onnx.export(
|
||||
sam2_decoder,
|
||||
example_inputs,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=16,
|
||||
do_constant_folding=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
logger.info("decoder onnx model saved to %s", onnx_model_path)
|
||||
|
||||
|
||||
def test_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output=False,
|
||||
):
|
||||
|
||||
batch_size = 1
|
||||
image = random_sam2_input_image(batch_size)
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
sam2_image_decoder = SAM2ImageDecoder(
|
||||
sam2_model,
|
||||
multimask_output=multimask_output,
|
||||
dynamic_multimask_via_stability=True,
|
||||
).cpu()
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.zeros(1, dtype=torch.float)
|
||||
original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
|
||||
|
||||
example_inputs = (
|
||||
image_features_0,
|
||||
image_features_1,
|
||||
image_embeddings,
|
||||
point_coords,
|
||||
point_labels,
|
||||
input_masks,
|
||||
has_input_masks,
|
||||
original_image_size,
|
||||
)
|
||||
|
||||
masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
|
||||
|
||||
import onnxruntime
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
|
||||
outputs = ort_session.run(output_names, inputs)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info(f"{output_name}.shape: %s", outputs[i].shape)
|
||||
|
||||
ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
|
||||
if (
|
||||
compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
|
||||
and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
|
||||
and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
|
||||
):
|
||||
print("onnx model has been verified:", onnx_model_path)
|
||||
else:
|
||||
print("onnx model verification failed:", onnx_model_path)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
|
||||
from torch import nn
|
||||
|
||||
import onnxruntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2ImageEncoder(nn.Module):
|
||||
def __init__(self, sam_model: SAM2Base) -> None:
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
self.image_encoder = sam_model.image_encoder
|
||||
self.no_mem_embed = sam_model.no_mem_embed
|
||||
|
||||
def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Encodes images into features.
|
||||
|
||||
Only supports H=W=1024. If you want to use different image sizes like 512x512,
|
||||
see https://github.com/facebookresearch/segment-anything-2/issues/138.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
|
||||
|
||||
Returns:
|
||||
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
|
||||
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
|
||||
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
|
||||
"""
|
||||
backbone_out = self.image_encoder(image)
|
||||
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
||||
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
||||
|
||||
# Prepare and flatten visual features.
|
||||
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
|
||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
|
||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
||||
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
|
||||
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
|
||||
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
|
||||
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
|
||||
][::-1]
|
||||
|
||||
return feats[0], feats[1], feats[2]
|
||||
|
||||
|
||||
def export_image_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
dynamic_batch_axes: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
image = random_sam2_input_image()
|
||||
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
logger.info("image.shape: %s", image.shape)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
dynamic_axes = None
|
||||
if dynamic_batch_axes:
|
||||
dynamic_axes = {
|
||||
"image": {0: "batch_size"},
|
||||
"image_features_0": {0: "batch_size"},
|
||||
"image_features_1": {0: "batch_size"},
|
||||
"image_embeddings": {0: "batch_size"},
|
||||
}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
torch.onnx.export(
|
||||
sam2_encoder,
|
||||
image,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
input_names=["image"],
|
||||
output_names=["image_features_0", "image_features_1", "image_embeddings"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
print("encoder onnx model saved to", onnx_model_path)
|
||||
|
||||
|
||||
def test_image_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
dynamic_batch_axes=False,
|
||||
):
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
|
||||
for batch_size in batch_sizes:
|
||||
image = random_sam2_input_image(batch_size)
|
||||
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
|
||||
|
||||
logger.info("image.shape: %s", image.shape)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
outputs = ort_session.run(output_names, {"image": image.numpy()})
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape %s", output_name, outputs[i].shape)
|
||||
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
|
||||
|
||||
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
|
||||
if (
|
||||
compare_tensors_with_tolerance(
|
||||
"image_features_0",
|
||||
image_features_0,
|
||||
torch.tensor(ort_image_features_0),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_features_1",
|
||||
image_features_1,
|
||||
torch.tensor(ort_image_features_1),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_embeddings",
|
||||
image_embeddings,
|
||||
torch.tensor(ort_image_embeddings),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
):
|
||||
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
|
||||
else:
|
||||
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
|
||||
from prompt_encoder import SAM2PromptEncoder
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = sam_model.sam_mask_decoder
|
||||
self.prompt_encoder = sam_model.sam_prompt_encoder
|
||||
self.model = sam_model
|
||||
self.multimask_output = multimask_output
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_features_0: torch.Tensor,
|
||||
image_features_1: torch.Tensor,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_embeddings: torch.Tensor,
|
||||
dense_embeddings: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Decode masks from image and prompt embeddings. Only support H=W=1024.
|
||||
|
||||
Args:
|
||||
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
|
||||
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
|
||||
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
|
||||
image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
|
||||
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
||||
dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
|
||||
|
||||
Returns:
|
||||
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
|
||||
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
|
||||
"""
|
||||
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
|
||||
high_res_features=[image_features_0, image_features_1],
|
||||
)
|
||||
|
||||
if self.multimask_output:
|
||||
low_res_masks = low_res_masks[:, 1:, :, :]
|
||||
iou_predictions = iou_predictions[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability:
|
||||
# When outputting a single mask, if the stability score from the current single-mask
|
||||
# output (based on output token 0) falls below a threshold, we instead select from
|
||||
# multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
|
||||
low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
|
||||
low_res_masks, iou_predictions
|
||||
)
|
||||
else:
|
||||
low_res_masks = low_res_masks[:, 0:1, :, :]
|
||||
iou_predictions = iou_predictions[:, 0:1]
|
||||
|
||||
return low_res_masks, iou_predictions
|
||||
|
||||
|
||||
def export_mask_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
verbose=False,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
image = random_sam2_input_image()
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
# encode an random prompt
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
||||
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
||||
logger.info("image_pe.shape: %s", image_pe.shape)
|
||||
|
||||
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
||||
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
||||
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
||||
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
|
||||
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
torch.onnx.export(
|
||||
sam2_mask_decoder,
|
||||
inputs,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=[
|
||||
"image_features_0",
|
||||
"image_features_1",
|
||||
"image_embeddings",
|
||||
"image_pe",
|
||||
"sparse_embeddings",
|
||||
"dense_embeddings",
|
||||
],
|
||||
output_names=["low_res_masks", "iou_predictions"],
|
||||
dynamic_axes={
|
||||
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
||||
"dense_embeddings": {0: "num_labels"},
|
||||
"low_res_masks": {0: "num_labels"},
|
||||
"iou_predictions": {0: "num_labels"},
|
||||
},
|
||||
)
|
||||
|
||||
print("mask decoder onnx model saved to", onnx_model_path)
|
||||
|
||||
|
||||
def test_mask_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
image = random_sam2_input_image()
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
||||
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
||||
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
||||
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
||||
|
||||
import onnxruntime
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
outputs = ort_session.run(
|
||||
output_names,
|
||||
{
|
||||
"image_features_0": image_features_0.numpy(),
|
||||
"image_features_1": image_features_1.numpy(),
|
||||
"image_embeddings": image_embeddings.numpy(),
|
||||
"image_pe": image_pe.numpy(),
|
||||
"sparse_embeddings": sparse_embeddings.numpy(),
|
||||
"dense_embeddings": dense_embeddings.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
||||
|
||||
ort_low_res_masks, ort_iou_predictions = outputs
|
||||
torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
|
||||
torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
|
||||
print(f"onnx model has been verified: {onnx_model_path}")
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2PromptEncoder(nn.Module):
|
||||
def __init__(self, sam_model: SAM2Base):
|
||||
super().__init__()
|
||||
self.prompt_encoder = sam_model.sam_prompt_encoder
|
||||
self.model = sam_model
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
input_masks: torch.Tensor,
|
||||
has_input_masks: torch.Tensor,
|
||||
):
|
||||
"""Encode prompts.
|
||||
|
||||
Args:
|
||||
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
|
||||
coordinate in (x, y) format of the P input points in image of size 1024x1024.
|
||||
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
|
||||
positive (foreground), 0 means negative (background), -1 means padding,
|
||||
2 (box left upper corner), 3 (box right bottom corner).
|
||||
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
|
||||
Typically coming from a previous iteration.
|
||||
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
|
||||
Returns:
|
||||
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
||||
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
|
||||
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
|
||||
"""
|
||||
sparse_embeddings = self._embed_points(point_coords, point_labels)
|
||||
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
|
||||
image_pe = self.prompt_encoder.get_dense_pe()
|
||||
|
||||
return sparse_embeddings, dense_embeddings, image_pe
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
|
||||
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
|
||||
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
|
||||
point_coords = torch.cat([point_coords, padding_point], dim=1)
|
||||
point_labels = torch.cat([point_labels, padding_label], dim=1)
|
||||
|
||||
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
|
||||
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
|
||||
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
|
||||
|
||||
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
|
||||
|
||||
for i in range(self.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
|
||||
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
|
||||
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
|
||||
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
|
||||
return mask_embedding
|
||||
|
||||
|
||||
def export_prompt_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
logger.info("point_coords.shape: %s", point_coords.shape)
|
||||
logger.info("point_labels.shape: %s", point_labels.shape)
|
||||
logger.info("input_masks.shape: %s", input_masks.shape)
|
||||
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
|
||||
|
||||
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
||||
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
||||
logger.info("image_pe.shape: %s", image_pe.shape)
|
||||
|
||||
torch.onnx.export(
|
||||
sam2_prompt_encoder,
|
||||
(point_coords, point_labels, input_masks, has_input_masks),
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
|
||||
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
|
||||
dynamic_axes={
|
||||
"point_coords": {0: "num_labels", 1: "num_points"},
|
||||
"point_labels": {0: "num_labels", 1: "num_points"},
|
||||
"input_masks": {0: "num_labels"},
|
||||
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
||||
"dense_embeddings": {0: "num_labels"},
|
||||
},
|
||||
)
|
||||
|
||||
print("prompt encoder onnx model saved to ", onnx_model_path)
|
||||
|
||||
|
||||
def test_prompt_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
import onnxruntime
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
outputs = ort_session.run(
|
||||
output_names,
|
||||
{
|
||||
"point_coords": point_coords.numpy(),
|
||||
"point_labels": point_labels.numpy(),
|
||||
"input_masks": input_masks.numpy(),
|
||||
"has_input_masks": has_input_masks.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
||||
|
||||
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
|
||||
if (
|
||||
compare_tensors_with_tolerance(
|
||||
"sparse_embeddings",
|
||||
sparse_embeddings,
|
||||
torch.tensor(ort_sparse_embeddings),
|
||||
mismatch_percentage_tolerance=0.2,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
|
||||
)
|
||||
):
|
||||
print(f"onnx model has been verified: {onnx_model_path}")
|
||||
else:
|
||||
print(f"onnx model verification failed: {onnx_model_path}")
|
||||
293
onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py
Normal file
293
onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
|
||||
from sam2_utils import build_sam2_model
|
||||
|
||||
|
||||
def show_mask(mask, ax, random_color=False, borders=True):
|
||||
if random_color:
|
||||
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||||
else:
|
||||
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
||||
h, w = mask.shape[-2:]
|
||||
mask = mask.astype(np.uint8)
|
||||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||||
if borders:
|
||||
import cv2
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
# Try to smooth contours
|
||||
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
||||
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
||||
ax.imshow(mask_image)
|
||||
|
||||
|
||||
def show_points(coords, labels, ax, marker_size=375):
|
||||
pos_points = coords[labels == 1]
|
||||
neg_points = coords[labels == 0]
|
||||
ax.scatter(
|
||||
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
||||
)
|
||||
ax.scatter(
|
||||
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
||||
)
|
||||
|
||||
|
||||
def show_box(box, ax):
|
||||
x0, y0 = box[0], box[1]
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
||||
|
||||
|
||||
def show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=None,
|
||||
box_coords=None,
|
||||
input_labels=None,
|
||||
borders=True,
|
||||
output_image_file_prefix=None,
|
||||
image_files=None,
|
||||
):
|
||||
for i, (mask, score) in enumerate(zip(masks, scores)):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(image)
|
||||
show_mask(mask, plt.gca(), borders=borders)
|
||||
if point_coords is not None:
|
||||
assert input_labels is not None
|
||||
show_points(point_coords, input_labels, plt.gca())
|
||||
|
||||
if box_coords is not None:
|
||||
show_box(box_coords, plt.gca())
|
||||
|
||||
if len(scores) > 1:
|
||||
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
||||
|
||||
plt.axis("off")
|
||||
if output_image_file_prefix:
|
||||
filename = f"{output_image_file_prefix}_{i}.png"
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
|
||||
if isinstance(image_files, list):
|
||||
image_files.append(filename)
|
||||
plt.show(block=False)
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_predictor(
|
||||
checkpoint_dir: str,
|
||||
device: torch.device,
|
||||
model_type="sam2_hiera_large",
|
||||
engine="torch",
|
||||
onnx_directory="sam2_onnx_models",
|
||||
):
|
||||
sam2_model = build_sam2_model(checkpoint_dir, model_type, device=device)
|
||||
if engine == "torch":
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
else:
|
||||
predictor = SAM2ImageOnnxPredictor(sam2_model, onnx_directory=onnx_directory, model_type=model_type)
|
||||
return predictor
|
||||
|
||||
|
||||
def run_demo(
|
||||
checkpoint_dir: str,
|
||||
model_type="sam2_hiera_large",
|
||||
engine="torch",
|
||||
onnx_directory="sam2_onnx_models",
|
||||
enable_batch=False,
|
||||
):
|
||||
use_gpu = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_gpu else "cpu")
|
||||
|
||||
if use_gpu:
|
||||
if engine == "torch":
|
||||
# Turn on tfloat32 for Ampere GPUs.
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
elif engine == "ort":
|
||||
import onnxruntime
|
||||
|
||||
assert use_gpu == ("CUDAExecutionProvider" in onnxruntime.get_available_providers())
|
||||
|
||||
np.random.seed(3)
|
||||
image = Image.open("truck.jpg")
|
||||
image = np.array(image.convert("RGB"))
|
||||
|
||||
predictor = get_predictor(checkpoint_dir, device, model_type, engine, onnx_directory=onnx_directory)
|
||||
|
||||
predictor.set_image(image)
|
||||
prefix = f"sam2_demo_{engine}_"
|
||||
|
||||
# The model returns masks, quality predictions for those masks,
|
||||
# and low resolution mask logits that can be passed to the next iteration of prediction.
|
||||
# With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
|
||||
# scores gives the model's own estimation of the quality of these masks.
|
||||
# For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
|
||||
# even if only a single mask is desired;
|
||||
input_point = np.array([[500, 375]])
|
||||
input_label = np.array([1])
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
sorted_ind = np.argsort(scores)[::-1]
|
||||
masks = masks[sorted_ind]
|
||||
scores = scores[sorted_ind]
|
||||
logits = logits[sorted_ind]
|
||||
|
||||
image_files = []
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
borders=True,
|
||||
output_image_file_prefix=prefix + "multimask",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Multiple points.
|
||||
input_point = np.array([[500, 375], [1125, 625]])
|
||||
input_label = np.array([1, 1])
|
||||
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
mask_input=mask_input[None, :, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "multi_points",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Specify a window and a background point.
|
||||
input_point = np.array([[500, 375], [1125, 625]])
|
||||
input_label = np.array([1, 0])
|
||||
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
mask_input=mask_input[None, :, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "background_point",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Take a box as input
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_box[None, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
box_coords=input_box,
|
||||
output_image_file_prefix=prefix + "box",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Combining points and boxes
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
input_point = np.array([[575, 750]])
|
||||
input_label = np.array([0])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
box=input_box,
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
box_coords=input_box,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "box_and_point",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# TODO: support batched prompt inputs
|
||||
if enable_batch:
|
||||
input_boxes = np.array(
|
||||
[
|
||||
[75, 275, 1725, 850],
|
||||
[425, 600, 700, 875],
|
||||
[1375, 550, 1650, 800],
|
||||
[1240, 675, 1400, 750],
|
||||
]
|
||||
)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(image)
|
||||
for mask in masks:
|
||||
show_mask(mask.squeeze(0), plt.gca(), random_color=True)
|
||||
for box in input_boxes:
|
||||
show_box(box, plt.gca())
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
plt.savefig(prefix + "batch_prompt.png")
|
||||
image_files.append(prefix + "batch_prompt.png")
|
||||
return image_files
|
||||
|
||||
|
||||
def show_all_images(left_images, right_images):
|
||||
# Show images in two rows since display screen is horizontal in most cases.
|
||||
fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
|
||||
for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)):
|
||||
left_img = mpimg.imread(left_img_path)
|
||||
right_img = mpimg.imread(right_img_path)
|
||||
|
||||
axes[0, i].imshow(left_img)
|
||||
axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
||||
axes[0, i].axis("off")
|
||||
axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
|
||||
|
||||
axes[1, i].imshow(right_img)
|
||||
axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
||||
axes[1, i].axis("off")
|
||||
axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("sam2_demo.png", format="png", bbox_inches="tight", dpi=1000)
|
||||
plt.show()
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2_utils import decoder_shape_dict, encoder_shape_dict, get_decoder_onnx_path, get_image_encoder_onnx_path
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_ort_session(
|
||||
onnx_path: str,
|
||||
session_options=None,
|
||||
provider="CUDAExecutionProvider",
|
||||
enable_cuda_graph=False,
|
||||
use_tf32=True,
|
||||
) -> InferenceSession:
|
||||
if provider == "CUDAExecutionProvider":
|
||||
device_id = torch.cuda.current_device()
|
||||
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
provider_options["use_tf32"] = int(use_tf32)
|
||||
providers = [(provider, provider_options), "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
print(f"Using providers: {providers}")
|
||||
return InferenceSession(onnx_path, session_options, providers=providers)
|
||||
|
||||
|
||||
def create_session(
|
||||
onnx_path: str, session_options=None, provider="CUDAExecutionProvider", device="cuda", enable_cuda_graph=False
|
||||
) -> CudaSession:
|
||||
ort_session = create_ort_session(
|
||||
onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
|
||||
)
|
||||
cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
|
||||
return cuda_session
|
||||
|
||||
|
||||
class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
onnx_directory: str = "sam2_onnx_models",
|
||||
model_type: str = "sam2_hiera_large",
|
||||
onnx_dtype: torch.dtype = torch.float32,
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
|
||||
|
||||
Arguments:
|
||||
sam_model (SAM2Base): The model to use for mask prediction.
|
||||
onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
|
||||
onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
|
||||
mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
|
||||
max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__(
|
||||
sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
|
||||
)
|
||||
|
||||
print(self.device)
|
||||
if torch.cuda.is_available():
|
||||
provider = "CUDAExecutionProvider"
|
||||
device = "cuda"
|
||||
else:
|
||||
provider = "CPUExecutionProvider"
|
||||
device = "cpu"
|
||||
|
||||
# This model is exported by image_encoder.py.
|
||||
onnx_path = get_image_encoder_onnx_path(onnx_directory, model_type)
|
||||
|
||||
self.encoder_session = create_session(
|
||||
onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
self.onnx_dtype = onnx_dtype
|
||||
|
||||
# This model is exported by image_decoder.py. It outputs only one mask.
|
||||
onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=False)
|
||||
self.decoder_session = create_session(
|
||||
onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
|
||||
# This model is exported by image_decoder.py. It outputs multiple (3) masks.
|
||||
onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=True)
|
||||
self.decoder_session_multi_out = create_session(
|
||||
onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(self, image: Union[np.ndarray, Image]):
|
||||
"""
|
||||
Calculates the image embeddings for the provided image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray or PIL Image): The input image to embed in RGB format.
|
||||
The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
|
||||
"""
|
||||
self.reset_predictor()
|
||||
# Transform the image to the form expected by the model
|
||||
if isinstance(image, np.ndarray):
|
||||
# For numpy array image, we assume (HxWxC) format.
|
||||
self._orig_hw = [image.shape[:2]]
|
||||
elif isinstance(image, Image):
|
||||
w, h = image.size
|
||||
self._orig_hw = [(h, w)]
|
||||
else:
|
||||
raise NotImplementedError("Image format not supported")
|
||||
|
||||
input_image = self._transforms(image)
|
||||
input_image = input_image[None, ...].to(self.device)
|
||||
|
||||
assert (
|
||||
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
||||
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
|
||||
# Computing image embeddings for the provided image
|
||||
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
|
||||
self.encoder_session.allocate_buffers(io_shapes)
|
||||
|
||||
feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
|
||||
|
||||
for key, value in feed_dict.items():
|
||||
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
||||
logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
|
||||
|
||||
ort_outputs = self.encoder_session.infer(feed_dict)
|
||||
|
||||
self._features = {
|
||||
"image_embed": ort_outputs["image_embeddings"],
|
||||
"high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
|
||||
}
|
||||
self._is_image_set = True
|
||||
logging.info("Image embeddings computed.")
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict(
|
||||
self,
|
||||
point_coords: Optional[torch.Tensor],
|
||||
point_labels: Optional[torch.Tensor],
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
img_idx: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Input prompts are batched torch tensors and are expected to already be
|
||||
transformed to the input frame using SAM2Transforms.
|
||||
|
||||
Arguments:
|
||||
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (torch.Tensor or None): A BxN array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||
for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||
predict method do not need further transformation.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(torch.Tensor): An array of shape BxC containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
||||
of masks and H=W=256. These low res logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
assert not return_logits # onnx model is exported for returning bool masks.
|
||||
|
||||
if not self._is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
if point_coords is not None:
|
||||
concat_points = (point_coords, point_labels)
|
||||
else:
|
||||
concat_points = None
|
||||
|
||||
# Embed prompts
|
||||
if boxes is not None:
|
||||
box_coords = boxes.reshape(-1, 2, 2)
|
||||
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
||||
box_labels = box_labels.repeat(boxes.size(0), 1)
|
||||
# we merge "boxes" and "points" into a single "concat_points" input (where
|
||||
# boxes are added at the beginning) to sam_prompt_encoder
|
||||
if concat_points is not None:
|
||||
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
||||
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
||||
concat_points = (concat_coords, concat_labels)
|
||||
else:
|
||||
concat_points = (box_coords, box_labels)
|
||||
|
||||
assert concat_points is not None
|
||||
num_labels = concat_points[0].shape[0]
|
||||
shape_dict = decoder_shape_dict(
|
||||
original_image_height=self._orig_hw[img_idx][0],
|
||||
original_image_width=self._orig_hw[img_idx][1],
|
||||
num_labels=num_labels,
|
||||
max_points=concat_points[0].shape[1],
|
||||
num_masks=3 if multimask_output else 1,
|
||||
)
|
||||
if multimask_output:
|
||||
decoder_session = self.decoder_session_multi_out
|
||||
else:
|
||||
decoder_session = self.decoder_session
|
||||
|
||||
decoder_session.allocate_buffers(shape_dict)
|
||||
|
||||
image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
|
||||
image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
|
||||
image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
|
||||
|
||||
if mask_input is None:
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float, device=self.device)
|
||||
has_input_masks = torch.zeros(num_labels, dtype=torch.float, device=self.device)
|
||||
else:
|
||||
input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
|
||||
has_input_masks = torch.ones(num_labels, dtype=torch.float, device=self.device)
|
||||
|
||||
feed_dict = {
|
||||
"image_embeddings": image_embeddings.contiguous().to(dtype=torch.float32).to(self.device),
|
||||
"image_features_0": image_features_0.contiguous().to(dtype=torch.float32).to(self.device),
|
||||
"image_features_1": image_features_1.contiguous().to(dtype=torch.float32).to(self.device),
|
||||
"point_coords": concat_points[0].to(dtype=torch.float32).to(self.device),
|
||||
"point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
|
||||
"input_masks": input_masks.to(dtype=torch.float32).to(self.device),
|
||||
"has_input_masks": has_input_masks.to(dtype=torch.float32).to(self.device),
|
||||
"original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
|
||||
}
|
||||
|
||||
for key, value in feed_dict.items():
|
||||
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
||||
logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
|
||||
|
||||
ort_outputs = decoder_session.infer(feed_dict)
|
||||
|
||||
masks = ort_outputs["masks"]
|
||||
iou_predictions = ort_outputs["iou_predictions"]
|
||||
low_res_masks = ort_outputs["low_res_masks"]
|
||||
|
||||
return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)
|
||||
122
onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py
Normal file
122
onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_model_cfg(model_type) -> str:
|
||||
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
|
||||
if model_type == "sam2_hiera_tiny":
|
||||
model_cfg = "sam2_hiera_t.yaml"
|
||||
elif model_type == "sam2_hiera_small":
|
||||
model_cfg = "sam2_hiera_s.yaml"
|
||||
elif model_type == "sam2_hiera_base_plus":
|
||||
model_cfg = "sam2_hiera_b+.yaml"
|
||||
else:
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
return model_cfg
|
||||
|
||||
|
||||
def build_sam2_model(checkpoint_dir: str, model_type: str, device="cpu") -> SAM2Base:
|
||||
sam2_checkpoint = os.path.join(checkpoint_dir, f"{model_type}.pt")
|
||||
model_cfg = get_model_cfg(model_type)
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||
return sam2_model
|
||||
|
||||
|
||||
def get_decoder_onnx_path(dir: str, model_type, multimask_output) -> str:
|
||||
return os.path.join(dir, f"{model_type}_decoder" + ("_multi" if multimask_output else "") + ".onnx")
|
||||
|
||||
|
||||
def get_image_encoder_onnx_path(dir: str, model_type) -> str:
|
||||
return os.path.join(dir, f"{model_type}_image_encoder.onnx")
|
||||
|
||||
|
||||
def encoder_shape_dict(batch_size: int, height: int, width: int):
|
||||
assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
|
||||
return {
|
||||
"image": [batch_size, 3, height, width],
|
||||
"image_features_0": [batch_size, 32, height // 4, width // 4],
|
||||
"image_features_1": [batch_size, 64, height // 8, width // 8],
|
||||
"image_embeddings": [batch_size, 256, height // 16, width // 16],
|
||||
}
|
||||
|
||||
|
||||
def decoder_shape_dict(
|
||||
original_image_height: int,
|
||||
original_image_width: int,
|
||||
num_labels: int = 1,
|
||||
max_points: int = 16,
|
||||
num_masks: int = 1,
|
||||
) -> dict:
|
||||
height: int = 1024
|
||||
width: int = 1024
|
||||
return {
|
||||
"image_features_0": [1, 32, height // 4, width // 4],
|
||||
"image_features_1": [1, 64, height // 8, width // 8],
|
||||
"image_embeddings": [1, 256, height // 16, width // 16],
|
||||
"point_coords": [num_labels, max_points, 2],
|
||||
"point_labels": [num_labels, max_points],
|
||||
"input_masks": [num_labels, 1, height // 4, width // 4],
|
||||
"has_input_masks": [num_labels],
|
||||
"original_image_size": [2],
|
||||
"masks": [num_labels, num_masks, original_image_height, original_image_width],
|
||||
"iou_predictions": [num_labels, num_masks],
|
||||
"low_res_masks": [num_labels, num_masks, height // 4, width // 4],
|
||||
}
|
||||
|
||||
|
||||
def compare_tensors_with_tolerance(
|
||||
name: str,
|
||||
tensor1: torch.Tensor,
|
||||
tensor2: torch.Tensor,
|
||||
atol=5e-3,
|
||||
rtol=1e-4,
|
||||
mismatch_percentage_tolerance=0.1,
|
||||
) -> bool:
|
||||
assert tensor1.shape == tensor2.shape
|
||||
a = tensor1.clone().float()
|
||||
b = tensor2.clone().float()
|
||||
|
||||
differences = torch.abs(a - b)
|
||||
mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
|
||||
|
||||
total_elements = a.numel()
|
||||
mismatch_percentage = (mismatch_count / total_elements) * 100
|
||||
|
||||
passed = mismatch_percentage < mismatch_percentage_tolerance
|
||||
|
||||
log_func = logger.error if not passed else logger.info
|
||||
log_func(
|
||||
"%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
|
||||
name,
|
||||
mismatch_percentage,
|
||||
mismatch_count,
|
||||
total_elements,
|
||||
"passed" if passed else "failed",
|
||||
mismatch_percentage_tolerance,
|
||||
)
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
|
||||
image = torch.randn(batch_size, 3, image_height, image_width).cpu()
|
||||
return image
|
||||
|
||||
|
||||
def setup_logger(verbose=True):
|
||||
if verbose:
|
||||
logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
else:
|
||||
logging.basicConfig(format="[%(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
3
setup.py
3
setup.py
|
|
@ -494,8 +494,9 @@ packages = [
|
|||
"onnxruntime.transformers.models.llama",
|
||||
"onnxruntime.transformers.models.longformer",
|
||||
"onnxruntime.transformers.models.phi2",
|
||||
"onnxruntime.transformers.models.t5",
|
||||
"onnxruntime.transformers.models.sam2",
|
||||
"onnxruntime.transformers.models.stable_diffusion",
|
||||
"onnxruntime.transformers.models.t5",
|
||||
"onnxruntime.transformers.models.whisper",
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue