Fix GemmBatched

Summary: Fix GemmBatched

Reviewed By: Yangqing

Differential Revision: D6678168

fbshipit-source-id: 132117633573600d4e31c1959a0ccbe34416e1f1
This commit is contained in:
Xiaomeng Yang 2018-01-10 18:13:41 -08:00 committed by Facebook Github Bot
parent 9eeb342bf9
commit 0a8a18ca01
8 changed files with 476 additions and 132 deletions

View file

@ -219,9 +219,6 @@ class BatchMatMulOp final : public Operator<Context> {
size_t A_stride = 1; // How far to increment A pointer each itr
size_t B_stride = 1; // How far to increment B pointer each itr
size_t Y_stride = 1; // How far to increment Y pointer each itr
// How large the slices of A and B we are operating on at each iteration
// are.
size_t A_slice_size, B_slice_size;
// How many "inner batches" we have. That is, the product of sizes for
// the slices excluding M, K, and N, for their respective matrices.
size_t num_sub_batches = 1;
@ -235,12 +232,9 @@ class BatchMatMulOp final : public Operator<Context> {
num_sub_batches *= *(first_r_itr + i);
}
}
A_slice_size = A_stride;
B_stride = 0;
B_slice_size = B.size();
} else {
A_stride = 0;
A_slice_size = A.size();
auto second_r_itr = dims_B.rbegin();
auto output_r_itr = new_dims.rbegin();
for (size_t i = 0; i < num_inner_dims; ++i) {
@ -250,7 +244,6 @@ class BatchMatMulOp final : public Operator<Context> {
num_sub_batches *= *(second_r_itr + i);
}
}
B_slice_size = B_stride;
}
size_t num_outer_batches = 1;
@ -280,9 +273,6 @@ class BatchMatMulOp final : public Operator<Context> {
math::GemmBatched<T, Context, Engine>(
trans_a_ ? CblasTrans : CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
A_slice_size,
num_sub_batches,
B_slice_size,
num_sub_batches,
M,
N,

View file

@ -0,0 +1,107 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/batch_matmul_op.h"
namespace caffe2 {
namespace {
using testing::ElementsAreArray;
using testing::Test;
class BatchMatMulOpGPUTest : public Test {
protected:
void SetUp() override {
if (!HasCudaGPU()) {
return;
}
option_.set_device_type(CUDA);
cuda_context_ = make_unique<CUDAContext>(option_);
def_.set_name("test");
def_.set_type("BatchMatMul");
def_.add_input("A");
def_.add_input("B");
def_.add_output("Y");
def_.mutable_device_option()->set_device_type(CUDA);
}
void AddConstInput(
const std::vector<TIndex>& dims,
const float value,
const string& name) {
Blob* blob = ws_.CreateBlob(name);
auto* tensor = blob->GetMutable<Tensor<CUDAContext>>();
tensor->Resize(dims);
math::Set<float, CUDAContext>(
tensor->size(),
value,
tensor->mutable_data<float>(),
cuda_context_.get());
}
void VerifyOutput(const std::vector<TIndex>& dims, const float value) const {
const Blob* Y_blob = ws_.GetBlob("Y");
ASSERT_NE(nullptr, Y_blob);
const auto& Y = Y_blob->Get<Tensor<CUDAContext>>();
TensorCPU Y_cpu(Y);
ASSERT_THAT(Y_cpu.dims(), ElementsAreArray(dims));
for (int i = 0; i < Y_cpu.size(); ++i) {
EXPECT_FLOAT_EQ(value, Y_cpu.data<float>()[i]);
}
}
DeviceOption option_;
std::unique_ptr<CUDAContext> cuda_context_;
Workspace ws_;
OperatorDef def_;
};
TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUNormalTest) {
if (!HasCudaGPU()) {
return;
}
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
AddConstInput(std::vector<TIndex>{3, 10, 6}, 1.0f, "B");
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
ASSERT_NE(nullptr, op);
ASSERT_TRUE(op->Run());
VerifyOutput(std::vector<TIndex>{3, 5, 6}, 10.0f);
}
TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUBroadcastTest) {
if (!HasCudaGPU()) {
return;
}
auto* arg = def_.add_arg();
arg->set_name("broadcast");
arg->set_i(1);
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
AddConstInput(std::vector<TIndex>{2, 3, 10, 6}, 1.0f, "B");
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
ASSERT_NE(nullptr, op);
ASSERT_TRUE(op->Run());
VerifyOutput(std::vector<TIndex>{2, 3, 5, 6}, 10.0f);
}
} // namespace
} // namespace caffe2

