mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Cast op cuda can deal with empty batch now. Reviewed By: azzolini Differential Revision: D6350138 fbshipit-source-id: 2f3d19f4d42ff34806aa9597690e66f6b4de1a6b
32 lines
1 KiB
Python
32 lines
1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
from caffe2.python import core
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
from hypothesis import given
|
|
import numpy as np
|
|
|
|
|
|
class TestCastOp(hu.HypothesisTestCase):
|
|
|
|
@given(**hu.gcs)
|
|
def test_cast_int_float(self, gc, dc):
|
|
data = np.random.rand(5, 5).astype(np.int32)
|
|
# from int to float
|
|
op = core.CreateOperator('Cast', 'data', 'data_cast', to=1, from_type=2)
|
|
self.assertDeviceChecks(dc, op, [data], [0])
|
|
# This is actually 0
|
|
self.assertGradientChecks(gc, op, [data], 0, [0])
|
|
|
|
@given(**hu.gcs)
|
|
def test_cast_int_float_empty(self, gc, dc):
|
|
data = np.random.rand(0).astype(np.int32)
|
|
# from int to float
|
|
op = core.CreateOperator('Cast', 'data', 'data_cast', to=1, from_type=2)
|
|
self.assertDeviceChecks(dc, op, [data], [0])
|
|
# This is actually 0
|
|
self.assertGradientChecks(gc, op, [data], 0, [0])
|