mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix GemmBatched
Summary: Fix GemmBatched Reviewed By: Yangqing Differential Revision: D6678168 fbshipit-source-id: 132117633573600d4e31c1959a0ccbe34416e1f1
This commit is contained in:
parent
9eeb342bf9
commit
0a8a18ca01
8 changed files with 476 additions and 132 deletions
|
|
@ -219,9 +219,6 @@ class BatchMatMulOp final : public Operator<Context> {
|
|||
size_t A_stride = 1; // How far to increment A pointer each itr
|
||||
size_t B_stride = 1; // How far to increment B pointer each itr
|
||||
size_t Y_stride = 1; // How far to increment Y pointer each itr
|
||||
// How large the slices of A and B we are operating on at each iteration
|
||||
// are.
|
||||
size_t A_slice_size, B_slice_size;
|
||||
// How many "inner batches" we have. That is, the product of sizes for
|
||||
// the slices excluding M, K, and N, for their respective matrices.
|
||||
size_t num_sub_batches = 1;
|
||||
|
|
@ -235,12 +232,9 @@ class BatchMatMulOp final : public Operator<Context> {
|
|||
num_sub_batches *= *(first_r_itr + i);
|
||||
}
|
||||
}
|
||||
A_slice_size = A_stride;
|
||||
B_stride = 0;
|
||||
B_slice_size = B.size();
|
||||
} else {
|
||||
A_stride = 0;
|
||||
A_slice_size = A.size();
|
||||
auto second_r_itr = dims_B.rbegin();
|
||||
auto output_r_itr = new_dims.rbegin();
|
||||
for (size_t i = 0; i < num_inner_dims; ++i) {
|
||||
|
|
@ -250,7 +244,6 @@ class BatchMatMulOp final : public Operator<Context> {
|
|||
num_sub_batches *= *(second_r_itr + i);
|
||||
}
|
||||
}
|
||||
B_slice_size = B_stride;
|
||||
}
|
||||
|
||||
size_t num_outer_batches = 1;
|
||||
|
|
@ -280,9 +273,6 @@ class BatchMatMulOp final : public Operator<Context> {
|
|||
math::GemmBatched<T, Context, Engine>(
|
||||
trans_a_ ? CblasTrans : CblasNoTrans,
|
||||
trans_b_ ? CblasTrans : CblasNoTrans,
|
||||
A_slice_size,
|
||||
num_sub_batches,
|
||||
B_slice_size,
|
||||
num_sub_batches,
|
||||
M,
|
||||
N,
|
||||
|
|
|
|||
107
caffe2/operators/batch_matmul_op_gpu_test.cc
Normal file
107
caffe2/operators/batch_matmul_op_gpu_test.cc
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/batch_matmul_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
using testing::ElementsAreArray;
|
||||
using testing::Test;
|
||||
|
||||
class BatchMatMulOpGPUTest : public Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
if (!HasCudaGPU()) {
|
||||
return;
|
||||
}
|
||||
option_.set_device_type(CUDA);
|
||||
cuda_context_ = make_unique<CUDAContext>(option_);
|
||||
def_.set_name("test");
|
||||
def_.set_type("BatchMatMul");
|
||||
def_.add_input("A");
|
||||
def_.add_input("B");
|
||||
def_.add_output("Y");
|
||||
def_.mutable_device_option()->set_device_type(CUDA);
|
||||
}
|
||||
|
||||
void AddConstInput(
|
||||
const std::vector<TIndex>& dims,
|
||||
const float value,
|
||||
const string& name) {
|
||||
Blob* blob = ws_.CreateBlob(name);
|
||||
auto* tensor = blob->GetMutable<Tensor<CUDAContext>>();
|
||||
tensor->Resize(dims);
|
||||
math::Set<float, CUDAContext>(
|
||||
tensor->size(),
|
||||
value,
|
||||
tensor->mutable_data<float>(),
|
||||
cuda_context_.get());
|
||||
}
|
||||
|
||||
void VerifyOutput(const std::vector<TIndex>& dims, const float value) const {
|
||||
const Blob* Y_blob = ws_.GetBlob("Y");
|
||||
ASSERT_NE(nullptr, Y_blob);
|
||||
const auto& Y = Y_blob->Get<Tensor<CUDAContext>>();
|
||||
TensorCPU Y_cpu(Y);
|
||||
ASSERT_THAT(Y_cpu.dims(), ElementsAreArray(dims));
|
||||
for (int i = 0; i < Y_cpu.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(value, Y_cpu.data<float>()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceOption option_;
|
||||
std::unique_ptr<CUDAContext> cuda_context_;
|
||||
Workspace ws_;
|
||||
OperatorDef def_;
|
||||
};
|
||||
|
||||
TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUNormalTest) {
|
||||
if (!HasCudaGPU()) {
|
||||
return;
|
||||
}
|
||||
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
|
||||
AddConstInput(std::vector<TIndex>{3, 10, 6}, 1.0f, "B");
|
||||
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
|
||||
ASSERT_NE(nullptr, op);
|
||||
ASSERT_TRUE(op->Run());
|
||||
VerifyOutput(std::vector<TIndex>{3, 5, 6}, 10.0f);
|
||||
}
|
||||
|
||||
TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUBroadcastTest) {
|
||||
if (!HasCudaGPU()) {
|
||||
return;
|
||||
}
|
||||
auto* arg = def_.add_arg();
|
||||
arg->set_name("broadcast");
|
||||
arg->set_i(1);
|
||||
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
|
||||
AddConstInput(std::vector<TIndex>{2, 3, 10, 6}, 1.0f, "B");
|
||||
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
|
||||
ASSERT_NE(nullptr, op);
|
||||
ASSERT_TRUE(op->Run());
|
||||
VerifyOutput(std::vector<TIndex>{2, 3, 5, 6}, 10.0f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
94
caffe2/operators/batch_matmul_op_test.cc
Normal file
94
caffe2/operators/batch_matmul_op_test.cc
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/operators/batch_matmul_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
using testing::ElementsAreArray;
|
||||
using testing::Test;
|
||||
|
||||
class BatchMatMulOpTest : public Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
cpu_context_ = make_unique<CPUContext>(option_);
|
||||
def_.set_name("test");
|
||||
def_.set_type("BatchMatMul");
|
||||
def_.add_input("A");
|
||||
def_.add_input("B");
|
||||
def_.add_output("Y");
|
||||
}
|
||||
|
||||
void AddConstInput(
|
||||
const std::vector<TIndex>& dims,
|
||||
const float value,
|
||||
const string& name) {
|
||||
Blob* blob = ws_.CreateBlob(name);
|
||||
auto* tensor = blob->GetMutable<TensorCPU>();
|
||||
tensor->Resize(dims);
|
||||
math::Set<float, CPUContext>(
|
||||
tensor->size(),
|
||||
value,
|
||||
tensor->mutable_data<float>(),
|
||||
cpu_context_.get());
|
||||
}
|
||||
|
||||
void VerifyOutput(const std::vector<TIndex>& dims, const float value) const {
|
||||
const Blob* Y_blob = ws_.GetBlob("Y");
|
||||
ASSERT_NE(nullptr, Y_blob);
|
||||
const auto& Y = Y_blob->Get<TensorCPU>();
|
||||
ASSERT_THAT(Y.dims(), ElementsAreArray(dims));
|
||||
for (int i = 0; i < Y.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(value, Y.data<float>()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceOption option_;
|
||||
std::unique_ptr<CPUContext> cpu_context_;
|
||||
Workspace ws_;
|
||||
OperatorDef def_;
|
||||
};
|
||||
|
||||
TEST_F(BatchMatMulOpTest, BatchMatMulOpNormalTest) {
|
||||
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
|
||||
AddConstInput(std::vector<TIndex>{3, 10, 6}, 1.0f, "B");
|
||||
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
|
||||
ASSERT_NE(nullptr, op);
|
||||
ASSERT_TRUE(op->Run());
|
||||
VerifyOutput(std::vector<TIndex>{3, 5, 6}, 10.0f);
|
||||
}
|
||||
|
||||
TEST_F(BatchMatMulOpTest, BatchMatMulOpBroadcastTest) {
|
||||
auto* arg = def_.add_arg();
|
||||
arg->set_name("broadcast");
|
||||
arg->set_i(1);
|
||||
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
|
||||
AddConstInput(std::vector<TIndex>{2, 3, 10, 6}, 1.0f, "B");
|
||||
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
|
||||
ASSERT_NE(nullptr, op);
|
||||
ASSERT_TRUE(op->Run());
|
||||
VerifyOutput(std::vector<TIndex>{2, 3, 5, 6}, 10.0f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
@ -252,10 +252,7 @@ template <typename T, class Context, class Engine = DefaultEngine>
|
|||
void GemmBatched(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
#include <numeric>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "caffe2/utils/cpu_neon.h"
|
||||
|
|
@ -425,10 +426,7 @@ template <>
|
|||
void GemmBatched<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
@ -440,25 +438,55 @@ void GemmBatched<float, CPUContext>(
|
|||
CPUContext* context,
|
||||
Tensor<CPUContext>*, /* scratch */
|
||||
TensorProto::DataType /* math_type */) {
|
||||
const int a_stride = M * K;
|
||||
const int b_stride = K * N;
|
||||
const int c_stride = M * N;
|
||||
|
||||
auto a_offset = A_size / A_batches;
|
||||
auto b_offset = B_size / B_batches;
|
||||
auto y_offset = M * N;
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
const int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
const int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
std::vector<const float*> a_array(batch_size, nullptr);
|
||||
std::vector<const float*> b_array(batch_size, nullptr);
|
||||
std::vector<float*> c_array(batch_size, nullptr);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
a_array[i] = A + a_stride * i;
|
||||
b_array[i] = B + b_stride * i;
|
||||
c_array[i] = C + c_stride * i;
|
||||
}
|
||||
cblas_sgemm_batch(
|
||||
CblasRowMajor,
|
||||
&TransA,
|
||||
&TransB,
|
||||
&M,
|
||||
&N,
|
||||
&K,
|
||||
&alpha,
|
||||
a_array.data(),
|
||||
&lda,
|
||||
b_array.data(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c_array.data(),
|
||||
&N, // ldc_array
|
||||
1,
|
||||
&batch_size);
|
||||
#else // CAFFE2_USE_MKL
|
||||
// loop over matrices in the batch
|
||||
for (int i = 0; i < A_batches; ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
math::Gemm<float, CPUContext>(
|
||||
TransA,
|
||||
TransB,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
A + a_offset * i,
|
||||
B + b_offset * i,
|
||||
0,
|
||||
C + y_offset * i,
|
||||
alpha,
|
||||
A + a_stride * i,
|
||||
B + b_stride * i,
|
||||
beta,
|
||||
C + c_stride * i,
|
||||
context);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
|||
|
|
@ -281,10 +281,7 @@ template <>
|
|||
void GemmBatched<float, CUDAContext>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
@ -296,55 +293,53 @@ void GemmBatched<float, CUDAContext>(
|
|||
CUDAContext* context,
|
||||
Tensor<CUDAContext>* scratch,
|
||||
TensorProto::DataType math_type) {
|
||||
|
||||
const int a_stride = M * K;
|
||||
const int b_stride = K * N;
|
||||
const int c_stride = M * N;
|
||||
#if __CUDACC_VER_MAJOR__ < 8
|
||||
auto a_offset = A_size / A_batches;
|
||||
auto b_offset = B_size / B_batches;
|
||||
auto y_offset = M * N;
|
||||
// loop over matrices in the batch
|
||||
for (int i = 0; i < A_batches; ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
math::Gemm<float, CUDAContext>(
|
||||
TransA,
|
||||
TransB,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
A + a_offset * i,
|
||||
B + b_offset * i,
|
||||
0,
|
||||
C + y_offset * i,
|
||||
alpha,
|
||||
A + a_stride * i,
|
||||
B + b_stride * i,
|
||||
beta,
|
||||
C + c_stride * i,
|
||||
context);
|
||||
}
|
||||
#else
|
||||
// Note that cublas follows fortran order, so the order is different from
|
||||
// the cblas convention.
|
||||
int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
|
||||
const int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
const int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
cublasOperation_t cuTransA =
|
||||
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
cublasOperation_t cuTransB =
|
||||
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
CUBLAS_ENFORCE(cublasSgemmStridedBatched(
|
||||
context->cublas_handle(),
|
||||
cuTransB,
|
||||
cuTransA,
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha,
|
||||
B,
|
||||
ldb,
|
||||
B_size / B_batches, // B stride
|
||||
A,
|
||||
lda,
|
||||
A_size / A_batches, // A stride
|
||||
&beta,
|
||||
C,
|
||||
N,
|
||||
M*N, // C stride
|
||||
A_batches));
|
||||
context->cublas_handle(),
|
||||
cuTransB,
|
||||
cuTransA,
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha,
|
||||
B,
|
||||
ldb,
|
||||
b_stride,
|
||||
A,
|
||||
lda,
|
||||
a_stride,
|
||||
&beta,
|
||||
C,
|
||||
N,
|
||||
c_stride,
|
||||
batch_size));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -368,10 +363,7 @@ template <>
|
|||
void GemmBatched<float16, CUDAContext>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
@ -383,24 +375,23 @@ void GemmBatched<float16, CUDAContext>(
|
|||
CUDAContext* context,
|
||||
Tensor<CUDAContext>* scratch,
|
||||
TensorProto::DataType math_type) {
|
||||
|
||||
const int a_stride = M * K;
|
||||
const int b_stride = K * N;
|
||||
const int c_stride = M * N;
|
||||
#if __CUDACC_VER_MAJOR__ < 8
|
||||
auto a_offset = A_size / A_batches;
|
||||
auto b_offset = B_size / B_batches;
|
||||
auto y_offset = M * N;
|
||||
// loop over matrices in the batch
|
||||
for (int i = 0; i < A_batches; ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
math::Gemm<float16, CUDAContext>(
|
||||
TransA,
|
||||
TransB,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
A + a_offset * i,
|
||||
B + b_offset * i,
|
||||
0,
|
||||
C + y_offset * i,
|
||||
alpha,
|
||||
A + a_stride * i,
|
||||
B + b_stride * i,
|
||||
beta,
|
||||
C + c_stride * i,
|
||||
context);
|
||||
}
|
||||
#else
|
||||
|
|
@ -410,11 +401,13 @@ void GemmBatched<float16, CUDAContext>(
|
|||
// 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm
|
||||
|
||||
if (scratch != nullptr) {
|
||||
const int A_size = a_stride * batch_size;
|
||||
const int B_size = b_stride * batch_size;
|
||||
// cast, cublasSgemmStridedBatched, cast
|
||||
size_t in_elems = A_size + B_size;
|
||||
size_t out_elems = A_batches*M*N;
|
||||
size_t out_elems = c_stride * batch_size;
|
||||
|
||||
scratch->Resize(in_elems+out_elems);
|
||||
scratch->Resize(in_elems + out_elems);
|
||||
float* scratch_ptr = scratch->mutable_data<float>();
|
||||
|
||||
float* A_fp32 = scratch_ptr;
|
||||
|
|
@ -432,13 +425,10 @@ void GemmBatched<float16, CUDAContext>(
|
|||
context->cuda_stream()>>>(B_size, (half*)B, B_fp32);
|
||||
|
||||
// run fp32 batched Gemm
|
||||
GemmBatched<float,CUDAContext>(
|
||||
GemmBatched<float, CUDAContext>(
|
||||
TransA,
|
||||
TransB,
|
||||
A_size,
|
||||
A_batches,
|
||||
B_size,
|
||||
B_batches,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
|
@ -450,35 +440,33 @@ void GemmBatched<float16, CUDAContext>(
|
|||
context);
|
||||
|
||||
// cast result back to fp16
|
||||
FloatToHalfKernel<<<CAFFE_GET_BLOCKS(A_batches*M*N),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context->cuda_stream()>>>(A_batches*M*N, C_fp32, (half*)C);
|
||||
FloatToHalfKernel<<<
|
||||
CAFFE_GET_BLOCKS(batch_size * M * N),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context->cuda_stream()>>>(batch_size * M * N, C_fp32, (half*)C);
|
||||
} else {
|
||||
if (math_type == TensorProto_DataType_FLOAT) {
|
||||
auto a_offset = A_size / A_batches;
|
||||
auto b_offset = B_size / B_batches;
|
||||
auto y_offset = M * N;
|
||||
// loop over matrices in the batch
|
||||
for (int i = 0; i < A_batches; ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
math::Gemm<float16, CUDAContext>(
|
||||
TransA,
|
||||
TransB,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
A + a_offset * i,
|
||||
B + b_offset * i,
|
||||
0,
|
||||
C + y_offset * i,
|
||||
alpha,
|
||||
A + a_stride * i,
|
||||
B + b_stride * i,
|
||||
beta,
|
||||
C + c_stride * i,
|
||||
context);
|
||||
}
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
// Note that cublas follows fortran order, so the order is different from
|
||||
// the cblas convention.
|
||||
int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
const int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
const int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
cublasOperation_t cuTransA =
|
||||
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
cublasOperation_t cuTransB =
|
||||
|
|
@ -488,24 +476,24 @@ void GemmBatched<float16, CUDAContext>(
|
|||
auto alpha_fp16 = convert::floatToHalf(alpha);
|
||||
auto beta_fp16 = convert::floatToHalf(beta);
|
||||
CUBLAS_ENFORCE(cublasHgemmStridedBatched(
|
||||
context->cublas_handle(),
|
||||
cuTransB,
|
||||
cuTransA,
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha_fp16,
|
||||
(const __half*)B,
|
||||
ldb,
|
||||
B_size / B_batches,
|
||||
(const __half*)A,
|
||||
lda,
|
||||
A_size / A_batches,
|
||||
&beta_fp16,
|
||||
(__half*)C,
|
||||
N,
|
||||
M*N,
|
||||
A_batches));
|
||||
context->cublas_handle(),
|
||||
cuTransB,
|
||||
cuTransA,
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha_fp16,
|
||||
(const __half*)B,
|
||||
ldb,
|
||||
b_stride,
|
||||
(const __half*)A,
|
||||
lda,
|
||||
a_stride,
|
||||
&beta_fp16,
|
||||
(__half*)C,
|
||||
N,
|
||||
c_stride,
|
||||
batch_size));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -605,10 +593,7 @@ template <>
|
|||
void GemmBatched<float, CUDAContext, TensorCoreEngine>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
@ -623,10 +608,7 @@ void GemmBatched<float, CUDAContext, TensorCoreEngine>(
|
|||
return GemmBatched<float, CUDAContext, DefaultEngine>(
|
||||
TransA,
|
||||
TransB,
|
||||
A_size,
|
||||
A_batches,
|
||||
B_size,
|
||||
B_batches,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
|
@ -644,10 +626,7 @@ template <>
|
|||
void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB,
|
||||
const int A_size,
|
||||
const int A_batches,
|
||||
const int B_size,
|
||||
const int B_batches,
|
||||
const int batch_size,
|
||||
const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
|
|
@ -662,10 +641,7 @@ void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
|
|||
return GemmBatched<float16, CUDAContext, DefaultEngine>(
|
||||
TransA,
|
||||
TransB,
|
||||
A_size,
|
||||
A_batches,
|
||||
B_size,
|
||||
B_batches,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,11 @@
|
|||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/flags.h"
|
||||
|
|
@ -254,4 +257,84 @@ TEST(MathUtilGPUTest, testCopyVector) {
|
|||
[](int i) { return 5.0f - i; });
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class GemmBatchedGPUTest
|
||||
: public testing::TestWithParam<testing::tuple<bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
if (!HasCudaGPU()) {
|
||||
return;
|
||||
}
|
||||
option_.set_device_type(CUDA);
|
||||
cuda_context_ = make_unique<CUDAContext>(option_);
|
||||
Blob* X_blob = ws_.CreateBlob("X");
|
||||
Blob* W_blob = ws_.CreateBlob("W");
|
||||
Blob* Y_blob = ws_.CreateBlob("Y");
|
||||
X_ = X_blob->GetMutable<Tensor<CUDAContext>>();
|
||||
W_ = W_blob->GetMutable<Tensor<CUDAContext>>();
|
||||
Y_ = Y_blob->GetMutable<Tensor<CUDAContext>>();
|
||||
X_->Resize(std::vector<TIndex>{3, 5, 10});
|
||||
W_->Resize(std::vector<TIndex>{3, 6, 10});
|
||||
Y_->Resize(std::vector<TIndex>{3, 5, 6});
|
||||
math::Set<float, CUDAContext>(
|
||||
X_->size(), 1.0f, X_->mutable_data<float>(), cuda_context_.get());
|
||||
math::Set<float, CUDAContext>(
|
||||
W_->size(), 1.0f, W_->mutable_data<float>(), cuda_context_.get());
|
||||
trans_X_ = std::get<0>(GetParam());
|
||||
trans_W_ = std::get<1>(GetParam());
|
||||
}
|
||||
|
||||
void RunGemmBatched(const float alpha, const float beta) {
|
||||
math::GemmBatched(
|
||||
trans_X_ ? CblasTrans : CblasNoTrans,
|
||||
trans_W_ ? CblasTrans : CblasNoTrans,
|
||||
3,
|
||||
5,
|
||||
6,
|
||||
10,
|
||||
alpha,
|
||||
X_->template data<float>(),
|
||||
W_->template data<float>(),
|
||||
beta,
|
||||
Y_->template mutable_data<float>(),
|
||||
cuda_context_.get());
|
||||
}
|
||||
|
||||
void VerifyOutput(const float value) const {
|
||||
TensorCPU Y_cpu(*Y_);
|
||||
for (int i = 0; i < Y_cpu.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(value, Y_cpu.template data<float>()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Workspace ws_;
|
||||
DeviceOption option_;
|
||||
std::unique_ptr<CUDAContext> cuda_context_;
|
||||
Tensor<CUDAContext>* X_ = nullptr;
|
||||
Tensor<CUDAContext>* W_ = nullptr;
|
||||
Tensor<CUDAContext>* Y_ = nullptr;
|
||||
bool trans_X_;
|
||||
bool trans_W_;
|
||||
};
|
||||
|
||||
TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) {
|
||||
if (!HasCudaGPU()) {
|
||||
return;
|
||||
}
|
||||
RunGemmBatched(1.0f, 0.0f);
|
||||
VerifyOutput(10.0f);
|
||||
RunGemmBatched(1.0f, 0.5f);
|
||||
VerifyOutput(15.0f);
|
||||
RunGemmBatched(0.5f, 1.0f);
|
||||
VerifyOutput(20.0f);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
GemmBatchedGPUTrans,
|
||||
GemmBatchedGPUTest,
|
||||
testing::Combine(testing::Bool(), testing::Bool()));
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -14,7 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
|
|
@ -116,6 +120,71 @@ TEST(MathTest, GemmNoTransTrans) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class GemmBatchedTest
|
||||
: public testing::TestWithParam<testing::tuple<bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
cpu_context_ = make_unique<CPUContext>(option_);
|
||||
X_.Resize(std::vector<TIndex>{3, 5, 10});
|
||||
W_.Resize(std::vector<TIndex>{3, 6, 10});
|
||||
Y_.Resize(std::vector<TIndex>{3, 5, 6});
|
||||
math::Set<float, CPUContext>(
|
||||
X_.size(), 1, X_.mutable_data<float>(), cpu_context_.get());
|
||||
math::Set<float, CPUContext>(
|
||||
W_.size(), 1, W_.mutable_data<float>(), cpu_context_.get());
|
||||
trans_X_ = std::get<0>(GetParam());
|
||||
trans_W_ = std::get<1>(GetParam());
|
||||
}
|
||||
|
||||
void RunGemmBatched(const float alpha, const float beta) {
|
||||
math::GemmBatched(
|
||||
trans_X_ ? CblasTrans : CblasNoTrans,
|
||||
trans_W_ ? CblasTrans : CblasNoTrans,
|
||||
3,
|
||||
5,
|
||||
6,
|
||||
10,
|
||||
alpha,
|
||||
X_.template data<float>(),
|
||||
W_.template data<float>(),
|
||||
beta,
|
||||
Y_.template mutable_data<float>(),
|
||||
cpu_context_.get());
|
||||
}
|
||||
|
||||
void VerifyOutput(const float value) const {
|
||||
for (int i = 0; i < Y_.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(value, Y_.template data<float>()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceOption option_;
|
||||
std::unique_ptr<CPUContext> cpu_context_;
|
||||
TensorCPU X_;
|
||||
TensorCPU W_;
|
||||
TensorCPU Y_;
|
||||
bool trans_X_;
|
||||
bool trans_W_;
|
||||
};
|
||||
|
||||
TEST_P(GemmBatchedTest, GemmBatchedFloatTest) {
|
||||
RunGemmBatched(1.0f, 0.0f);
|
||||
VerifyOutput(10.0f);
|
||||
RunGemmBatched(1.0f, 0.5f);
|
||||
VerifyOutput(15.0f);
|
||||
RunGemmBatched(0.5f, 1.0f);
|
||||
VerifyOutput(20.0f);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
GemmBatchedTrans,
|
||||
GemmBatchedTest,
|
||||
testing::Combine(testing::Bool(), testing::Bool()));
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(MathTest, GemvNoTrans) {
|
||||
DeviceOption option;
|
||||
CPUContext cpu_context(option);
|
||||
|
|
|
|||
Loading…
Reference in a new issue