Add onnx export script for segment anything v2 (#22119)

### Description
Add ONNX export script for segment anything v2 (SAM2).

### Limitations
* Does not support video. Only support image right now.
* The decoder does not support batch inference.

### Credits
The demo that is based on [SAM2
notebook](https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb),
and modified to run with ORT.

The export of decoder is inspired by
https://github.com/vietanhdev/samexporter.

### Demo
Example output of demo:

![sam2_demo](https://github.com/user-attachments/assets/9a9fa360-8c20-482e-9935-a7aba9cf15de)

### Motivation and Context
For support optimization of SAM2 image segmentation.
This commit is contained in:
Tianlei Wu 2024-09-18 14:31:59 -07:00 committed by GitHub
parent 05acfb90ab
commit a9740d6f96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1796 additions and 5 deletions

View file

@ -476,6 +476,9 @@ file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPEND
file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_sam2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/sam2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
)
@ -547,6 +550,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
@ -656,6 +660,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_phi2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_sam2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_stable_diffusion_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/

View file

@ -1,13 +1,16 @@
import copy
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy
import torch
from onnxruntime import InferenceSession, RunOptions
# Type alias
ShapeDict = Mapping[str, Union[Tuple, List[int]]]
logger = logging.getLogger(__name__)
@ -262,7 +265,7 @@ class CudaSession:
)
self.output_tensors[self.buffer_sharing[name]] = tensor
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
def allocate_buffers(self, shape_dict: ShapeDict):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
for name, shape in shape_dict.items():
@ -346,7 +349,7 @@ class GpuBinding(CudaSession):
self,
ort_session: InferenceSession,
device: torch.device,
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
shape_dict: ShapeDict,
enable_gpu_graph: bool = False,
gpu_graph_id: int = -1,
stream: int = 0,
@ -406,7 +409,7 @@ class GpuBindingManager:
def get_binding(
self,
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
shape_dict: ShapeDict,
use_cuda_graph: bool = False,
buffer_sharing: Optional[Dict[str, str]] = None,
) -> GpuBinding:

View file

@ -0,0 +1,65 @@
# SAM2 ONNX Model Export
## Setup Environment
It is recommend to setup a machine with python 3.10, 3.11 or 3.12. Then install [PyTorch 2.4.1](https://pytorch.org/) and [Onnx Runtime 1.19.2].
### CPU Only
To install the CPU-only version of PyTorch and Onnx Runtime for exporting and running ONNX models, use the following commands:
```
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install onnxruntime onnx opencv-python matplotlib
```
### GPU
If your machine has an NVIDIA GPU, you can install the CUDA version of PyTorch and Onnx Runtime for exporting and running ONNX models:
```
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
python3 -m pip install onnxruntime-gpu onnx opencv-python matplotlib
```
onnxruntime-gpu requires CUDA 12.x, cuDNN 9.x, and other dependencies (such as MSVC Runtime on Windows). For more information, see the [installation guide](https://onnxruntime.ai/docs/install/#python-installs).
## Download Checkpoints
Clone the SAM2 git repository and download the checkpoints:
```bash
git clone https://github.com/facebookresearch/segment-anything-2.git
cd segment-anything-2
python3 -m pip install -e .
cd checkpoints
sh ./download_ckpts.sh
```
On Windows, you can replace `sh ./download_ckpts.sh` with the following commands:
```bash
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt > sam2_hiera_tiny.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt > sam2_hiera_small.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt > sam2_hiera_base_plus.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt > sam2_hiera_large.pt
```
## Export ONNX
To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command:
```bash
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2
```
The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option.
If you want the model outputs multiple masks, append the `--multimask_output` option.
To see all parameters, run the following command:
```bash
python3 convert_to_onnx.py -h
```
## Run Demo
The exported ONNX models can run on a CPU. The demo will output sam2_demo.png.
```bash
curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo
```
## Limitations
- The exported image_decoder model does not support batch mode for now.

View file

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View file

@ -0,0 +1,195 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import os
import pathlib
import sys
import torch
from image_decoder import export_decoder_onnx, test_decoder_onnx
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
from sam2_demo import run_demo, show_all_images
from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger
def parse_arguments():
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")
parser.add_argument(
"--model_type",
required=False,
type=str,
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
default="sam2_hiera_large",
help="The model type to export",
)
parser.add_argument(
"--components",
required=False,
nargs="+",
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
default=["image_encoder", "image_decoder"],
help="Type of ONNX models to export. "
"Note that image_decoder is a combination of prompt_encoder and mask_decoder",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="sam2_onnx_models",
)
parser.add_argument(
"--dynamic_batch_axes",
required=False,
default=False,
action="store_true",
help="Export image_encoder with dynamic batch axes",
)
parser.add_argument(
"--multimask_output",
required=False,
default=False,
action="store_true",
help="Export mask_decoder or image_decoder with multimask_output",
)
parser.add_argument(
"--disable_dynamic_multimask_via_stability",
required=False,
action="store_true",
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
"This option will be ignored when multimask_output is True",
)
parser.add_argument(
"--sam2_dir",
required=False,
type=str,
default="./segment-anything-2",
help="The directory of segment-anything-2 git repository",
)
parser.add_argument(
"--overwrite",
required=False,
default=False,
action="store_true",
help="Overwrite onnx model file if exists.",
)
parser.add_argument(
"--demo",
required=False,
default=False,
action="store_true",
help="Run demo with the exported ONNX models.",
)
parser.add_argument(
"--verbose",
required=False,
default=False,
action="store_true",
help="Print verbose information",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints")
sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs")
if not os.path.exists(args.sam2_dir):
raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(checkpoints_dir):
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(sam2_config_dir):
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")):
raise FileNotFoundError(
f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory."
)
if args.sam2_dir not in sys.path:
sys.path.append(args.sam2_dir)
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu")
for component in args.components:
if component == "image_encoder":
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
if args.overwrite or not os.path.exists(onnx_model_path):
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False)
elif component == "mask_decoder":
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx")
if args.overwrite or not os.path.exists(onnx_model_path):
export_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
args.verbose,
)
test_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
)
elif component == "prompt_encoder":
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx")
if args.overwrite or not os.path.exists(onnx_model_path):
export_prompt_encoder_onnx(sam2_model, onnx_model_path)
test_prompt_encoder_onnx(sam2_model, onnx_model_path)
elif component == "image_decoder":
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output)
if args.overwrite or not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
if args.demo:
# Export required ONNX models for demo if not already exported.
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
if not os.path.exists(onnx_model_path):
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True)
if not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, True)
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False)
if not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, False)
ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir)
print("demo output files for ONNX Runtime:", ort_image_files)
# Get results from torch engine to compare.
torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir)
print("demo output files for PyTorch:", torch_image_files)
show_all_images(ort_image_files, torch_image_files)
if __name__ == "__main__":
setup_logger(verbose=False)
with torch.no_grad():
main()

View file

