mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: There is a module called `2to3` which you can target for future specifically to remove these, the directory of `caffe2` has the most redundant imports: ```2to3 -f future -w caffe2``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/45033 Reviewed By: seemethere Differential Revision: D23808648 Pulled By: bugra fbshipit-source-id: 38971900f0fe43ab44a9168e57f2307580d36a38
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
|
|
|
|
|
|
|
|
|
|
from caffe2.python import core
|
|
from hypothesis import given
|
|
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
import caffe2.python.serialized_test.serialized_test_util as serial
|
|
import hypothesis.strategies as st
|
|
import itertools as it
|
|
import numpy as np
|
|
|
|
|
|
class TestMomentsOp(serial.SerializedTestCase):
|
|
def run_moments_test(self, X, axes, keepdims, gc, dc):
|
|
if axes is None:
|
|
op = core.CreateOperator(
|
|
"Moments",
|
|
["X"],
|
|
["mean", "variance"],
|
|
keepdims=keepdims,
|
|
)
|
|
else:
|
|
op = core.CreateOperator(
|
|
"Moments",
|
|
["X"],
|
|
["mean", "variance"],
|
|
axes=axes,
|
|
keepdims=keepdims,
|
|
)
|
|
|
|
def ref(X):
|
|
mean = np.mean(X, axis=None if axes is None else tuple(
|
|
axes), keepdims=keepdims)
|
|
variance = np.var(X, axis=None if axes is None else tuple(
|
|
axes), keepdims=keepdims)
|
|
return [mean, variance]
|
|
|
|
self.assertReferenceChecks(gc, op, [X], ref)
|
|
self.assertDeviceChecks(dc, op, [X], [0, 1])
|
|
self.assertGradientChecks(gc, op, [X], 0, [0, 1])
|
|
|
|
@serial.given(X=hu.tensor(dtype=np.float32), keepdims=st.booleans(),
|
|
num_axes=st.integers(1, 4), **hu.gcs)
|
|
def test_moments(self, X, keepdims, num_axes, gc, dc):
|
|
self.run_moments_test(X, None, keepdims, gc, dc)
|
|
num_dims = len(X.shape)
|
|
if num_dims < num_axes:
|
|
self.run_moments_test(X, range(num_dims), keepdims, gc, dc)
|
|
else:
|
|
for axes in it.combinations(range(num_dims), num_axes):
|
|
self.run_moments_test(X, axes, keepdims, gc, dc)
|