mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Instead of doing gemms in a for-loop (which is not parallelized), it is much better to do the batched matmuls using CUDA 8's new batched-striped version of gemm. With the MT team's test, we get 5-10% improvement in overall walltime, so it is significant improvement: ---- Without batched gemm: I0328 10:46:48.118605 58068 prof_dag_net.cc:136] 424.757 ms/iter ( 283.878 ms/iter) RecurrentNetwork I0328 10:46:48.118609 58068 prof_dag_net.cc:136] 352.603 ms/iter ( 265.85 ms/iter) RecurrentNetworkGradient With batched gemm: I0328 10:53:48.169996 85617 prof_dag_net.cc:136] 407.438 ms/iter ( 269.564 ms/iter) RecurrentNetwork I0328 10:53:48.169999 85617 prof_dag_net.cc:136] 322.393 ms/iter ( 287.625 ms/iter) RecurrentNetworkGradient Reviewed By: jamesr66a Differential Revision: D4788272 fbshipit-source-id: 210e8b94c1e036b6ef0f039ce000d455258651f4
92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
|
|
from hypothesis import given
|
|
import hypothesis.strategies as st
|
|
|
|
from caffe2.python import core
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
|
|
class TestMatMul(hu.HypothesisTestCase):
|
|
@given(M=st.integers(min_value=1, max_value=10),
|
|
K=st.integers(min_value=1, max_value=10),
|
|
N=st.integers(min_value=1, max_value=10),
|
|
trans_a=st.booleans(),
|
|
trans_b=st.booleans(),
|
|
**hu.gcs)
|
|
def test_matmul(self, M, K, N, trans_a, trans_b, gc, dc):
|
|
X = np.random.rand(M, K).astype(np.float32) - 0.5
|
|
if trans_a:
|
|
X = X.transpose()
|
|
|
|
Y = np.random.rand(K, N).astype(np.float32) - 0.5
|
|
if trans_b:
|
|
Y = Y.transpose()
|
|
|
|
op = core.CreateOperator(
|
|
'MatMul', ['X', 'Y'], 'out',
|
|
trans_a=trans_a, trans_b=trans_b)
|
|
|
|
def matmul_ref(X, Y, trans_a, trans_b):
|
|
XX = X.transpose() if trans_a else X
|
|
YY = Y.transpose() if trans_b else Y
|
|
return (XX.dot(YY),)
|
|
|
|
# Check against numpy reference
|
|
self.assertReferenceChecks(gc, op, [X, Y, trans_a, trans_b],
|
|
matmul_ref)
|
|
# Check over multiple devices
|
|
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
|
# Gradient check wrt X
|
|
self.assertGradientChecks(gc, op, [X, Y], 0, [0])
|
|
# Gradient check wrt Y
|
|
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
|
|
|
|
|
class TestBatchMatMul(hu.HypothesisTestCase):
|
|
@given(C=st.integers(min_value=1, max_value=10),
|
|
M=st.integers(min_value=1, max_value=10),
|
|
K=st.integers(min_value=1, max_value=10),
|
|
N=st.integers(min_value=1, max_value=10),
|
|
trans_a=st.booleans(),
|
|
trans_b=st.booleans(),
|
|
**hu.gcs)
|
|
def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, gc, dc):
|
|
X = np.random.rand(C, M, K).astype(np.float32) - 0.5
|
|
if trans_a:
|
|
X = X.swapaxes(1, 2)
|
|
|
|
Y = np.random.rand(C, K, N).astype(np.float32) - 0.5
|
|
if trans_b:
|
|
Y = Y.swapaxes(1, 2)
|
|
|
|
op = core.CreateOperator(
|
|
'BatchMatMul', ['X', 'Y'], 'out',
|
|
trans_a=trans_a, trans_b=trans_b)
|
|
|
|
def matmul_ref(X, Y, trans_a, trans_b):
|
|
XX = X.swapaxes(1, 2) if trans_a else X
|
|
YY = Y.swapaxes(1, 2) if trans_b else Y
|
|
output = np.zeros((C, M, N)).astype(XX.dtype)
|
|
for i in range(C):
|
|
output[i] = XX[i].dot(YY[i])
|
|
return (output,)
|
|
|
|
# Check against numpy reference
|
|
self.assertReferenceChecks(gc, op, [X, Y, trans_a, trans_b],
|
|
matmul_ref)
|
|
# Check over multiple devices
|
|
self.assertDeviceChecks(dc, op, [X, Y], [0])
|
|
# Gradient check wrt X
|
|
self.assertGradientChecks(gc, op, [X, Y], 0, [0])
|
|
# Gradient check wrt Y
|
|
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|