@ -0,0 +1,249 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
import torch.nn.functional as F
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
from mask_decoder import SAM2MaskDecoder
from prompt_encoder import SAM2PromptEncoder
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn
logger = logging.getLogger(__name__)
class SAM2ImageDecoder(nn.Module):
def __init__(
self,
sam_model: SAM2Base,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
return_logits: bool = False,
mask_threshold: float = 0.0,
) -> None:
super().__init__()
self.prompt_encoder = SAM2PromptEncoder(sam_model)
self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
self.return_logits = return_logits
self.mask_threshold = mask_threshold
@torch.no_grad()
def forward(
self,
image_features_0: torch.Tensor,
image_features_1: torch.Tensor,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
input_masks: torch.Tensor,
has_input_masks: torch.Tensor,
original_image_size: torch.Tensor,
):
"""
Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
Args:
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
coordinate in (x, y) format of the P input points in image of size 1024x1024.
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
positive (foreground), 0 means negative (background), -1 means padding,
2 (box left upper corner), 3 (box right bottom corner).
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
Typically coming from a previous iteration.
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
Returns:
masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
"""
sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
low_res_masks, iou_predictions = self.mask_decoder(
image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
)
# Interpolate the low resolution masks back to the original image size.
masks = F.interpolate(
low_res_masks,
(original_image_size[0], original_image_size[1]),
mode="bilinear",
align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not self.return_logits:
masks = masks > self.mask_threshold
return masks, iou_predictions, low_res_masks
def export_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool = False,
verbose: bool = False,
):
batch_size = 1
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
sam2_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=multimask_output,
dynamic_multimask_via_stability=True,
).cpu()
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
example_inputs = (
image_features_0,
image_features_1,
image_embeddings,
point_coords,
point_labels,
input_masks,
has_input_masks,
original_image_size,
)
logger.info("point_coords.shape: %s", point_coords.shape)
logger.info("point_labels.shape: %s", point_labels.shape)
logger.info("input_masks.shape: %s", input_masks.shape)
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
logger.info("original_image_size.shape: %s", original_image_size.shape)
if verbose:
masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
logger.info("masks.shape: %s", masks.shape)
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
input_names = [
"image_features_0",
"image_features_1",
"image_embeddings",
"point_coords",
"point_labels",
"input_masks",
"has_input_masks",
"original_image_size",
]
output_names = ["masks", "iou_predictions", "low_res_masks"]
dynamic_axes = {
"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"input_masks": {0: "num_labels"},
"has_input_masks": {0: "num_labels"},
"masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
"low_res_masks": {0: "num_labels"},
"iou_predictions": {0: "num_labels"},
}
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch.onnx.export(
sam2_decoder,
example_inputs,
onnx_model_path,
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
logger.info("decoder onnx model saved to %s", onnx_model_path)
def test_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output=False,
):
batch_size = 1
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
sam2_image_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=multimask_output,
dynamic_multimask_via_stability=True,
).cpu()
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.zeros(1, dtype=torch.float)
original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
example_inputs = (
image_features_0,
image_features_1,
image_embeddings,
point_coords,
point_labels,
input_masks,
has_input_masks,
original_image_size,
)
masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
outputs = ort_session.run(output_names, inputs)
for i, output_name in enumerate(output_names):
logger.info(f"{output_name}.shape: %s", outputs[i].shape)
ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
if (
compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
):
print("onnx model has been verified:", onnx_model_path)
else:
print("onnx model verification failed:", onnx_model_path)

View file

@ -0,0 +1,164 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
from torch import nn
import onnxruntime
logger = logging.getLogger(__name__)
class SAM2ImageEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encodes images into features.
Only supports H=W=1024. If you want to use different image sizes like 512x512,
see https://github.com/facebookresearch/segment-anything-2/issues/138.
Args:
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
Returns:
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
"""
backbone_out = self.image_encoder(image)
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
# Prepare and flatten visual features.
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
][::-1]
return feats[0], feats[1], feats[2]
def export_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes: bool = False,
verbose: bool = False,
):
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
dynamic_axes = None
if dynamic_batch_axes:
dynamic_axes = {
"image": {0: "batch_size"},
"image_features_0": {0: "batch_size"},
"image_features_1": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
}
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch.onnx.export(
sam2_encoder,
image,
onnx_model_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["image"],
output_names=["image_features_0", "image_features_1", "image_embeddings"],
dynamic_axes=dynamic_axes,
)
print("encoder onnx model saved to", onnx_model_path)
def test_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes=False,
):
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
for batch_size in batch_sizes:
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
outputs = ort_session.run(output_names, {"image": image.numpy()})
for i, output_name in enumerate(output_names):
logger.info("output %s shape %s", output_name, outputs[i].shape)
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
if (
compare_tensors_with_tolerance(
"image_features_0",
image_features_0,
torch.tensor(ort_image_features_0),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_features_1",
image_features_1,
torch.tensor(ort_image_features_1),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_embeddings",
image_embeddings,
torch.tensor(ort_image_embeddings),
mismatch_percentage_tolerance=1,
)
):
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
else:
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")

View file

@ -0,0 +1,208 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
from prompt_encoder import SAM2PromptEncoder
from sam2.modeling.sam2_base import SAM2Base
from torch import nn
logger = logging.getLogger(__name__)
class SAM2MaskDecoder(nn.Module):
def __init__(
self,
sam_model: SAM2Base,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.multimask_output = multimask_output
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
@torch.no_grad()
def forward(
self,
image_features_0: torch.Tensor,
image_features_1: torch.Tensor,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_embeddings: torch.Tensor,
dense_embeddings: torch.Tensor,
):
"""
Decode masks from image and prompt embeddings. Only support H=W=1024.
Args:
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
Returns:
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
"""
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
high_res_features=[image_features_0, image_features_1],
)
if self.multimask_output:
low_res_masks = low_res_masks[:, 1:, :, :]
iou_predictions = iou_predictions[:, 1:]
elif self.dynamic_multimask_via_stability:
# When outputting a single mask, if the stability score from the current single-mask
# output (based on output token 0) falls below a threshold, we instead select from
# multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
low_res_masks, iou_predictions
)
else:
low_res_masks = low_res_masks[:, 0:1, :, :]
iou_predictions = iou_predictions[:, 0:1]
return low_res_masks, iou_predictions
def export_mask_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
verbose=False,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
# encode an random prompt
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
logger.info("image_pe.shape: %s", image_pe.shape)
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch.onnx.export(
sam2_mask_decoder,
inputs,
onnx_model_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=[
"image_features_0",
"image_features_1",
"image_embeddings",
"image_pe",
"sparse_embeddings",
"dense_embeddings",
],
output_names=["low_res_masks", "iou_predictions"],
dynamic_axes={
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
"dense_embeddings": {0: "num_labels"},
"low_res_masks": {0: "num_labels"},
"iou_predictions": {0: "num_labels"},
},
)
print("mask decoder onnx model saved to", onnx_model_path)
def test_mask_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool,
dynamic_multimask_via_stability: bool,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
outputs = ort_session.run(
output_names,
{
"image_features_0": image_features_0.numpy(),
"image_features_1": image_features_1.numpy(),
"image_embeddings": image_embeddings.numpy(),
"image_pe": image_pe.numpy(),
"sparse_embeddings": sparse_embeddings.numpy(),
"dense_embeddings": dense_embeddings.numpy(),
},
)
for i, output_name in enumerate(output_names):
logger.info("output %s shape: %s", output_name, outputs[i].shape)
ort_low_res_masks, ort_iou_predictions = outputs
torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
print(f"onnx model has been verified: {onnx_model_path}")

View file

@ -0,0 +1,189 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn
logger = logging.getLogger(__name__)
class SAM2PromptEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base):
super().__init__()
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
@torch.no_grad()
def forward(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
input_masks: torch.Tensor,
has_input_masks: torch.Tensor,
):
"""Encode prompts.
Args:
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
coordinate in (x, y) format of the P input points in image of size 1024x1024.
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
positive (foreground), 0 means negative (background), -1 means padding,
2 (box left upper corner), 3 (box right bottom corner).
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
Typically coming from a previous iteration.
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
Returns:
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
"""
sparse_embeddings = self._embed_points(point_coords, point_labels)
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
image_pe = self.prompt_encoder.get_dense_pe()
return sparse_embeddings, dense_embeddings, image_pe
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
point_coords = torch.cat([point_coords, padding_point], dim=1)
point_labels = torch.cat([point_labels, padding_label], dim=1)
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
for i in range(self.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
return mask_embedding
def export_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
logger.info("point_coords.shape: %s", point_coords.shape)
logger.info("point_labels.shape: %s", point_labels.shape)
logger.info("input_masks.shape: %s", input_masks.shape)
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
logger.info("image_pe.shape: %s", image_pe.shape)
torch.onnx.export(
sam2_prompt_encoder,
(point_coords, point_labels, input_masks, has_input_masks),
onnx_model_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
dynamic_axes={
"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"input_masks": {0: "num_labels"},
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
"dense_embeddings": {0: "num_labels"},
},
)
print("prompt encoder onnx model saved to ", onnx_model_path)
def test_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
outputs = ort_session.run(
output_names,
{
"point_coords": point_coords.numpy(),
"point_labels": point_labels.numpy(),
"input_masks": input_masks.numpy(),
"has_input_masks": has_input_masks.numpy(),
},
)
for i, output_name in enumerate(output_names):
logger.info("output %s shape: %s", output_name, outputs[i].shape)
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
if (
compare_tensors_with_tolerance(
"sparse_embeddings",
sparse_embeddings,
torch.tensor(ort_sparse_embeddings),
mismatch_percentage_tolerance=0.2,
)
and compare_tensors_with_tolerance(
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
)
and compare_tensors_with_tolerance(
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
)
):
print(f"onnx model has been verified: {onnx_model_path}")
else:
print(f"onnx model verification failed: {onnx_model_path}")

View file

@ -0,0 +1,293 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
from sam2_utils import build_sam2_model
def show_mask(mask, ax, random_color=False, borders=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
ax.scatter(
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
def show_masks(
image,
masks,
scores,
point_coords=None,
box_coords=None,
input_labels=None,
borders=True,
output_image_file_prefix=None,
image_files=None,
):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
if output_image_file_prefix:
filename = f"{output_image_file_prefix}_{i}.png"
if os.path.exists(filename):
os.remove(filename)
plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
if isinstance(image_files, list):
image_files.append(filename)
plt.show(block=False)
plt.close()
def get_predictor(
checkpoint_dir: str,
device: torch.device,
model_type="sam2_hiera_large",
engine="torch",
onnx_directory="sam2_onnx_models",
):
sam2_model = build_sam2_model(checkpoint_dir, model_type, device=device)
if engine == "torch":
predictor = SAM2ImagePredictor(sam2_model)
else:
predictor = SAM2ImageOnnxPredictor(sam2_model, onnx_directory=onnx_directory, model_type=model_type)
return predictor
def run_demo(
checkpoint_dir: str,
model_type="sam2_hiera_large",
engine="torch",
onnx_directory="sam2_onnx_models",
enable_batch=False,
):
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
if use_gpu:
if engine == "torch":
# Turn on tfloat32 for Ampere GPUs.
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif engine == "ort":
import onnxruntime
assert use_gpu == ("CUDAExecutionProvider" in onnxruntime.get_available_providers())
np.random.seed(3)
image = Image.open("truck.jpg")
image = np.array(image.convert("RGB"))
predictor = get_predictor(checkpoint_dir, device, model_type, engine, onnx_directory=onnx_directory)
predictor.set_image(image)
prefix = f"sam2_demo_{engine}_"
# The model returns masks, quality predictions for those masks,
# and low resolution mask logits that can be passed to the next iteration of prediction.
# With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
# scores gives the model's own estimation of the quality of these masks.
# For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
# even if only a single mask is desired;
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
image_files = []
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
borders=True,
output_image_file_prefix=prefix + "multimask",
image_files=image_files,
)
# Multiple points.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "multi_points",
image_files=image_files,
)
# Specify a window and a background point.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "background_point",
image_files=image_files,
)
# Take a box as input
input_box = np.array([425, 600, 700, 875])
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
box_coords=input_box,
output_image_file_prefix=prefix + "box",
image_files=image_files,
)
# Combining points and boxes
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
show_masks(
image,
masks,
scores,
box_coords=input_box,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "box_and_point",
image_files=image_files,
)
# TODO: support batched prompt inputs
if enable_batch:
input_boxes = np.array(
[
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
]
)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.squeeze(0), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box, plt.gca())
plt.axis("off")
plt.show()
plt.savefig(prefix + "batch_prompt.png")
image_files.append(prefix + "batch_prompt.png")
return image_files
def show_all_images(left_images, right_images):
# Show images in two rows since display screen is horizontal in most cases.
fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)):
left_img = mpimg.imread(left_img_path)
right_img = mpimg.imread(right_img_path)
axes[0, i].imshow(left_img)
axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
axes[0, i].axis("off")
axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
axes[1, i].imshow(right_img)
axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
axes[1, i].axis("off")
axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
plt.tight_layout()
plt.savefig("sam2_demo.png", format="png", bbox_inches="tight", dpi=1000)
plt.show()

View file

@ -0,0 +1,283 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional, Tuple, Union
import numpy as np
import torch
from PIL.Image import Image
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2_utils import decoder_shape_dict, encoder_shape_dict, get_decoder_onnx_path, get_image_encoder_onnx_path
from onnxruntime import InferenceSession
from onnxruntime.transformers.io_binding_helper import CudaSession
logger = logging.getLogger(__name__)
def create_ort_session(
onnx_path: str,
session_options=None,
provider="CUDAExecutionProvider",
enable_cuda_graph=False,
use_tf32=True,
) -> InferenceSession:
if provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device()
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
provider_options["use_tf32"] = int(use_tf32)
providers = [(provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
print(f"Using providers: {providers}")
return InferenceSession(onnx_path, session_options, providers=providers)
def create_session(
onnx_path: str, session_options=None, provider="CUDAExecutionProvider", device="cuda", enable_cuda_graph=False
) -> CudaSession:
ort_session = create_ort_session(
onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
)
cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
return cuda_session
class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
def __init__(
self,
sam_model: SAM2Base,
onnx_directory: str = "sam2_onnx_models",
model_type: str = "sam2_hiera_large",
onnx_dtype: torch.dtype = torch.float32,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
Arguments:
sam_model (SAM2Base): The model to use for mask prediction.
onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__(
sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
)
print(self.device)
if torch.cuda.is_available():
provider = "CUDAExecutionProvider"
device = "cuda"
else:
provider = "CPUExecutionProvider"
device = "cpu"
# This model is exported by image_encoder.py.
onnx_path = get_image_encoder_onnx_path(onnx_directory, model_type)
self.encoder_session = create_session(
onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
self.onnx_dtype = onnx_dtype
# This model is exported by image_decoder.py. It outputs only one mask.
onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=False)
self.decoder_session = create_session(
onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
# This model is exported by image_decoder.py. It outputs multiple (3) masks.
onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=True)
self.decoder_session_multi_out = create_session(
onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
@torch.no_grad()
def set_image(self, image: Union[np.ndarray, Image]):
"""
Calculates the image embeddings for the provided image.
Arguments:
image (np.ndarray or PIL Image): The input image to embed in RGB format.
The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
"""
self.reset_predictor()
# Transform the image to the form expected by the model
if isinstance(image, np.ndarray):
# For numpy array image, we assume (HxWxC) format.
self._orig_hw = [image.shape[:2]]
elif isinstance(image, Image):
w, h = image.size
self._orig_hw = [(h, w)]
else:
raise NotImplementedError("Image format not supported")
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
# Computing image embeddings for the provided image
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
self.encoder_session.allocate_buffers(io_shapes)
feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
for key, value in feed_dict.items():
logger.debug(f"{key}: {value.shape}, {value.dtype}")
logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
ort_outputs = self.encoder_session.infer(feed_dict)
self._features = {
"image_embed": ort_outputs["image_embeddings"],
"high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
}
self._is_image_set = True
logging.info("Image embeddings computed.")
@torch.no_grad()
def _predict(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
img_idx: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
assert not return_logits # onnx model is exported for returning bool masks.
if not self._is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
concat_points = (point_coords, point_labels)
else:
concat_points = None
# Embed prompts
if boxes is not None:
box_coords = boxes.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
box_labels = box_labels.repeat(boxes.size(0), 1)
# we merge "boxes" and "points" into a single "concat_points" input (where
# boxes are added at the beginning) to sam_prompt_encoder
if concat_points is not None:
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
concat_points = (concat_coords, concat_labels)
else:
concat_points = (box_coords, box_labels)
assert concat_points is not None
num_labels = concat_points[0].shape[0]
shape_dict = decoder_shape_dict(
original_image_height=self._orig_hw[img_idx][0],
original_image_width=self._orig_hw[img_idx][1],
num_labels=num_labels,
max_points=concat_points[0].shape[1],
num_masks=3 if multimask_output else 1,
)
if multimask_output:
decoder_session = self.decoder_session_multi_out
else:
decoder_session = self.decoder_session
decoder_session.allocate_buffers(shape_dict)
image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
if mask_input is None:
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float, device=self.device)
has_input_masks = torch.zeros(num_labels, dtype=torch.float, device=self.device)
else:
input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
has_input_masks = torch.ones(num_labels, dtype=torch.float, device=self.device)
feed_dict = {
"image_embeddings": image_embeddings.contiguous().to(dtype=torch.float32).to(self.device),
"image_features_0": image_features_0.contiguous().to(dtype=torch.float32).to(self.device),
"image_features_1": image_features_1.contiguous().to(dtype=torch.float32).to(self.device),
"point_coords": concat_points[0].to(dtype=torch.float32).to(self.device),
"point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
"input_masks": input_masks.to(dtype=torch.float32).to(self.device),
"has_input_masks": has_input_masks.to(dtype=torch.float32).to(self.device),
"original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
}
for key, value in feed_dict.items():
logger.debug(f"{key}: {value.shape}, {value.dtype}")
logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
ort_outputs = decoder_session.infer(feed_dict)
masks = ort_outputs["masks"]
iou_predictions = ort_outputs["iou_predictions"]
low_res_masks = ort_outputs["low_res_masks"]
return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)

View file

@ -0,0 +1,122 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import os
import torch
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
logger = logging.getLogger(__name__)
def get_model_cfg(model_type) -> str:
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
if model_type == "sam2_hiera_tiny":
model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
model_cfg = "sam2_hiera_b+.yaml"
else:
model_cfg = "sam2_hiera_l.yaml"
return model_cfg
def build_sam2_model(checkpoint_dir: str, model_type: str, device="cpu") -> SAM2Base:
sam2_checkpoint = os.path.join(checkpoint_dir, f"{model_type}.pt")
model_cfg = get_model_cfg(model_type)
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
return sam2_model
def get_decoder_onnx_path(dir: str, model_type, multimask_output) -> str:
return os.path.join(dir, f"{model_type}_decoder" + ("_multi" if multimask_output else "") + ".onnx")
def get_image_encoder_onnx_path(dir: str, model_type) -> str:
return os.path.join(dir, f"{model_type}_image_encoder.onnx")
def encoder_shape_dict(batch_size: int, height: int, width: int):
assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
return {
"image": [batch_size, 3, height, width],
"image_features_0": [batch_size, 32, height // 4, width // 4],
"image_features_1": [batch_size, 64, height // 8, width // 8],
"image_embeddings": [batch_size, 256, height // 16, width // 16],
}
def decoder_shape_dict(
original_image_height: int,
original_image_width: int,
num_labels: int = 1,
max_points: int = 16,
num_masks: int = 1,
) -> dict:
height: int = 1024
width: int = 1024
return {
"image_features_0": [1, 32, height // 4, width // 4],
"image_features_1": [1, 64, height // 8, width // 8],
"image_embeddings": [1, 256, height // 16, width // 16],
"point_coords": [num_labels, max_points, 2],
"point_labels": [num_labels, max_points],
"input_masks": [num_labels, 1, height // 4, width // 4],
"has_input_masks": [num_labels],
"original_image_size": [2],
"masks": [num_labels, num_masks, original_image_height, original_image_width],
"iou_predictions": [num_labels, num_masks],
"low_res_masks": [num_labels, num_masks, height // 4, width // 4],
}
def compare_tensors_with_tolerance(
name: str,
tensor1: torch.Tensor,
tensor2: torch.Tensor,
atol=5e-3,
rtol=1e-4,
mismatch_percentage_tolerance=0.1,
) -> bool:
assert tensor1.shape == tensor2.shape
a = tensor1.clone().float()
b = tensor2.clone().float()
differences = torch.abs(a - b)
mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
total_elements = a.numel()
mismatch_percentage = (mismatch_count / total_elements) * 100
passed = mismatch_percentage < mismatch_percentage_tolerance
log_func = logger.error if not passed else logger.info
log_func(
"%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
name,
mismatch_percentage,
mismatch_count,
total_elements,
"passed" if passed else "failed",
mismatch_percentage_tolerance,
)
return passed
def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
image = torch.randn(batch_size, 3, image_height, image_width).cpu()
return image
def setup_logger(verbose=True):
if verbose:
logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
logging.getLogger().setLevel(logging.INFO)
else:
logging.basicConfig(format="[%(message)s")
logging.getLogger().setLevel(logging.WARNING)

View file

@ -494,8 +494,9 @@ packages = [
"onnxruntime.transformers.models.llama",
"onnxruntime.transformers.models.longformer",
"onnxruntime.transformers.models.phi2",
"onnxruntime.transformers.models.t5",
"onnxruntime.transformers.models.sam2",
"onnxruntime.transformers.models.stable_diffusion",
"onnxruntime.transformers.models.t5",
"onnxruntime.transformers.models.whisper",
]