mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: These operators update the state of the instance and therefor should have the instance in the output list. Reviewed By: xianjiec Differential Revision: D4554773 fbshipit-source-id: 556d484fcf58878308aa6b0f7cd7ea2446d3f29e
131 lines
4.4 KiB
Python
131 lines
4.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
from caffe2.python import core, workspace
|
|
from caffe2.python.test_util import TestCase
|
|
import numpy as np
|
|
import tempfile
|
|
|
|
|
|
class TestIndexOps(TestCase):
|
|
def _test_index_ops(self, entries, dtype, index_create_op):
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
index_create_op,
|
|
[],
|
|
['index'],
|
|
max_elements=10))
|
|
my_entries = np.array(
|
|
[entries[0], entries[1], entries[2]], dtype=dtype)
|
|
|
|
workspace.FeedBlob('entries', my_entries)
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexLoad',
|
|
['index', 'entries'],
|
|
['index']))
|
|
query1 = np.array(
|
|
[entries[0], entries[3], entries[0], entries[4]],
|
|
dtype=dtype)
|
|
|
|
workspace.FeedBlob('query1', query1)
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexGet',
|
|
['index', 'query1'],
|
|
['result1']))
|
|
result1 = workspace.FetchBlob('result1')
|
|
np.testing.assert_array_equal([1, 4, 1, 5], result1)
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexFreeze',
|
|
['index'],
|
|
['index']))
|
|
|
|
query2 = np.array(
|
|
[entries[5], entries[4], entries[0], entries[6], entries[7]],
|
|
dtype=dtype)
|
|
workspace.FeedBlob('query2', query2)
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexGet',
|
|
['index', 'query2'],
|
|
['result2']))
|
|
result2 = workspace.FetchBlob('result2')
|
|
np.testing.assert_array_equal([0, 5, 1, 0, 0], result2)
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexSize',
|
|
['index'],
|
|
['index_size']))
|
|
size = workspace.FetchBlob('index_size')
|
|
self.assertEquals(size, 6)
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexStore',
|
|
['index'],
|
|
['stored_entries']))
|
|
stored_actual = workspace.FetchBlob('stored_entries')
|
|
new_entries = np.array([entries[3], entries[4]], dtype=dtype)
|
|
np.testing.assert_array_equal(
|
|
np.concatenate((my_entries, new_entries)), stored_actual)
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
index_create_op,
|
|
[],
|
|
['index2']))
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexLoad',
|
|
['index2', 'stored_entries'],
|
|
['index2'],
|
|
skip_first_entry=1))
|
|
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexSize',
|
|
['index2'],
|
|
['index2_size']))
|
|
index2_size = workspace.FetchBlob('index2_size')
|
|
self.assertEquals(index2_size, 5)
|
|
|
|
# test serde
|
|
with tempfile.NamedTemporaryFile() as tmp:
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'Save',
|
|
['index'],
|
|
[],
|
|
absolute_path=1,
|
|
db_type='minidb',
|
|
db=tmp.name))
|
|
# frees up the blob
|
|
workspace.FeedBlob('index', np.array([]))
|
|
# reloads the index
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'Load',
|
|
[],
|
|
['index'],
|
|
absolute_path=1,
|
|
db_type='minidb',
|
|
db=tmp.name))
|
|
query3 = np.array(
|
|
[entries[0], entries[3], entries[0], entries[4], entries[4]],
|
|
dtype=dtype)
|
|
|
|
workspace.FeedBlob('query3', query3)
|
|
workspace.RunOperatorOnce(core.CreateOperator(
|
|
'IndexGet', ['index', 'query3'], ['result3']))
|
|
result3 = workspace.FetchBlob('result3')
|
|
np.testing.assert_array_equal([1, 4, 1, 5, 5], result3)
|
|
|
|
def test_string_index_ops(self):
|
|
self._test_index_ops([
|
|
'entry1', 'entry2', 'entry3', 'new_entry1',
|
|
'new_entry2', 'miss1', 'miss2', 'miss3',
|
|
], str, 'StringIndexCreate')
|
|
|
|
def test_int_index_ops(self):
|
|
self._test_index_ops(range(8), np.int32, 'IntIndexCreate')
|
|
|
|
def test_long_index_ops(self):
|
|
self._test_index_ops(range(8), np.int64, 'LongIndexCreate')
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|