mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Export MergeIdLists Caffe2 Operator to PyTorch
Summary: As titled. Test Plan: buck test //caffe2/caffe2/python/operator_test:torch_integration_test -- test_merge_id_lists Reviewed By: yf225 Differential Revision: D23076951 fbshipit-source-id: c37dfd93003590eed70b0d46e0151397a402dde6
This commit is contained in:
parent
c8e789e06e
commit
5014cf4a4d
3 changed files with 37 additions and 1 deletions
|
|
@ -30,3 +30,7 @@ within a batch. This can be an issue if ID_LIST are order sensitive.
|
|||
NO_GRADIENT(MergeIdLists);
|
||||
}
|
||||
}
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
|
||||
MergeIdLists,
|
||||
"_caffe2::MergeIdLists(Tensor[] lengths_and_values) -> (Tensor merged_lengths, Tensor merged_values)",
|
||||
caffe2::MergeIdListsOp<caffe2::CPUContext>);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
#include <vector>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
||||
|
||||
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(MergeIdLists);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import unittest
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
from hypothesis import given
|
||||
from hypothesis import given, settings
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
|
|
@ -875,6 +875,35 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
)
|
||||
torch.testing.assert_allclose(expected_output, actual_output.cpu())
|
||||
|
||||
@given(lengths_0=st.integers(1, 10), lengths_1=st.integers(1, 10))
|
||||
@settings(deadline=1000)
|
||||
def test_merge_id_lists(self, lengths_0, lengths_1):
|
||||
def _merge_id_lists(lengths, values):
|
||||
ref_op = core.CreateOperator(
|
||||
'MergeIdLists',
|
||||
["lengths_0", "values_0", "lengths_1", "values_1"],
|
||||
["merged_lengths", "merged_values"]
|
||||
)
|
||||
workspace.FeedBlob("lengths_0", lengths[0])
|
||||
workspace.FeedBlob("values_0", values[0])
|
||||
workspace.FeedBlob("lengths_1", lengths[1])
|
||||
workspace.FeedBlob("values_1", values[1])
|
||||
workspace.RunOperatorOnce(ref_op)
|
||||
return workspace.FetchBlob("merged_lengths"), workspace.FetchBlob("merged_values")
|
||||
|
||||
lengths = [np.array([lengths_0]).astype(np.int32), np.array([lengths_1]).astype(np.int32)]
|
||||
values = [
|
||||
np.random.choice(np.arange(0, 10), size=lengths_0, replace=False).astype(np.int32),
|
||||
np.random.choice(np.arange(10, 20), size=lengths_1, replace=False).astype(np.int32)
|
||||
]
|
||||
|
||||
expected_merged_lengths, expected_merged_values = _merge_id_lists(lengths, values)
|
||||
output_merged_lengths, output_merged_values = torch.ops._caffe2.MergeIdLists(
|
||||
[torch.tensor(lengths[0]), torch.tensor(values[0]), torch.tensor(lengths[1]), torch.tensor(values[1])]
|
||||
)
|
||||
torch.testing.assert_allclose(expected_merged_lengths, output_merged_lengths)
|
||||
torch.testing.assert_allclose(expected_merged_values, output_merged_values)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue