mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Added the possibility to add 'tiles' and 'axis' as input as opposed to arguments for the Tile Operator. If provided, the input values will override the argument values. Now with proper CUDA code Differential Revision: D4930347 fbshipit-source-id: b44b032b327c7d7bddfce63abf4e3289d7e74bfb
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
|
|
from hypothesis import given
|
|
import hypothesis.strategies as st
|
|
|
|
from caffe2.python import core
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
|
|
class TestTile(hu.HypothesisTestCase):
|
|
@given(M=st.integers(min_value=1, max_value=10),
|
|
K=st.integers(min_value=1, max_value=10),
|
|
N=st.integers(min_value=1, max_value=10),
|
|
tiles=st.integers(min_value=1, max_value=3),
|
|
axis=st.integers(min_value=0, max_value=2),
|
|
**hu.gcs)
|
|
def test_tile(self, M, K, N, tiles, axis, gc, dc):
|
|
X = np.random.rand(M, K, N).astype(np.float32)
|
|
|
|
op = core.CreateOperator(
|
|
'Tile', ['X'], 'out',
|
|
tiles=tiles,
|
|
axis=axis,
|
|
)
|
|
|
|
def tile_ref(X, tiles, axis):
|
|
dims = [1, 1, 1]
|
|
dims[axis] = tiles
|
|
tiled_data = np.tile(X, tuple(dims))
|
|
return (tiled_data,)
|
|
|
|
# Check against numpy reference
|
|
self.assertReferenceChecks(gc, op, [X, tiles, axis],
|
|
tile_ref)
|
|
# Check over multiple devices
|
|
self.assertDeviceChecks(dc, op, [X], [0])
|
|
# Gradient check wrt X
|
|
self.assertGradientChecks(gc, op, [X], 0, [0])
|
|
|
|
@given(M=st.integers(min_value=1, max_value=10),
|
|
K=st.integers(min_value=1, max_value=10),
|
|
N=st.integers(min_value=1, max_value=10),
|
|
tiles=st.integers(min_value=1, max_value=3),
|
|
axis=st.integers(min_value=0, max_value=2),
|
|
**hu.gcs)
|
|
def test_tilewinput(self, M, K, N, tiles, axis, gc, dc):
|
|
X = np.random.rand(M, K, N).astype(np.float32)
|
|
|
|
tiles_arg = np.array([tiles], dtype=np.int32)
|
|
axis_arg = np.array([axis], dtype=np.int32)
|
|
|
|
op = core.CreateOperator(
|
|
'Tile', ['X', 'tiles', 'axis'], 'out',
|
|
)
|
|
|
|
def tile_ref(X, tiles, axis):
|
|
dims = [1, 1, 1]
|
|
dims[axis] = tiles
|
|
tiled_data = np.tile(X, tuple(dims))
|
|
return (tiled_data,)
|
|
|
|
# Check against numpy reference
|
|
self.assertReferenceChecks(gc, op, [X, tiles_arg, axis_arg],
|
|
tile_ref)
|
|
# Check over multiple devices
|
|
self.assertDeviceChecks(dc, op, [X, tiles_arg, axis_arg], [0])
|
|
# Gradient check wrt X
|
|
self.assertGradientChecks(gc, op, [X, tiles_arg, axis_arg], 0, [0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|