mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46693 title Test Plan: unit tests Reviewed By: hx89 Differential Revision: D23946770 fbshipit-source-id: f7c3d4a1b4ef3b0e5f56e5a9a30f5003ce9f40b0
699 lines
23 KiB
Python
699 lines
23 KiB
Python
import functools
|
|
import operator
|
|
import string
|
|
|
|
import hypothesis.strategies as st
|
|
import numpy as np
|
|
import numpy.testing as npt
|
|
from caffe2.python import core, dataset, workspace
|
|
from caffe2.python.dataset import Const
|
|
from caffe2.python.schema import (
|
|
FeedRecord,
|
|
FetchRecord,
|
|
Field,
|
|
List,
|
|
Map,
|
|
NewRecord,
|
|
Scalar,
|
|
Struct,
|
|
from_blob_list,
|
|
)
|
|
from caffe2.python.test_util import TestCase
|
|
from hypothesis import given
|
|
|
|
|
|
def _assert_arrays_equal(actual, ref, err_msg):
|
|
if ref.dtype.kind in ("S", "O", "U"):
|
|
np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
|
|
else:
|
|
np.testing.assert_allclose(actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg)
|
|
|
|
|
|
def _assert_records_equal(actual, ref):
|
|
assert isinstance(actual, Field)
|
|
assert isinstance(ref, Field)
|
|
b1 = actual.field_blobs()
|
|
b2 = ref.field_blobs()
|
|
assert len(b1) == len(b2), "Records have different lengths: %d vs. %d" % (
|
|
len(b1),
|
|
len(b2),
|
|
)
|
|
for name, d1, d2 in zip(ref.field_names(), b1, b2):
|
|
_assert_arrays_equal(d1, d2, err_msg="Mismatch in field %s." % name)
|
|
|
|
|
|
@st.composite
|
|
def _sparse_features_map(draw, num_records, **kwargs):
|
|
sparse_maps_lengths = draw(
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=10),
|
|
min_size=num_records,
|
|
max_size=num_records,
|
|
)
|
|
)
|
|
|
|
sparse_maps_total_length = sum(sparse_maps_lengths)
|
|
|
|
sparse_keys = draw(
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=100),
|
|
min_size=sparse_maps_total_length,
|
|
max_size=sparse_maps_total_length,
|
|
unique=True,
|
|
)
|
|
)
|
|
|
|
sparse_values_lengths = draw(
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=10),
|
|
min_size=sparse_maps_total_length,
|
|
max_size=sparse_maps_total_length,
|
|
)
|
|
)
|
|
|
|
total_sparse_values_lengths = sum(sparse_values_lengths)
|
|
|
|
sparse_values = draw(
|
|
# max_value is max int64
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=9223372036854775807),
|
|
min_size=total_sparse_values_lengths,
|
|
max_size=total_sparse_values_lengths,
|
|
)
|
|
)
|
|
|
|
return [
|
|
sparse_maps_lengths,
|
|
sparse_keys,
|
|
sparse_values_lengths,
|
|
sparse_values,
|
|
]
|
|
|
|
|
|
@st.composite
|
|
def _dense_features_map(draw, num_records, **kwargs):
|
|
float_lengths = draw(
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=10),
|
|
min_size=num_records,
|
|
max_size=num_records,
|
|
)
|
|
)
|
|
|
|
total_length = sum(float_lengths)
|
|
|
|
float_keys = draw(
|
|
st.lists(
|
|
st.integers(min_value=1, max_value=100),
|
|
min_size=total_length,
|
|
max_size=total_length,
|
|
unique=True,
|
|
)
|
|
)
|
|
|
|
float_values = draw(
|
|
st.lists(st.floats(), min_size=total_length, max_size=total_length)
|
|
)
|
|
|
|
return [float_lengths, float_keys, float_values]
|
|
|
|
|
|
@st.composite
|
|
def _dataset(draw, min_elements=3, max_elements=10, **kwargs):
|
|
schema = Struct(
|
|
# Dense Features Map
|
|
("floats", Map(Scalar(np.int32), Scalar(np.float32))),
|
|
# Sparse Features Map
|
|
(
|
|
"int_lists",
|
|
Map(
|
|
Scalar(np.int32),
|
|
List(Scalar(np.int64)),
|
|
),
|
|
),
|
|
# Complex Type
|
|
("text", Scalar(str)),
|
|
)
|
|
|
|
num_records = draw(st.integers(min_value=min_elements, max_value=max_elements))
|
|
|
|
raw_dense_features_map_contents = draw(_dense_features_map(num_records))
|
|
|
|
raw_sparse_features_map_contents = draw(_sparse_features_map(num_records))
|
|
|
|
raw_text_contents = [
|
|
draw(
|
|
st.lists(
|
|
st.text(alphabet=string.ascii_lowercase),
|
|
min_size=num_records,
|
|
max_size=num_records,
|
|
)
|
|
)
|
|
]
|
|
|
|
# Concatenate all raw contents to a single one
|
|
contents_raw = (
|
|
raw_dense_features_map_contents
|
|
+ raw_sparse_features_map_contents
|
|
+ raw_text_contents
|
|
)
|
|
|
|
contents = from_blob_list(schema, contents_raw)
|
|
|
|
return (schema, contents, num_records)
|
|
|
|
|
|
class TestDatasetOps(TestCase):
|
|
@given(_dataset())
|
|
def test_pack_unpack(self, input):
|
|
"""
|
|
Tests if packing and unpacking of the whole dataset is an identity.
|
|
"""
|
|
(schema, contents, num_records) = input
|
|
|
|
dataset_fields = schema.field_names()
|
|
|
|
for pack_to_single_shared_ptr in (True, False):
|
|
net = core.Net("pack_unpack_net")
|
|
batch = NewRecord(net, contents)
|
|
FeedRecord(batch, contents)
|
|
|
|
packed = net.PackRecords(
|
|
batch.field_blobs(),
|
|
1,
|
|
fields=dataset_fields,
|
|
pack_to_single_shared_ptr=pack_to_single_shared_ptr,
|
|
)
|
|
|
|
unpacked = packed.UnPackRecords(
|
|
[], len(dataset_fields), fields=dataset_fields
|
|
)
|
|
|
|
workspace.RunNetOnce(net)
|
|
|
|
for initial_tensor, unpacked_tensor in zip(batch.field_blobs(), unpacked):
|
|
npt.assert_array_equal(
|
|
workspace.FetchBlob(initial_tensor),
|
|
workspace.FetchBlob(unpacked_tensor),
|
|
)
|
|
|
|
def test_dataset_ops(self):
|
|
"""
|
|
1. Defining the schema of our dataset.
|
|
|
|
This example schema could represent, for example, a search query log.
|
|
"""
|
|
schema = Struct(
|
|
# fixed size vector, which will be stored as a matrix when batched
|
|
("dense", Scalar((np.float32, 3))),
|
|
# could represent a feature map from feature ID to float value
|
|
("floats", Map(Scalar(np.int32), Scalar(np.float32))),
|
|
# could represent a multi-valued categorical feature map
|
|
(
|
|
"int_lists",
|
|
Map(
|
|
Scalar(np.int32),
|
|
List(Scalar(np.int64)),
|
|
),
|
|
),
|
|
# could represent a multi-valued, weighted categorical feature map
|
|
(
|
|
"id_score_pairs",
|
|
Map(
|
|
Scalar(np.int32),
|
|
Map(
|
|
Scalar(np.int64),
|
|
Scalar(np.float32),
|
|
keys_name="ids",
|
|
values_name="scores",
|
|
),
|
|
),
|
|
),
|
|
# additional scalar information
|
|
(
|
|
"metadata",
|
|
Struct(
|
|
("user_id", Scalar(np.int64)),
|
|
("user_embed", Scalar((np.float32, 2))),
|
|
("query", Scalar(str)),
|
|
),
|
|
),
|
|
)
|
|
"""
|
|
This is what the flattened fields for this schema look like, along
|
|
with its type. Each one of these fields will be stored, read and
|
|
written as a tensor.
|
|
"""
|
|
expected_fields = [
|
|
("dense", (np.float32, 3)),
|
|
("floats:lengths", np.int32),
|
|
("floats:values:keys", np.int32),
|
|
("floats:values:values", np.float32),
|
|
("int_lists:lengths", np.int32),
|
|
("int_lists:values:keys", np.int32),
|
|
("int_lists:values:values:lengths", np.int32),
|
|
("int_lists:values:values:values", np.int64),
|
|
("id_score_pairs:lengths", np.int32),
|
|
("id_score_pairs:values:keys", np.int32),
|
|
("id_score_pairs:values:values:lengths", np.int32),
|
|
("id_score_pairs:values:values:values:ids", np.int64),
|
|
("id_score_pairs:values:values:values:scores", np.float32),
|
|
("metadata:user_id", np.int64),
|
|
("metadata:user_embed", (np.float32, 2)),
|
|
("metadata:query", str),
|
|
]
|
|
zipped = zip(expected_fields, schema.field_names(), schema.field_types())
|
|
for (ref_name, ref_type), name, dtype in zipped:
|
|
self.assertEquals(ref_name, name)
|
|
self.assertEquals(np.dtype(ref_type), dtype)
|
|
"""
|
|
2. The contents of our dataset.
|
|
|
|
Contents as defined below could represent, for example, a log of
|
|
search queries along with dense, sparse features and metadata.
|
|
The dataset below has 3 top-level entries.
|
|
"""
|
|
contents_raw = [
|
|
# dense
|
|
[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]],
|
|
# floats
|
|
[1, 2, 3], # len
|
|
[11, 21, 22, 31, 32, 33], # key
|
|
[1.1, 2.1, 2.2, 3.1, 3.2, 3.3], # value
|
|
# int lists
|
|
[2, 0, 1], # len
|
|
[11, 12, 31], # key
|
|
[2, 4, 3], # value:len
|
|
[111, 112, 121, 122, 123, 124, 311, 312, 313], # value:value
|
|
# id score pairs
|
|
[1, 2, 2], # len
|
|
[11, 21, 22, 31, 32], # key
|
|
[1, 1, 2, 2, 3], # value:len
|
|
[111, 211, 221, 222, 311, 312, 321, 322, 323], # value:ids
|
|
[11.1, 21.1, 22.1, 22.2, 31.1, 31.2, 32.1, 32.2, 32.3], # val:score
|
|
# metadata
|
|
[123, 234, 456], # user_id
|
|
[[0.2, 0.8], [0.5, 0.5], [0.7, 0.3]], # user_embed
|
|
["dog posts", "friends who like to", "posts about ca"], # query
|
|
]
|
|
# convert the above content to ndarrays, checking against the schema
|
|
contents = from_blob_list(schema, contents_raw)
|
|
"""
|
|
3. Creating and appending to the dataset.
|
|
We first create an empty dataset with the given schema.
|
|
Then, a Writer is used to append these entries to the dataset.
|
|
"""
|
|
ds = dataset.Dataset(schema)
|
|
net = core.Net("init")
|
|
with core.NameScope("init"):
|
|
ds.init_empty(net)
|
|
|
|
content_blobs = NewRecord(net, contents)
|
|
FeedRecord(content_blobs, contents)
|
|
writer = ds.writer(init_net=net)
|
|
writer.write_record(net, content_blobs)
|
|
workspace.RunNetOnce(net)
|
|
"""
|
|
4. Iterating through the dataset contents.
|
|
|
|
If we were to iterate through the top level entries of our dataset,
|
|
this is what we should expect to see:
|
|
"""
|
|
entries_raw = [
|
|
(
|
|
[[1.1, 1.2, 1.3]], # dense
|
|
[1],
|
|
[11],
|
|
[1.1], # floats
|
|
[2],
|
|
[11, 12],
|
|
[2, 4],
|
|
[111, 112, 121, 122, 123, 124], # intlst
|
|
[1],
|
|
[11],
|
|
[1],
|
|
[111],
|
|
[11.1], # id score pairs
|
|
[123],
|
|
[[0.2, 0.8]],
|
|
["dog posts"], # metadata
|
|
),
|
|
(
|
|
[[2.1, 2.2, 2.3]], # dense
|
|
[2],
|
|
[21, 22],
|
|
[2.1, 2.2], # floats
|
|
[0],
|
|
[],
|
|
[],
|
|
[], # int list
|
|
[2],
|
|
[21, 22],
|
|
[1, 2],
|
|
[211, 221, 222],
|
|
[21.1, 22.1, 22.2],
|
|
[234],
|
|
[[0.5, 0.5]],
|
|
["friends who like to"], # metadata
|
|
),
|
|
(
|
|
[[3.1, 3.2, 3.3]], # dense
|
|
[3],
|
|
[31, 32, 33],
|
|
[3.1, 3.2, 3.3], # floats
|
|
[1],
|
|
[31],
|
|
[3],
|
|
[311, 312, 313], # int lst
|
|
[2],
|
|
[31, 32],
|
|
[2, 3],
|
|
[311, 312, 321, 322, 323],
|
|
[31.1, 31.2, 32.1, 32.2, 32.3], # id score list
|
|
[456],
|
|
[[0.7, 0.3]],
|
|
["posts about ca"], # metadata
|
|
),
|
|
# after the end of the dataset, we will keep getting empty vectors
|
|
([],) * 16,
|
|
([],) * 16,
|
|
]
|
|
entries = [from_blob_list(schema, e) for e in entries_raw]
|
|
"""
|
|
Let's go ahead and create the reading nets.
|
|
We will run `read` net multiple times and assert that we are reading the
|
|
entries the way we stated above.
|
|
"""
|
|
read_init_net = core.Net("read_init")
|
|
read_next_net = core.Net("read_next")
|
|
reader = ds.reader(read_init_net)
|
|
should_continue, batch = reader.read_record(read_next_net)
|
|
|
|
workspace.RunNetOnce(read_init_net)
|
|
workspace.CreateNet(read_next_net, True)
|
|
|
|
for entry in entries:
|
|
workspace.RunNet(str(read_next_net))
|
|
actual = FetchRecord(batch)
|
|
_assert_records_equal(actual, entry)
|
|
"""
|
|
5. Reading/writing in a single plan
|
|
|
|
If all of operations on the data are expressible as Caffe2 operators,
|
|
we don't need to load the data to python, iterating through the dataset
|
|
in a single Plan.
|
|
|
|
Where we will process the dataset a little and store it in a second
|
|
dataset. We can reuse the same Reader since it supports reset.
|
|
"""
|
|
reset_net = core.Net("reset_net")
|
|
reader.reset(reset_net)
|
|
read_step, batch = reader.execution_step()
|
|
""" We will add the line number * 1000 to the feature ids. """
|
|
process_net = core.Net("process")
|
|
line_no = Const(process_net, 0, dtype=np.int32)
|
|
const_one = Const(process_net, 1000, dtype=np.int32)
|
|
process_net.Add([line_no, const_one], [line_no])
|
|
field = batch.floats.keys.get()
|
|
process_net.Print(field, [])
|
|
process_net.Add([field, line_no], field, broadcast=1, axis=0)
|
|
""" Lets create a second dataset and append to it. """
|
|
ds2 = dataset.Dataset(schema, name="dataset2")
|
|
ds2.init_empty(reset_net)
|
|
writer = ds2.writer(reset_net)
|
|
writer.write_record(process_net, batch)
|
|
# commit is not necessary for DatasetWriter but will add it for
|
|
# generality of the example
|
|
commit_net = core.Net("commit")
|
|
writer.commit(commit_net)
|
|
""" Time to create and run a plan which will do the processing """
|
|
plan = core.Plan("process")
|
|
plan.AddStep(core.execution_step("reset", reset_net))
|
|
plan.AddStep(read_step.AddNet(process_net))
|
|
plan.AddStep(core.execution_step("commit", commit_net))
|
|
workspace.RunPlan(plan)
|
|
"""
|
|
Now we should have dataset2 populated.
|
|
"""
|
|
ds2_data = FetchRecord(ds2.content())
|
|
field = ds2_data.floats.keys
|
|
field.set(blob=field.get() - [1000, 2000, 2000, 3000, 3000, 3000])
|
|
_assert_records_equal(contents, ds2_data)
|
|
"""
|
|
6. Slicing a dataset
|
|
|
|
You can create a new schema from pieces of another schema and reuse
|
|
the same data.
|
|
"""
|
|
subschema = Struct(("top_level", schema.int_lists.values))
|
|
int_list_contents = contents.int_lists.values.field_names()
|
|
self.assertEquals(len(subschema.field_names()), len(int_list_contents))
|
|
"""
|
|
7. Random Access a dataset
|
|
|
|
"""
|
|
read_init_net = core.Net("read_init")
|
|
read_next_net = core.Net("read_next")
|
|
|
|
idx = np.array([2, 1, 0])
|
|
indices_blob = Const(read_init_net, idx, name="indices")
|
|
reader = ds.random_reader(read_init_net, indices_blob)
|
|
reader.computeoffset(read_init_net)
|
|
|
|
should_stop, batch = reader.read_record(read_next_net)
|
|
|
|
workspace.CreateNet(read_init_net, True)
|
|
workspace.RunNetOnce(read_init_net)
|
|
|
|
workspace.CreateNet(read_next_net, True)
|
|
|
|
for i in range(len(entries)):
|
|
k = idx[i] if i in idx else i
|
|
entry = entries[k]
|
|
workspace.RunNet(str(read_next_net))
|
|
actual = FetchRecord(batch)
|
|
_assert_records_equal(actual, entry)
|
|
workspace.RunNet(str(read_next_net))
|
|
self.assertEquals(True, workspace.FetchBlob(should_stop))
|
|
"""
|
|
8. Random Access a dataset with loop_over = true
|
|
|
|
"""
|
|
read_init_net = core.Net("read_init")
|
|
read_next_net = core.Net("read_next")
|
|
|
|
idx = np.array([2, 1, 0])
|
|
indices_blob = Const(read_init_net, idx, name="indices")
|
|
reader = ds.random_reader(read_init_net, indices_blob, loop_over=True)
|
|
reader.computeoffset(read_init_net)
|
|
|
|
should_stop, batch = reader.read_record(read_next_net)
|
|
|
|
workspace.CreateNet(read_init_net, True)
|
|
workspace.RunNetOnce(read_init_net)
|
|
|
|
workspace.CreateNet(read_next_net, True)
|
|
|
|
for _ in range(len(entries) * 3):
|
|
workspace.RunNet(str(read_next_net))
|
|
self.assertEquals(False, workspace.FetchBlob(should_stop))
|
|
"""
|
|
9. Sort and shuffle a dataset
|
|
|
|
This sort the dataset using the score of a certain column,
|
|
and then shuffle within each chunk of size batch_size * shuffle_size
|
|
before shuffling the chunks.
|
|
|
|
"""
|
|
read_init_net = core.Net("read_init")
|
|
read_next_net = core.Net("read_next")
|
|
|
|
reader = ds.random_reader(read_init_net)
|
|
reader.sort_and_shuffle(read_init_net, "int_lists:lengths", 1, 2)
|
|
reader.computeoffset(read_init_net)
|
|
|
|
should_continue, batch = reader.read_record(read_next_net)
|
|
|
|
workspace.CreateNet(read_init_net, True)
|
|
workspace.RunNetOnce(read_init_net)
|
|
|
|
workspace.CreateNet(read_next_net, True)
|
|
|
|
expected_idx = np.array([2, 1, 0])
|
|
for i in range(len(entries)):
|
|
k = expected_idx[i] if i in expected_idx else i
|
|
entry = entries[k]
|
|
workspace.RunNet(str(read_next_net))
|
|
actual = FetchRecord(batch)
|
|
_assert_records_equal(actual, entry)
|
|
|
|
"""
|
|
Trim a dataset
|
|
"""
|
|
trim_net = core.Net("trim_ds")
|
|
ds.trim(trim_net, multiple_of=2)
|
|
workspace.RunNetOnce(trim_net)
|
|
trimmed = FetchRecord(ds.content())
|
|
EXPECTED_SIZES = [2, 2, 3, 3, 2, 2, 2, 6, 2, 3, 3, 4, 4, 2, 2, 2]
|
|
actual_sizes = [d.shape[0] for d in trimmed.field_blobs()]
|
|
self.assertEquals(EXPECTED_SIZES, actual_sizes)
|
|
|
|
def test_last_n_window_ops(self):
|
|
collect_net = core.Net("collect_net")
|
|
collect_net.GivenTensorFill(
|
|
[],
|
|
"input",
|
|
shape=[3, 2],
|
|
values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
)
|
|
input_array = np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2)
|
|
|
|
workspace.CreateBlob("output")
|
|
workspace.FeedBlob("next", np.array(0, dtype=np.int32))
|
|
collect_net.LastNWindowCollector(
|
|
["output", "next", "input"],
|
|
["output", "next"],
|
|
num_to_collect=7,
|
|
)
|
|
plan = core.Plan("collect_data")
|
|
plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=1))
|
|
workspace.RunPlan(plan)
|
|
reference_result = workspace.FetchBlob("output")
|
|
npt.assert_array_equal(input_array, reference_result)
|
|
|
|
plan = core.Plan("collect_data")
|
|
plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=2))
|
|
workspace.RunPlan(plan)
|
|
reference_result = workspace.FetchBlob("output")
|
|
npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]], reference_result)
|
|
|
|
plan = core.Plan("collect_data")
|
|
plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=3))
|
|
workspace.RunPlan(plan)
|
|
reference_result = workspace.FetchBlob("output")
|
|
npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]], reference_result)
|
|
|
|
def test_last_n_window_ops_shape_inference(self):
|
|
collect_net = core.Net("collect_net")
|
|
collect_net.GivenTensorFill(
|
|
[],
|
|
"input",
|
|
shape=[3, 2],
|
|
values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
|
)
|
|
|
|
workspace.CreateBlob("output")
|
|
workspace.FeedBlob("next", np.array(0, dtype=np.int32))
|
|
collect_net.LastNWindowCollector(
|
|
["output", "next", "input"],
|
|
["output", "next"],
|
|
num_to_collect=7,
|
|
)
|
|
(shapes, types) = workspace.InferShapesAndTypes([collect_net])
|
|
workspace.RunNetOnce(collect_net)
|
|
|
|
self.assertTrue(
|
|
np.array_equal(
|
|
shapes["output"], np.array([7, workspace.blobs["output"].shape[1]])
|
|
)
|
|
)
|
|
|
|
def test_last_n_window_ops_shape_inference_4d_input(self):
|
|
input_shape = [3, 2, 4, 5]
|
|
collect_net = core.Net("collect_net")
|
|
collect_net.GivenTensorFill(
|
|
[],
|
|
"input",
|
|
shape=input_shape,
|
|
values=[
|
|
float(val) for val in range(functools.reduce(operator.mul, input_shape))
|
|
],
|
|
)
|
|
|
|
workspace.CreateBlob("output")
|
|
workspace.FeedBlob("next", np.array(0, dtype=np.int32))
|
|
collect_net.LastNWindowCollector(
|
|
["output", "next", "input"],
|
|
["output", "next"],
|
|
num_to_collect=7,
|
|
)
|
|
(shapes, types) = workspace.InferShapesAndTypes([collect_net])
|
|
workspace.RunNetOnce(collect_net)
|
|
|
|
self.assertTrue(
|
|
np.array_equal(
|
|
shapes["output"], np.array([7, *list(workspace.blobs["output"].shape[1:])])
|
|
)
|
|
)
|
|
|
|
def test_collect_tensor_ops(self):
|
|
init_net = core.Net("init_net")
|
|
blobs = ["blob_1", "blob_2", "blob_3"]
|
|
bvec_map = {}
|
|
ONE = init_net.ConstantFill([], "ONE", shape=[1, 2], value=1)
|
|
for b in blobs:
|
|
init_net.ConstantFill([], [b], shape=[1, 2], value=0)
|
|
bvec_map[b] = b + "_vec"
|
|
init_net.CreateTensorVector([], [bvec_map[b]])
|
|
|
|
reader_net = core.Net("reader_net")
|
|
for b in blobs:
|
|
reader_net.Add([b, ONE], [b])
|
|
|
|
collect_net = core.Net("collect_net")
|
|
num_to_collect = 1000
|
|
max_example_to_cover = 100000
|
|
bvec = [bvec_map[b] for b in blobs]
|
|
collect_net.CollectTensor(
|
|
bvec + blobs,
|
|
bvec,
|
|
num_to_collect=num_to_collect,
|
|
)
|
|
|
|
print("Collect Net Proto: {}".format(collect_net.Proto()))
|
|
|
|
plan = core.Plan("collect_data")
|
|
plan.AddStep(core.execution_step("collect_init", init_net))
|
|
plan.AddStep(
|
|
core.execution_step(
|
|
"collect_data", [reader_net, collect_net], num_iter=max_example_to_cover
|
|
)
|
|
)
|
|
workspace.RunPlan(plan)
|
|
|
|
# concat the collected tensors
|
|
concat_net = core.Net("concat_net")
|
|
bconcated_map = {}
|
|
bsize_map = {}
|
|
for b in blobs:
|
|
bconcated_map[b] = b + "_concated"
|
|
bsize_map[b] = b + "_size"
|
|
concat_net.ConcatTensorVector([bvec_map[b]], [bconcated_map[b]])
|
|
concat_net.TensorVectorSize([bvec_map[b]], [bsize_map[b]])
|
|
|
|
workspace.RunNetOnce(concat_net)
|
|
|
|
# check data
|
|
reference_result = workspace.FetchBlob(bconcated_map[blobs[0]])
|
|
self.assertEqual(
|
|
reference_result.shape, (min(num_to_collect, max_example_to_cover), 2)
|
|
)
|
|
size = workspace.FetchBlob(bsize_map[blobs[0]])
|
|
self.assertEqual(tuple(), size.shape)
|
|
self.assertEqual(min(num_to_collect, max_example_to_cover), size.item())
|
|
|
|
hist, _ = np.histogram(
|
|
reference_result[:, 0], bins=10, range=(1, max_example_to_cover)
|
|
)
|
|
print("Sample histogram: {}".format(hist))
|
|
|
|
self.assertTrue(all(hist > 0.6 * (num_to_collect / 10)))
|
|
for i in range(1, len(blobs)):
|
|
result = workspace.FetchBlob(bconcated_map[blobs[i]])
|
|
self.assertEqual(reference_result.tolist(), result.tolist())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
|
|
unittest.main()
|