Export MergeIdLists Caffe2 Operator to PyTorch

Summary: As titled.

Test Plan: buck test //caffe2/caffe2/python/operator_test:torch_integration_test -- test_merge_id_lists

Reviewed By: yf225

Differential Revision: D23076951

fbshipit-source-id: c37dfd93003590eed70b0d46e0151397a402dde6
This commit is contained in:
Edson Romero 2020-08-14 14:41:56 -07:00 committed by Facebook GitHub Bot
parent c8e789e06e
commit 5014cf4a4d
3 changed files with 37 additions and 1 deletions

View file

@ -30,3 +30,7 @@ within a batch. This can be an issue if ID_LIST are order sensitive.
NO_GRADIENT(MergeIdLists);
}
}
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
MergeIdLists,
"_caffe2::MergeIdLists(Tensor[] lengths_and_values) -> (Tensor merged_lengths, Tensor merged_values)",
caffe2::MergeIdListsOp<caffe2::CPUContext>);

View file

@ -5,6 +5,9 @@
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(MergeIdLists);
namespace caffe2 {

View file

@ -8,7 +8,7 @@ import torch
import unittest
from caffe2.python import core, workspace
from hypothesis import given
from hypothesis import given, settings
from scipy.stats import norm
@ -875,6 +875,35 @@ class TorchIntegration(hu.HypothesisTestCase):
)
torch.testing.assert_allclose(expected_output, actual_output.cpu())
@given(lengths_0=st.integers(1, 10), lengths_1=st.integers(1, 10))
@settings(deadline=1000)
def test_merge_id_lists(self, lengths_0, lengths_1):
def _merge_id_lists(lengths, values):
ref_op = core.CreateOperator(
'MergeIdLists',
["lengths_0", "values_0", "lengths_1", "values_1"],
["merged_lengths", "merged_values"]
)
workspace.FeedBlob("lengths_0", lengths[0])
workspace.FeedBlob("values_0", values[0])
workspace.FeedBlob("lengths_1", lengths[1])
workspace.FeedBlob("values_1", values[1])
workspace.RunOperatorOnce(ref_op)
return workspace.FetchBlob("merged_lengths"), workspace.FetchBlob("merged_values")
lengths = [np.array([lengths_0]).astype(np.int32), np.array([lengths_1]).astype(np.int32)]
values = [
np.random.choice(np.arange(0, 10), size=lengths_0, replace=False).astype(np.int32),
np.random.choice(np.arange(10, 20), size=lengths_1, replace=False).astype(np.int32)
]
expected_merged_lengths, expected_merged_values = _merge_id_lists(lengths, values)
output_merged_lengths, output_merged_values = torch.ops._caffe2.MergeIdLists(
[torch.tensor(lengths[0]), torch.tensor(values[0]), torch.tensor(lengths[1]), torch.tensor(values[1])]
)
torch.testing.assert_allclose(expected_merged_lengths, output_merged_lengths)
torch.testing.assert_allclose(expected_merged_values, output_merged_values)
if __name__ == '__main__':
unittest.main()