From 5014cf4a4d2cc0d0ee18d1efc8818fc8a2c6bd6e Mon Sep 17 00:00:00 2001 From: Edson Romero Date: Fri, 14 Aug 2020 14:41:56 -0700 Subject: [PATCH] Export MergeIdLists Caffe2 Operator to PyTorch Summary: As titled. Test Plan: buck test //caffe2/caffe2/python/operator_test:torch_integration_test -- test_merge_id_lists Reviewed By: yf225 Differential Revision: D23076951 fbshipit-source-id: c37dfd93003590eed70b0d46e0151397a402dde6 --- caffe2/operators/merge_id_lists_op.cc | 4 +++ caffe2/operators/merge_id_lists_op.h | 3 ++ .../operator_test/torch_integration_test.py | 31 ++++++++++++++++++- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/caffe2/operators/merge_id_lists_op.cc b/caffe2/operators/merge_id_lists_op.cc index 2505123d765..38acb1f71a9 100644 --- a/caffe2/operators/merge_id_lists_op.cc +++ b/caffe2/operators/merge_id_lists_op.cc @@ -30,3 +30,7 @@ within a batch. This can be an issue if ID_LIST are order sensitive. NO_GRADIENT(MergeIdLists); } } +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + MergeIdLists, + "_caffe2::MergeIdLists(Tensor[] lengths_and_values) -> (Tensor merged_lengths, Tensor merged_values)", + caffe2::MergeIdListsOp); diff --git a/caffe2/operators/merge_id_lists_op.h b/caffe2/operators/merge_id_lists_op.h index 0915c3792cc..2ba88dccf4a 100644 --- a/caffe2/operators/merge_id_lists_op.h +++ b/caffe2/operators/merge_id_lists_op.h @@ -5,6 +5,9 @@ #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" + +C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(MergeIdLists); namespace caffe2 { diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index 5a5b3d8802b..1fff75afd96 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -8,7 +8,7 @@ import torch import unittest from caffe2.python import core, workspace -from hypothesis import given +from hypothesis import given, settings from scipy.stats import norm @@ -875,6 +875,35 @@ class TorchIntegration(hu.HypothesisTestCase): ) torch.testing.assert_allclose(expected_output, actual_output.cpu()) + @given(lengths_0=st.integers(1, 10), lengths_1=st.integers(1, 10)) + @settings(deadline=1000) + def test_merge_id_lists(self, lengths_0, lengths_1): + def _merge_id_lists(lengths, values): + ref_op = core.CreateOperator( + 'MergeIdLists', + ["lengths_0", "values_0", "lengths_1", "values_1"], + ["merged_lengths", "merged_values"] + ) + workspace.FeedBlob("lengths_0", lengths[0]) + workspace.FeedBlob("values_0", values[0]) + workspace.FeedBlob("lengths_1", lengths[1]) + workspace.FeedBlob("values_1", values[1]) + workspace.RunOperatorOnce(ref_op) + return workspace.FetchBlob("merged_lengths"), workspace.FetchBlob("merged_values") + + lengths = [np.array([lengths_0]).astype(np.int32), np.array([lengths_1]).astype(np.int32)] + values = [ + np.random.choice(np.arange(0, 10), size=lengths_0, replace=False).astype(np.int32), + np.random.choice(np.arange(10, 20), size=lengths_1, replace=False).astype(np.int32) + ] + + expected_merged_lengths, expected_merged_values = _merge_id_lists(lengths, values) + output_merged_lengths, output_merged_values = torch.ops._caffe2.MergeIdLists( + [torch.tensor(lengths[0]), torch.tensor(values[0]), torch.tensor(lengths[1]), torch.tensor(values[1])] + ) + torch.testing.assert_allclose(expected_merged_lengths, output_merged_lengths) + torch.testing.assert_allclose(expected_merged_values, output_merged_values) + if __name__ == '__main__': unittest.main()