mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
register BoxWithNMSLimit with C10
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17956 Reviewed By: houseroad Differential Revision: D14417300 fbshipit-source-id: eb5e2ba84513b3b7bfa509dc442424b13fe9148f
This commit is contained in:
parent
d895d30876
commit
f4e35d30ed
3 changed files with 193 additions and 45 deletions
|
|
@ -295,3 +295,28 @@ SHOULD_NOT_DO_GRADIENT(BoxWithNMSLimit);
|
|||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
||||
C10_REGISTER_CAFFE2_OPERATOR_CPU(
|
||||
BoxWithNMSLimit,
|
||||
(std::vector<c10::Argument>{
|
||||
c10::Argument("scores"),
|
||||
c10::Argument("boxes"),
|
||||
c10::Argument("batch_splits"),
|
||||
c10::Argument("score_thresh", FloatType::get()),
|
||||
c10::Argument("nms", FloatType::get()),
|
||||
c10::Argument("detections_per_im", IntType::get()),
|
||||
c10::Argument("soft_nms_enabled", BoolType::get()),
|
||||
c10::Argument("soft_nms_method", StringType::get()),
|
||||
c10::Argument("soft_nms_sigma", FloatType::get()),
|
||||
c10::Argument("soft_nms_min_score_thres", FloatType::get()),
|
||||
c10::Argument("rotated", BoolType::get()),
|
||||
}),
|
||||
(std::vector<c10::Argument>{
|
||||
c10::Argument("scores"),
|
||||
c10::Argument("boxes"),
|
||||
c10::Argument("classes"),
|
||||
c10::Argument("batch_splits"),
|
||||
// c10::Argument("keeps"),
|
||||
// c10::Argument("keeps_size"),
|
||||
}),
|
||||
caffe2::BoxWithNMSLimitOp<caffe2::CPUContext>);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
C10_DECLARE_CAFFE2_OPERATOR(BoxWithNMSLimit)
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// C++ implementation of function insert_box_results_with_nms_and_limit()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
import torch
|
||||
|
|
@ -44,16 +41,33 @@ def generate_rois_rotated(roi_counts, im_dims):
|
|||
# [batch_id, ctr_x, ctr_y, w, h, angle]
|
||||
rotated_rois = np.empty((rois.shape[0], 6)).astype(np.float32)
|
||||
rotated_rois[:, 0] = rois[:, 0] # batch_id
|
||||
rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2. # ctr_x = (x1 + x2) / 2
|
||||
rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2. # ctr_y = (y1 + y2) / 2
|
||||
rotated_rois[:, 1] = (rois[:, 1] + rois[:, 3]) / 2.0 # ctr_x = (x1 + x2) / 2
|
||||
rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2.0 # ctr_y = (y1 + y2) / 2
|
||||
rotated_rois[:, 3] = rois[:, 3] - rois[:, 1] + 1.0 # w = x2 - x1 + 1
|
||||
rotated_rois[:, 4] = rois[:, 4] - rois[:, 2] + 1.0 # h = y2 - y1 + 1
|
||||
rotated_rois[:, 5] = np.random.uniform(-90.0, 90.0) # angle in degrees
|
||||
return rotated_rois
|
||||
|
||||
|
||||
class TorchIntegration(hu.HypothesisTestCase):
|
||||
def create_bbox_transform_inputs(roi_counts, num_classes, rotated):
|
||||
batch_size = len(roi_counts)
|
||||
total_rois = sum(roi_counts)
|
||||
im_dims = np.random.randint(100, 600, batch_size)
|
||||
rois = (
|
||||
generate_rois_rotated(roi_counts, im_dims)
|
||||
if rotated
|
||||
else generate_rois(roi_counts, im_dims)
|
||||
)
|
||||
box_dim = 5 if rotated else 4
|
||||
deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32)
|
||||
im_info = np.zeros((batch_size, 3)).astype(np.float32)
|
||||
im_info[:, 0] = im_dims
|
||||
im_info[:, 1] = im_dims
|
||||
im_info[:, 2] = 1.0
|
||||
return rois, deltas, im_info
|
||||
|
||||
|
||||
class TorchIntegration(hu.HypothesisTestCase):
|
||||
@given(
|
||||
roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10),
|
||||
num_classes=st.integers(1, 10),
|
||||
|
|
@ -75,20 +89,9 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
"""
|
||||
Test with rois for multiple images in a batch
|
||||
"""
|
||||
batch_size = len(roi_counts)
|
||||
total_rois = sum(roi_counts)
|
||||
im_dims = np.random.randint(100, 600, batch_size)
|
||||
rois = (
|
||||
generate_rois_rotated(roi_counts, im_dims)
|
||||
if rotated
|
||||
else generate_rois(roi_counts, im_dims)
|
||||
rois, deltas, im_info = create_bbox_transform_inputs(
|
||||
roi_counts, num_classes, rotated
|
||||
)
|
||||
box_dim = 5 if rotated else 4
|
||||
deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32)
|
||||
im_info = np.zeros((batch_size, 3)).astype(np.float32)
|
||||
im_info[:, 0] = im_dims
|
||||
im_info[:, 1] = im_dims
|
||||
im_info[:, 2] = 1.0
|
||||
|
||||
def bbox_transform_ref():
|
||||
ref_op = core.CreateOperator(
|
||||
|
|
@ -108,14 +111,102 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
|
||||
box_out = torch.tensor(bbox_transform_ref())
|
||||
a, b = torch.ops._caffe2.BBoxTransform(
|
||||
torch.tensor(rois), torch.tensor(deltas),
|
||||
torch.tensor(im_info),
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
False, rotated, angle_bound_on,
|
||||
-90, 90, clip_angle_thresh)
|
||||
torch.tensor(rois),
|
||||
torch.tensor(deltas),
|
||||
torch.tensor(im_info),
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
False,
|
||||
rotated,
|
||||
angle_bound_on,
|
||||
-90,
|
||||
90,
|
||||
clip_angle_thresh,
|
||||
)
|
||||
|
||||
torch.testing.assert_allclose(box_out, a)
|
||||
|
||||
@given(
|
||||
roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10),
|
||||
num_classes=st.integers(1, 10),
|
||||
rotated=st.booleans(),
|
||||
angle_bound_on=st.booleans(),
|
||||
clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
|
||||
**hu.gcs_cpu_only
|
||||
)
|
||||
def test_box_with_nms_limits(
|
||||
self,
|
||||
roi_counts,
|
||||
num_classes,
|
||||
rotated,
|
||||
angle_bound_on,
|
||||
clip_angle_thresh,
|
||||
gc,
|
||||
dc,
|
||||
):
|
||||
rotated = False # FIXME remove this after rotation is supported
|
||||
rois, deltas, im_info = create_bbox_transform_inputs(
|
||||
roi_counts, num_classes, rotated
|
||||
)
|
||||
pred_bbox, batch_splits = [
|
||||
t.detach().numpy()
|
||||
for t in torch.ops._caffe2.BBoxTransform(
|
||||
torch.tensor(rois),
|
||||
torch.tensor(deltas),
|
||||
torch.tensor(im_info),
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
False,
|
||||
rotated,
|
||||
angle_bound_on,
|
||||
-90,
|
||||
90,
|
||||
clip_angle_thresh,
|
||||
)
|
||||
]
|
||||
class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32)
|
||||
score_thresh = 0.5
|
||||
nms_thresh = 0.5
|
||||
topk_per_image = sum(roi_counts) / 2
|
||||
|
||||
def box_with_nms_limit_ref():
|
||||
input_blobs = ["class_prob", "pred_bbox", "batch_splits"]
|
||||
output_blobs = ["score_nms", "bbox_nms", "class_nms", "batch_splits_nms"]
|
||||
ref_op = core.CreateOperator(
|
||||
"BoxWithNMSLimit",
|
||||
input_blobs,
|
||||
output_blobs,
|
||||
score_thresh=float(score_thresh),
|
||||
nms=float(nms_thresh),
|
||||
detections_per_im=int(topk_per_image),
|
||||
soft_nms_enabled=False,
|
||||
soft_nms_method="linear",
|
||||
soft_nms_sigma=0.5,
|
||||
soft_nms_min_score_thres=0.001,
|
||||
rotated=rotated,
|
||||
)
|
||||
workspace.FeedBlob("class_prob", class_prob)
|
||||
workspace.FeedBlob("pred_bbox", pred_bbox)
|
||||
workspace.FeedBlob("batch_splits", batch_splits)
|
||||
workspace.RunOperatorOnce(ref_op)
|
||||
return (workspace.FetchBlob(b) for b in output_blobs)
|
||||
|
||||
output_refs = box_with_nms_limit_ref()
|
||||
outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
||||
torch.tensor(class_prob),
|
||||
torch.tensor(pred_bbox),
|
||||
torch.tensor(batch_splits),
|
||||
score_thresh=float(score_thresh),
|
||||
nms=float(nms_thresh),
|
||||
detections_per_im=int(topk_per_image),
|
||||
soft_nms_enabled=False,
|
||||
soft_nms_method="linear",
|
||||
soft_nms_sigma=0.5,
|
||||
soft_nms_min_score_thres=0.001,
|
||||
rotated=rotated,
|
||||
)
|
||||
|
||||
for o, o_ref in zip(outputs, output_refs):
|
||||
torch.testing.assert_allclose(o, o_ref)
|
||||
|
||||
@given(
|
||||
A=st.integers(min_value=4, max_value=4),
|
||||
H=st.integers(min_value=10, max_value=10),
|
||||
|
|
@ -124,8 +215,11 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
)
|
||||
def test_generate_proposals(self, A, H, W, img_count):
|
||||
scores = np.ones((img_count, A, H, W)).astype(np.float32)
|
||||
bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape(
|
||||
(img_count, 4*A, H, W)).astype(np.float32)
|
||||
bbox_deltas = (
|
||||
np.linspace(0, 10, num=img_count * 4 * A * H * W)
|
||||
.reshape((img_count, 4 * A, H, W))
|
||||
.astype(np.float32)
|
||||
)
|
||||
im_info = np.ones((img_count, 3)).astype(np.float32) / 10
|
||||
anchors = np.ones((A, 4)).astype(np.float32)
|
||||
|
||||
|
|
@ -147,9 +241,20 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
rois = torch.tensor(rois)
|
||||
rois_probs = torch.tensor(rois_probs)
|
||||
a, b = torch.ops._caffe2.GenerateProposals(
|
||||
torch.tensor(scores), torch.tensor(bbox_deltas),
|
||||
torch.tensor(im_info), torch.tensor(anchors),
|
||||
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0)
|
||||
torch.tensor(scores),
|
||||
torch.tensor(bbox_deltas),
|
||||
torch.tensor(im_info),
|
||||
torch.tensor(anchors),
|
||||
2.0,
|
||||
6000,
|
||||
300,
|
||||
0.7,
|
||||
16,
|
||||
True,
|
||||
-90,
|
||||
90,
|
||||
1.0,
|
||||
)
|
||||
torch.testing.assert_allclose(rois, a)
|
||||
torch.testing.assert_allclose(rois_probs, b)
|
||||
|
||||
|
|
@ -241,11 +346,14 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
H=st.integers(min_value=10, max_value=10),
|
||||
W=st.integers(min_value=8, max_value=8),
|
||||
img_count=st.integers(min_value=3, max_value=3),
|
||||
)
|
||||
)
|
||||
def test_generate_proposals_cuda(self, A, H, W, img_count):
|
||||
scores = np.ones((img_count, A, H, W)).astype(np.float32)
|
||||
bbox_deltas = np.linspace(0, 10, num=img_count*4*A*H*W).reshape(
|
||||
(img_count, 4*A, H, W)).astype(np.float32)
|
||||
bbox_deltas = (
|
||||
np.linspace(0, 10, num=img_count * 4 * A * H * W)
|
||||
.reshape((img_count, 4 * A, H, W))
|
||||
.astype(np.float32)
|
||||
)
|
||||
im_info = np.ones((img_count, 3)).astype(np.float32) / 10
|
||||
anchors = np.ones((A, 4)).astype(np.float32)
|
||||
|
||||
|
|
@ -267,9 +375,20 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
rois = torch.tensor(rois)
|
||||
rois_probs = torch.tensor(rois_probs)
|
||||
a, b = torch.ops._caffe2.GenerateProposals(
|
||||
torch.tensor(scores).cuda(), torch.tensor(bbox_deltas).cuda(),
|
||||
torch.tensor(im_info).cuda(), torch.tensor(anchors).cuda(),
|
||||
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0)
|
||||
torch.tensor(scores).cuda(),
|
||||
torch.tensor(bbox_deltas).cuda(),
|
||||
torch.tensor(im_info).cuda(),
|
||||
torch.tensor(anchors).cuda(),
|
||||
2.0,
|
||||
6000,
|
||||
300,
|
||||
0.7,
|
||||
16,
|
||||
True,
|
||||
-90,
|
||||
90,
|
||||
1.0,
|
||||
)
|
||||
torch.testing.assert_allclose(rois, a.cpu())
|
||||
torch.testing.assert_allclose(rois_probs, b.cpu())
|
||||
|
||||
|
|
@ -281,13 +400,15 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
)
|
||||
def _test_roi_align(self, N, C, H, W, device):
|
||||
def rand_roi():
|
||||
return np.array([
|
||||
float(int(N * np.random.rand())),
|
||||
0.5 * np.random.rand() * W,
|
||||
0.5 * np.random.rand() * H,
|
||||
(0.5 + 0.5 * np.random.rand()) * W,
|
||||
(0.5 + 0.5 * np.random.rand()) * H,
|
||||
]).astype(np.float32)
|
||||
return np.array(
|
||||
[
|
||||
float(int(N * np.random.rand())),
|
||||
0.5 * np.random.rand() * W,
|
||||
0.5 * np.random.rand() * H,
|
||||
(0.5 + 0.5 * np.random.rand()) * W,
|
||||
(0.5 + 0.5 * np.random.rand()) * H,
|
||||
]
|
||||
).astype(np.float32)
|
||||
|
||||
feature = np.random.randn(N, C, H, W).astype(np.float32)
|
||||
rois = np.array([rand_roi() for _ in range(10)])
|
||||
|
|
@ -300,7 +421,7 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
spatial_scale=1.0,
|
||||
pooled_h=3,
|
||||
pooled_w=3,
|
||||
sampling_ratio=0
|
||||
sampling_ratio=0,
|
||||
)
|
||||
workspace.FeedBlob("feature", _feature)
|
||||
workspace.FeedBlob("rois", _rois)
|
||||
|
|
@ -315,7 +436,7 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
spatial_scale=1.0,
|
||||
pooled_h=3,
|
||||
pooled_w=3,
|
||||
sampling_ratio=0
|
||||
sampling_ratio=0,
|
||||
)
|
||||
torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue