mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Simple FindOp for CPU and GPU which searches a list of unordered needles from an unordered index. CPU version might be faster if first sorting the index / needles, but we can get back to that later. CUDA op is also kind of brutish, but pretty parallel. Since the index and the queries are smallish at least in the use case currently in mind (Machine Translation's team word candidate search), I think this is a sufficient start. Note that this is much simpler than the Index-class of ops which allow modifying the index etc. Since CUDA ops are more complex to implement for the full Index functionality, I decided to make a separate op with this very simple functionality. Differential Revision: D4910131 fbshipit-source-id: 6df35c9e3c71d5392a500d5b98fd708ab0c8e587
51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core
|
|
import hypothesis.strategies as st
|
|
from hypothesis import given
|
|
|
|
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
import numpy as np
|
|
|
|
|
|
class TestFindOperator(hu.HypothesisTestCase):
|
|
|
|
@given(n=st.sampled_from([1, 4, 8, 31, 79, 150]),
|
|
idxsize=st.sampled_from([2, 4, 8, 1000, 5000]),
|
|
**hu.gcs)
|
|
def test_find(self, n, idxsize, gc, dc):
|
|
maxval = 10
|
|
|
|
def findop(idx, X):
|
|
res = []
|
|
for j in list(X.flatten()):
|
|
i = np.where(idx == j)[0]
|
|
if len(i) == 0:
|
|
res.append(-1)
|
|
else:
|
|
res.append(i[-1])
|
|
|
|
print("Idx: {} X: {}".format(idx, X))
|
|
print("Res: {}".format(res))
|
|
return [np.array(res).astype(np.int32)]
|
|
|
|
X = (np.random.rand(n) * maxval).astype(np.int32)
|
|
idx = (np.random.rand(idxsize) * maxval).astype(np.int32)
|
|
|
|
op = core.CreateOperator(
|
|
"Find",
|
|
["idx", "X"],
|
|
["y"],
|
|
)
|
|
|
|
self.assertReferenceChecks(
|
|
device_option=gc,
|
|
op=op,
|
|
inputs=[idx, X],
|
|
reference=findop,
|
|
)
|