View file

@ -0,0 +1,94 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "caffe2/operators/batch_matmul_op.h"
namespace caffe2 {
namespace {
using testing::ElementsAreArray;
using testing::Test;
class BatchMatMulOpTest : public Test {
protected:
void SetUp() override {
cpu_context_ = make_unique<CPUContext>(option_);
def_.set_name("test");
def_.set_type("BatchMatMul");
def_.add_input("A");
def_.add_input("B");
def_.add_output("Y");
}
void AddConstInput(
const std::vector<TIndex>& dims,
const float value,
const string& name) {
Blob* blob = ws_.CreateBlob(name);
auto* tensor = blob->GetMutable<TensorCPU>();
tensor->Resize(dims);
math::Set<float, CPUContext>(
tensor->size(),
value,
tensor->mutable_data<float>(),
cpu_context_.get());
}
void VerifyOutput(const std::vector<TIndex>& dims, const float value) const {
const Blob* Y_blob = ws_.GetBlob("Y");
ASSERT_NE(nullptr, Y_blob);
const auto& Y = Y_blob->Get<TensorCPU>();
ASSERT_THAT(Y.dims(), ElementsAreArray(dims));
for (int i = 0; i < Y.size(); ++i) {
EXPECT_FLOAT_EQ(value, Y.data<float>()[i]);
}
}
DeviceOption option_;
std::unique_ptr<CPUContext> cpu_context_;
Workspace ws_;
OperatorDef def_;
};
TEST_F(BatchMatMulOpTest, BatchMatMulOpNormalTest) {
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
AddConstInput(std::vector<TIndex>{3, 10, 6}, 1.0f, "B");
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
ASSERT_NE(nullptr, op);
ASSERT_TRUE(op->Run());
VerifyOutput(std::vector<TIndex>{3, 5, 6}, 10.0f);
}
TEST_F(BatchMatMulOpTest, BatchMatMulOpBroadcastTest) {
auto* arg = def_.add_arg();
arg->set_name("broadcast");
arg->set_i(1);
AddConstInput(std::vector<TIndex>{3, 5, 10}, 1.0f, "A");
AddConstInput(std::vector<TIndex>{2, 3, 10, 6}, 1.0f, "B");
std::unique_ptr<OperatorBase> op(CreateOperator(def_, &ws_));
ASSERT_NE(nullptr, op);
ASSERT_TRUE(op->Run());
VerifyOutput(std::vector<TIndex>{2, 3, 5, 6}, 10.0f);
}
} // namespace
} // namespace caffe2

View file

@ -252,10 +252,7 @@ template <typename T, class Context, class Engine = DefaultEngine>
void GemmBatched(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,

View file

@ -34,6 +34,7 @@
#include <numeric>
#include <random>
#include <unordered_set>
#include <vector>
#include "caffe2/utils/math.h"
#include "caffe2/utils/cpu_neon.h"
@ -425,10 +426,7 @@ template <>
void GemmBatched<float, CPUContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,
@ -440,25 +438,55 @@ void GemmBatched<float, CPUContext>(
CPUContext* context,
Tensor<CPUContext>*, /* scratch */
TensorProto::DataType /* math_type */) {
const int a_stride = M * K;
const int b_stride = K * N;
const int c_stride = M * N;
auto a_offset = A_size / A_batches;
auto b_offset = B_size / B_batches;
auto y_offset = M * N;
#ifdef CAFFE2_USE_MKL
const int lda = (TransA == CblasNoTrans) ? K : M;
const int ldb = (TransB == CblasNoTrans) ? N : K;
std::vector<const float*> a_array(batch_size, nullptr);
std::vector<const float*> b_array(batch_size, nullptr);
std::vector<float*> c_array(batch_size, nullptr);
for (int i = 0; i < batch_size; ++i) {
a_array[i] = A + a_stride * i;
b_array[i] = B + b_stride * i;
c_array[i] = C + c_stride * i;
}
cblas_sgemm_batch(
CblasRowMajor,
&TransA,
&TransB,
&M,
&N,
&K,
&alpha,
a_array.data(),
&lda,
b_array.data(),
&ldb,
&beta,
c_array.data(),
&N, // ldc_array
1,
&batch_size);
#else // CAFFE2_USE_MKL
// loop over matrices in the batch
for (int i = 0; i < A_batches; ++i) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm<float, CPUContext>(
TransA,
TransB,
M,
N,
K,
1,
A + a_offset * i,
B + b_offset * i,
0,
C + y_offset * i,
alpha,
A + a_stride * i,
B + b_stride * i,
beta,
C + c_stride * i,
context);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////

View file

@ -281,10 +281,7 @@ template <>
void GemmBatched<float, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,
@ -296,55 +293,53 @@ void GemmBatched<float, CUDAContext>(
CUDAContext* context,
Tensor<CUDAContext>* scratch,
TensorProto::DataType math_type) {
const int a_stride = M * K;
const int b_stride = K * N;
const int c_stride = M * N;
#if __CUDACC_VER_MAJOR__ < 8
auto a_offset = A_size / A_batches;
auto b_offset = B_size / B_batches;
auto y_offset = M * N;
// loop over matrices in the batch
for (int i = 0; i < A_batches; ++i) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm<float, CUDAContext>(
TransA,
TransB,
M,
N,
K,
1,
A + a_offset * i,
B + b_offset * i,
0,
C + y_offset * i,
alpha,
A + a_stride * i,
B + b_stride * i,
beta,
C + c_stride * i,
context);
}
#else
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
const int lda = (TransA == CblasNoTrans) ? K : M;
const int ldb = (TransB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_ENFORCE(cublasSgemmStridedBatched(
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
B_size / B_batches, // B stride
A,
lda,
A_size / A_batches, // A stride
&beta,
C,
N,
M*N, // C stride
A_batches));
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
b_stride,
A,
lda,
a_stride,
&beta,
C,
N,
c_stride,
batch_size));
#endif
}
@ -368,10 +363,7 @@ template <>
void GemmBatched<float16, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,
@ -383,24 +375,23 @@ void GemmBatched<float16, CUDAContext>(
CUDAContext* context,
Tensor<CUDAContext>* scratch,
TensorProto::DataType math_type) {
const int a_stride = M * K;
const int b_stride = K * N;
const int c_stride = M * N;
#if __CUDACC_VER_MAJOR__ < 8
auto a_offset = A_size / A_batches;
auto b_offset = B_size / B_batches;
auto y_offset = M * N;
// loop over matrices in the batch
for (int i = 0; i < A_batches; ++i) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm<float16, CUDAContext>(
TransA,
TransB,
M,
N,
K,
1,
A + a_offset * i,
B + b_offset * i,
0,
C + y_offset * i,
alpha,
A + a_stride * i,
B + b_stride * i,
beta,
C + c_stride * i,
context);
}
#else
@ -410,11 +401,13 @@ void GemmBatched<float16, CUDAContext>(
// 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm
if (scratch != nullptr) {
const int A_size = a_stride * batch_size;
const int B_size = b_stride * batch_size;
// cast, cublasSgemmStridedBatched, cast
size_t in_elems = A_size + B_size;
size_t out_elems = A_batches*M*N;
size_t out_elems = c_stride * batch_size;
scratch->Resize(in_elems+out_elems);
scratch->Resize(in_elems + out_elems);
float* scratch_ptr = scratch->mutable_data<float>();
float* A_fp32 = scratch_ptr;
@ -432,13 +425,10 @@ void GemmBatched<float16, CUDAContext>(
context->cuda_stream()>>>(B_size, (half*)B, B_fp32);
// run fp32 batched Gemm
GemmBatched<float,CUDAContext>(
GemmBatched<float, CUDAContext>(
TransA,
TransB,
A_size,
A_batches,
B_size,
B_batches,
batch_size,
M,
N,
K,
@ -450,35 +440,33 @@ void GemmBatched<float16, CUDAContext>(
context);
// cast result back to fp16
FloatToHalfKernel<<<CAFFE_GET_BLOCKS(A_batches*M*N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(A_batches*M*N, C_fp32, (half*)C);
FloatToHalfKernel<<<
CAFFE_GET_BLOCKS(batch_size * M * N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(batch_size * M * N, C_fp32, (half*)C);
} else {
if (math_type == TensorProto_DataType_FLOAT) {
auto a_offset = A_size / A_batches;
auto b_offset = B_size / B_batches;
auto y_offset = M * N;
// loop over matrices in the batch
for (int i = 0; i < A_batches; ++i) {
for (int i = 0; i < batch_size; ++i) {
math::Gemm<float16, CUDAContext>(
TransA,
TransB,
M,
N,
K,
1,
A + a_offset * i,
B + b_offset * i,
0,
C + y_offset * i,
alpha,
A + a_stride * i,
B + b_stride * i,
beta,
C + c_stride * i,
context);
}
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
const int lda = (TransA == CblasNoTrans) ? K : M;
const int ldb = (TransB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
@ -488,24 +476,24 @@ void GemmBatched<float16, CUDAContext>(
auto alpha_fp16 = convert::floatToHalf(alpha);
auto beta_fp16 = convert::floatToHalf(beta);
CUBLAS_ENFORCE(cublasHgemmStridedBatched(
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_fp16,
(const __half*)B,
ldb,
B_size / B_batches,
(const __half*)A,
lda,
A_size / A_batches,
&beta_fp16,
(__half*)C,
N,
M*N,
A_batches));
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha_fp16,
(const __half*)B,
ldb,
b_stride,
(const __half*)A,
lda,
a_stride,
&beta_fp16,
(__half*)C,
N,
c_stride,
batch_size));
}
}
#endif
@ -605,10 +593,7 @@ template <>
void GemmBatched<float, CUDAContext, TensorCoreEngine>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,
@ -623,10 +608,7 @@ void GemmBatched<float, CUDAContext, TensorCoreEngine>(
return GemmBatched<float, CUDAContext, DefaultEngine>(
TransA,
TransB,
A_size,
A_batches,
B_size,
B_batches,
batch_size,
M,
N,
K,
@ -644,10 +626,7 @@ template <>
void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int A_size,
const int A_batches,
const int B_size,
const int B_batches,
const int batch_size,
const int M,
const int N,
const int K,
@ -662,10 +641,7 @@ void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
return GemmBatched<float16, CUDAContext, DefaultEngine>(
TransA,
TransB,
A_size,
A_batches,
B_size,
B_batches,
batch_size,
M,
N,
K,

View file

@ -15,8 +15,11 @@
*/
#include <iostream>
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/flags.h"
@ -254,4 +257,84 @@ TEST(MathUtilGPUTest, testCopyVector) {
[](int i) { return 5.0f - i; });
}
namespace {
class GemmBatchedGPUTest
: public testing::TestWithParam<testing::tuple<bool, bool>> {
protected:
void SetUp() override {
if (!HasCudaGPU()) {
return;
}
option_.set_device_type(CUDA);
cuda_context_ = make_unique<CUDAContext>(option_);
Blob* X_blob = ws_.CreateBlob("X");
Blob* W_blob = ws_.CreateBlob("W");
Blob* Y_blob = ws_.CreateBlob("Y");
X_ = X_blob->GetMutable<Tensor<CUDAContext>>();
W_ = W_blob->GetMutable<Tensor<CUDAContext>>();
Y_ = Y_blob->GetMutable<Tensor<CUDAContext>>();
X_->Resize(std::vector<TIndex>{3, 5, 10});
W_->Resize(std::vector<TIndex>{3, 6, 10});
Y_->Resize(std::vector<TIndex>{3, 5, 6});
math::Set<float, CUDAContext>(
X_->size(), 1.0f, X_->mutable_data<float>(), cuda_context_.get());
math::Set<float, CUDAContext>(
W_->size(), 1.0f, W_->mutable_data<float>(), cuda_context_.get());
trans_X_ = std::get<0>(GetParam());
trans_W_ = std::get<1>(GetParam());
}
void RunGemmBatched(const float alpha, const float beta) {
math::GemmBatched(
trans_X_ ? CblasTrans : CblasNoTrans,
trans_W_ ? CblasTrans : CblasNoTrans,
3,
5,
6,
10,
alpha,
X_->template data<float>(),
W_->template data<float>(),
beta,
Y_->template mutable_data<float>(),
cuda_context_.get());
}
void VerifyOutput(const float value) const {
TensorCPU Y_cpu(*Y_);
for (int i = 0; i < Y_cpu.size(); ++i) {
EXPECT_FLOAT_EQ(value, Y_cpu.template data<float>()[i]);
}
}
Workspace ws_;
DeviceOption option_;
std::unique_ptr<CUDAContext> cuda_context_;
Tensor<CUDAContext>* X_ = nullptr;
Tensor<CUDAContext>* W_ = nullptr;
Tensor<CUDAContext>* Y_ = nullptr;
bool trans_X_;
bool trans_W_;
};
TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) {
if (!HasCudaGPU()) {
return;
}
RunGemmBatched(1.0f, 0.0f);
VerifyOutput(10.0f);
RunGemmBatched(1.0f, 0.5f);
VerifyOutput(15.0f);
RunGemmBatched(0.5f, 1.0f);
VerifyOutput(20.0f);
}
INSTANTIATE_TEST_CASE_P(
GemmBatchedGPUTrans,
GemmBatchedGPUTest,
testing::Combine(testing::Bool(), testing::Bool()));
} // namespace
} // namespace caffe2

View file

@ -14,7 +14,11 @@
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include "caffe2/core/blob.h"
#include "caffe2/core/context.h"
#include "caffe2/core/tensor.h"
@ -116,6 +120,71 @@ TEST(MathTest, GemmNoTransTrans) {
}
}
namespace {
class GemmBatchedTest
: public testing::TestWithParam<testing::tuple<bool, bool>> {
protected:
void SetUp() override {
cpu_context_ = make_unique<CPUContext>(option_);
X_.Resize(std::vector<TIndex>{3, 5, 10});
W_.Resize(std::vector<TIndex>{3, 6, 10});
Y_.Resize(std::vector<TIndex>{3, 5, 6});
math::Set<float, CPUContext>(
X_.size(), 1, X_.mutable_data<float>(), cpu_context_.get());
math::Set<float, CPUContext>(
W_.size(), 1, W_.mutable_data<float>(), cpu_context_.get());
trans_X_ = std::get<0>(GetParam());
trans_W_ = std::get<1>(GetParam());
}
void RunGemmBatched(const float alpha, const float beta) {
math::GemmBatched(
trans_X_ ? CblasTrans : CblasNoTrans,
trans_W_ ? CblasTrans : CblasNoTrans,
3,
5,
6,
10,
alpha,
X_.template data<float>(),
W_.template data<float>(),
beta,
Y_.template mutable_data<float>(),
cpu_context_.get());
}
void VerifyOutput(const float value) const {
for (int i = 0; i < Y_.size(); ++i) {
EXPECT_FLOAT_EQ(value, Y_.template data<float>()[i]);
}
}
DeviceOption option_;
std::unique_ptr<CPUContext> cpu_context_;
TensorCPU X_;
TensorCPU W_;
TensorCPU Y_;
bool trans_X_;
bool trans_W_;
};
TEST_P(GemmBatchedTest, GemmBatchedFloatTest) {
RunGemmBatched(1.0f, 0.0f);
VerifyOutput(10.0f);
RunGemmBatched(1.0f, 0.5f);
VerifyOutput(15.0f);
RunGemmBatched(0.5f, 1.0f);
VerifyOutput(20.0f);
}
INSTANTIATE_TEST_CASE_P(
GemmBatchedTrans,
GemmBatchedTest,
testing::Combine(testing::Bool(), testing::Bool()));
} // namespace
TEST(MathTest, GemvNoTrans) {
DeviceOption option;
CPUContext cpu_context(option);