From 0a8a18ca0124d93df03ea4b291f484fe72103e75 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Wed, 10 Jan 2018 18:13:41 -0800 Subject: [PATCH] Fix GemmBatched Summary: Fix GemmBatched Reviewed By: Yangqing Differential Revision: D6678168 fbshipit-source-id: 132117633573600d4e31c1959a0ccbe34416e1f1 --- caffe2/operators/batch_matmul_op.h | 10 - caffe2/operators/batch_matmul_op_gpu_test.cc | 107 +++++++++++ caffe2/operators/batch_matmul_op_test.cc | 94 ++++++++++ caffe2/utils/math.h | 5 +- caffe2/utils/math_cpu.cc | 54 ++++-- caffe2/utils/math_gpu.cu | 186 ++++++++----------- caffe2/utils/math_gpu_test.cc | 83 +++++++++ caffe2/utils/math_test.cc | 69 +++++++ 8 files changed, 476 insertions(+), 132 deletions(-) create mode 100644 caffe2/operators/batch_matmul_op_gpu_test.cc create mode 100644 caffe2/operators/batch_matmul_op_test.cc diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h index b36566bed0c..ff8cd535ade 100644 --- a/caffe2/operators/batch_matmul_op.h +++ b/caffe2/operators/batch_matmul_op.h @@ -219,9 +219,6 @@ class BatchMatMulOp final : public Operator { size_t A_stride = 1; // How far to increment A pointer each itr size_t B_stride = 1; // How far to increment B pointer each itr size_t Y_stride = 1; // How far to increment Y pointer each itr - // How large the slices of A and B we are operating on at each iteration - // are. - size_t A_slice_size, B_slice_size; // How many "inner batches" we have. That is, the product of sizes for // the slices excluding M, K, and N, for their respective matrices. size_t num_sub_batches = 1; @@ -235,12 +232,9 @@ class BatchMatMulOp final : public Operator { num_sub_batches *= *(first_r_itr + i); } } - A_slice_size = A_stride; B_stride = 0; - B_slice_size = B.size(); } else { A_stride = 0; - A_slice_size = A.size(); auto second_r_itr = dims_B.rbegin(); auto output_r_itr = new_dims.rbegin(); for (size_t i = 0; i < num_inner_dims; ++i) { @@ -250,7 +244,6 @@ class BatchMatMulOp final : public Operator { num_sub_batches *= *(second_r_itr + i); } } - B_slice_size = B_stride; } size_t num_outer_batches = 1; @@ -280,9 +273,6 @@ class BatchMatMulOp final : public Operator { math::GemmBatched( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, - A_slice_size, - num_sub_batches, - B_slice_size, num_sub_batches, M, N, diff --git a/caffe2/operators/batch_matmul_op_gpu_test.cc b/caffe2/operators/batch_matmul_op_gpu_test.cc new file mode 100644 index 00000000000..d1d42310380 --- /dev/null +++ b/caffe2/operators/batch_matmul_op_gpu_test.cc @@ -0,0 +1,107 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/batch_matmul_op.h" + +namespace caffe2 { +namespace { + +using testing::ElementsAreArray; +using testing::Test; + +class BatchMatMulOpGPUTest : public Test { + protected: + void SetUp() override { + if (!HasCudaGPU()) { + return; + } + option_.set_device_type(CUDA); + cuda_context_ = make_unique(option_); + def_.set_name("test"); + def_.set_type("BatchMatMul"); + def_.add_input("A"); + def_.add_input("B"); + def_.add_output("Y"); + def_.mutable_device_option()->set_device_type(CUDA); + } + + void AddConstInput( + const std::vector& dims, + const float value, + const string& name) { + Blob* blob = ws_.CreateBlob(name); + auto* tensor = blob->GetMutable>(); + tensor->Resize(dims); + math::Set( + tensor->size(), + value, + tensor->mutable_data(), + cuda_context_.get()); + } + + void VerifyOutput(const std::vector& dims, const float value) const { + const Blob* Y_blob = ws_.GetBlob("Y"); + ASSERT_NE(nullptr, Y_blob); + const auto& Y = Y_blob->Get>(); + TensorCPU Y_cpu(Y); + ASSERT_THAT(Y_cpu.dims(), ElementsAreArray(dims)); + for (int i = 0; i < Y_cpu.size(); ++i) { + EXPECT_FLOAT_EQ(value, Y_cpu.data()[i]); + } + } + + DeviceOption option_; + std::unique_ptr cuda_context_; + Workspace ws_; + OperatorDef def_; +}; + +TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUNormalTest) { + if (!HasCudaGPU()) { + return; + } + AddConstInput(std::vector{3, 5, 10}, 1.0f, "A"); + AddConstInput(std::vector{3, 10, 6}, 1.0f, "B"); + std::unique_ptr op(CreateOperator(def_, &ws_)); + ASSERT_NE(nullptr, op); + ASSERT_TRUE(op->Run()); + VerifyOutput(std::vector{3, 5, 6}, 10.0f); +} + +TEST_F(BatchMatMulOpGPUTest, BatchMatMulOpGPUBroadcastTest) { + if (!HasCudaGPU()) { + return; + } + auto* arg = def_.add_arg(); + arg->set_name("broadcast"); + arg->set_i(1); + AddConstInput(std::vector{3, 5, 10}, 1.0f, "A"); + AddConstInput(std::vector{2, 3, 10, 6}, 1.0f, "B"); + std::unique_ptr op(CreateOperator(def_, &ws_)); + ASSERT_NE(nullptr, op); + ASSERT_TRUE(op->Run()); + VerifyOutput(std::vector{2, 3, 5, 6}, 10.0f); +} + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/batch_matmul_op_test.cc b/caffe2/operators/batch_matmul_op_test.cc new file mode 100644 index 00000000000..aed80c3d0d9 --- /dev/null +++ b/caffe2/operators/batch_matmul_op_test.cc @@ -0,0 +1,94 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "caffe2/operators/batch_matmul_op.h" + +namespace caffe2 { +namespace { + +using testing::ElementsAreArray; +using testing::Test; + +class BatchMatMulOpTest : public Test { + protected: + void SetUp() override { + cpu_context_ = make_unique(option_); + def_.set_name("test"); + def_.set_type("BatchMatMul"); + def_.add_input("A"); + def_.add_input("B"); + def_.add_output("Y"); + } + + void AddConstInput( + const std::vector& dims, + const float value, + const string& name) { + Blob* blob = ws_.CreateBlob(name); + auto* tensor = blob->GetMutable(); + tensor->Resize(dims); + math::Set( + tensor->size(), + value, + tensor->mutable_data(), + cpu_context_.get()); + } + + void VerifyOutput(const std::vector& dims, const float value) const { + const Blob* Y_blob = ws_.GetBlob("Y"); + ASSERT_NE(nullptr, Y_blob); + const auto& Y = Y_blob->Get(); + ASSERT_THAT(Y.dims(), ElementsAreArray(dims)); + for (int i = 0; i < Y.size(); ++i) { + EXPECT_FLOAT_EQ(value, Y.data()[i]); + } + } + + DeviceOption option_; + std::unique_ptr cpu_context_; + Workspace ws_; + OperatorDef def_; +}; + +TEST_F(BatchMatMulOpTest, BatchMatMulOpNormalTest) { + AddConstInput(std::vector{3, 5, 10}, 1.0f, "A"); + AddConstInput(std::vector{3, 10, 6}, 1.0f, "B"); + std::unique_ptr op(CreateOperator(def_, &ws_)); + ASSERT_NE(nullptr, op); + ASSERT_TRUE(op->Run()); + VerifyOutput(std::vector{3, 5, 6}, 10.0f); +} + +TEST_F(BatchMatMulOpTest, BatchMatMulOpBroadcastTest) { + auto* arg = def_.add_arg(); + arg->set_name("broadcast"); + arg->set_i(1); + AddConstInput(std::vector{3, 5, 10}, 1.0f, "A"); + AddConstInput(std::vector{2, 3, 10, 6}, 1.0f, "B"); + std::unique_ptr op(CreateOperator(def_, &ws_)); + ASSERT_NE(nullptr, op); + ASSERT_TRUE(op->Run()); + VerifyOutput(std::vector{2, 3, 5, 6}, 10.0f); +} + +} // namespace +} // namespace caffe2 diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index a86558658d7..80b15c73964 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -252,10 +252,7 @@ template void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index dab3ea4719a..7502e01c8ff 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include "caffe2/utils/math.h" #include "caffe2/utils/cpu_neon.h" @@ -425,10 +426,7 @@ template <> void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, @@ -440,25 +438,55 @@ void GemmBatched( CPUContext* context, Tensor*, /* scratch */ TensorProto::DataType /* math_type */) { + const int a_stride = M * K; + const int b_stride = K * N; + const int c_stride = M * N; - auto a_offset = A_size / A_batches; - auto b_offset = B_size / B_batches; - auto y_offset = M * N; +#ifdef CAFFE2_USE_MKL + const int lda = (TransA == CblasNoTrans) ? K : M; + const int ldb = (TransB == CblasNoTrans) ? N : K; + std::vector a_array(batch_size, nullptr); + std::vector b_array(batch_size, nullptr); + std::vector c_array(batch_size, nullptr); + for (int i = 0; i < batch_size; ++i) { + a_array[i] = A + a_stride * i; + b_array[i] = B + b_stride * i; + c_array[i] = C + c_stride * i; + } + cblas_sgemm_batch( + CblasRowMajor, + &TransA, + &TransB, + &M, + &N, + &K, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &N, // ldc_array + 1, + &batch_size); +#else // CAFFE2_USE_MKL // loop over matrices in the batch - for (int i = 0; i < A_batches; ++i) { + for (int i = 0; i < batch_size; ++i) { math::Gemm( TransA, TransB, M, N, K, - 1, - A + a_offset * i, - B + b_offset * i, - 0, - C + y_offset * i, + alpha, + A + a_stride * i, + B + b_stride * i, + beta, + C + c_stride * i, context); } +#endif } //////////////////////////////////////////////////////////////////////////////// diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 8942b6ff820..3d078a0986a 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -281,10 +281,7 @@ template <> void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, @@ -296,55 +293,53 @@ void GemmBatched( CUDAContext* context, Tensor* scratch, TensorProto::DataType math_type) { - + const int a_stride = M * K; + const int b_stride = K * N; + const int c_stride = M * N; #if __CUDACC_VER_MAJOR__ < 8 - auto a_offset = A_size / A_batches; - auto b_offset = B_size / B_batches; - auto y_offset = M * N; // loop over matrices in the batch - for (int i = 0; i < A_batches; ++i) { + for (int i = 0; i < batch_size; ++i) { math::Gemm( TransA, TransB, M, N, K, - 1, - A + a_offset * i, - B + b_offset * i, - 0, - C + y_offset * i, + alpha, + A + a_stride * i, + B + b_stride * i, + beta, + C + c_stride * i, context); } #else // Note that cublas follows fortran order, so the order is different from // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - + const int lda = (TransA == CblasNoTrans) ? K : M; + const int ldb = (TransB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; CUBLAS_ENFORCE(cublasSgemmStridedBatched( - context->cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - B_size / B_batches, // B stride - A, - lda, - A_size / A_batches, // A stride - &beta, - C, - N, - M*N, // C stride - A_batches)); + context->cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + b_stride, + A, + lda, + a_stride, + &beta, + C, + N, + c_stride, + batch_size)); #endif } @@ -368,10 +363,7 @@ template <> void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, @@ -383,24 +375,23 @@ void GemmBatched( CUDAContext* context, Tensor* scratch, TensorProto::DataType math_type) { - + const int a_stride = M * K; + const int b_stride = K * N; + const int c_stride = M * N; #if __CUDACC_VER_MAJOR__ < 8 - auto a_offset = A_size / A_batches; - auto b_offset = B_size / B_batches; - auto y_offset = M * N; // loop over matrices in the batch - for (int i = 0; i < A_batches; ++i) { + for (int i = 0; i < batch_size; ++i) { math::Gemm( TransA, TransB, M, N, K, - 1, - A + a_offset * i, - B + b_offset * i, - 0, - C + y_offset * i, + alpha, + A + a_stride * i, + B + b_stride * i, + beta, + C + c_stride * i, context); } #else @@ -410,11 +401,13 @@ void GemmBatched( // 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm if (scratch != nullptr) { + const int A_size = a_stride * batch_size; + const int B_size = b_stride * batch_size; // cast, cublasSgemmStridedBatched, cast size_t in_elems = A_size + B_size; - size_t out_elems = A_batches*M*N; + size_t out_elems = c_stride * batch_size; - scratch->Resize(in_elems+out_elems); + scratch->Resize(in_elems + out_elems); float* scratch_ptr = scratch->mutable_data(); float* A_fp32 = scratch_ptr; @@ -432,13 +425,10 @@ void GemmBatched( context->cuda_stream()>>>(B_size, (half*)B, B_fp32); // run fp32 batched Gemm - GemmBatched( + GemmBatched( TransA, TransB, - A_size, - A_batches, - B_size, - B_batches, + batch_size, M, N, K, @@ -450,35 +440,33 @@ void GemmBatched( context); // cast result back to fp16 - FloatToHalfKernel<<cuda_stream()>>>(A_batches*M*N, C_fp32, (half*)C); + FloatToHalfKernel<<< + CAFFE_GET_BLOCKS(batch_size * M * N), + CAFFE_CUDA_NUM_THREADS, + 0, + context->cuda_stream()>>>(batch_size * M * N, C_fp32, (half*)C); } else { if (math_type == TensorProto_DataType_FLOAT) { - auto a_offset = A_size / A_batches; - auto b_offset = B_size / B_batches; - auto y_offset = M * N; // loop over matrices in the batch - for (int i = 0; i < A_batches; ++i) { + for (int i = 0; i < batch_size; ++i) { math::Gemm( TransA, TransB, M, N, K, - 1, - A + a_offset * i, - B + b_offset * i, - 0, - C + y_offset * i, + alpha, + A + a_stride * i, + B + b_stride * i, + beta, + C + c_stride * i, context); } } else if (math_type == TensorProto_DataType_FLOAT16) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; + const int lda = (TransA == CblasNoTrans) ? K : M; + const int ldb = (TransB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -488,24 +476,24 @@ void GemmBatched( auto alpha_fp16 = convert::floatToHalf(alpha); auto beta_fp16 = convert::floatToHalf(beta); CUBLAS_ENFORCE(cublasHgemmStridedBatched( - context->cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha_fp16, - (const __half*)B, - ldb, - B_size / B_batches, - (const __half*)A, - lda, - A_size / A_batches, - &beta_fp16, - (__half*)C, - N, - M*N, - A_batches)); + context->cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha_fp16, + (const __half*)B, + ldb, + b_stride, + (const __half*)A, + lda, + a_stride, + &beta_fp16, + (__half*)C, + N, + c_stride, + batch_size)); } } #endif @@ -605,10 +593,7 @@ template <> void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, @@ -623,10 +608,7 @@ void GemmBatched( return GemmBatched( TransA, TransB, - A_size, - A_batches, - B_size, - B_batches, + batch_size, M, N, K, @@ -644,10 +626,7 @@ template <> void GemmBatched( const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int A_size, - const int A_batches, - const int B_size, - const int B_batches, + const int batch_size, const int M, const int N, const int K, @@ -662,10 +641,7 @@ void GemmBatched( return GemmBatched( TransA, TransB, - A_size, - A_batches, - B_size, - B_batches, + batch_size, M, N, K, diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc index ffc0c4b80ea..38312c9bbd8 100644 --- a/caffe2/utils/math_gpu_test.cc +++ b/caffe2/utils/math_gpu_test.cc @@ -15,8 +15,11 @@ */ #include +#include +#include #include + #include "caffe2/core/context.h" #include "caffe2/core/context_gpu.h" #include "caffe2/core/flags.h" @@ -254,4 +257,84 @@ TEST(MathUtilGPUTest, testCopyVector) { [](int i) { return 5.0f - i; }); } +namespace { + +class GemmBatchedGPUTest + : public testing::TestWithParam> { + protected: + void SetUp() override { + if (!HasCudaGPU()) { + return; + } + option_.set_device_type(CUDA); + cuda_context_ = make_unique(option_); + Blob* X_blob = ws_.CreateBlob("X"); + Blob* W_blob = ws_.CreateBlob("W"); + Blob* Y_blob = ws_.CreateBlob("Y"); + X_ = X_blob->GetMutable>(); + W_ = W_blob->GetMutable>(); + Y_ = Y_blob->GetMutable>(); + X_->Resize(std::vector{3, 5, 10}); + W_->Resize(std::vector{3, 6, 10}); + Y_->Resize(std::vector{3, 5, 6}); + math::Set( + X_->size(), 1.0f, X_->mutable_data(), cuda_context_.get()); + math::Set( + W_->size(), 1.0f, W_->mutable_data(), cuda_context_.get()); + trans_X_ = std::get<0>(GetParam()); + trans_W_ = std::get<1>(GetParam()); + } + + void RunGemmBatched(const float alpha, const float beta) { + math::GemmBatched( + trans_X_ ? CblasTrans : CblasNoTrans, + trans_W_ ? CblasTrans : CblasNoTrans, + 3, + 5, + 6, + 10, + alpha, + X_->template data(), + W_->template data(), + beta, + Y_->template mutable_data(), + cuda_context_.get()); + } + + void VerifyOutput(const float value) const { + TensorCPU Y_cpu(*Y_); + for (int i = 0; i < Y_cpu.size(); ++i) { + EXPECT_FLOAT_EQ(value, Y_cpu.template data()[i]); + } + } + + Workspace ws_; + DeviceOption option_; + std::unique_ptr cuda_context_; + Tensor* X_ = nullptr; + Tensor* W_ = nullptr; + Tensor* Y_ = nullptr; + bool trans_X_; + bool trans_W_; +}; + +TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) { + if (!HasCudaGPU()) { + return; + } + RunGemmBatched(1.0f, 0.0f); + VerifyOutput(10.0f); + RunGemmBatched(1.0f, 0.5f); + VerifyOutput(15.0f); + RunGemmBatched(0.5f, 1.0f); + VerifyOutput(20.0f); +} + +INSTANTIATE_TEST_CASE_P( + GemmBatchedGPUTrans, + GemmBatchedGPUTest, + testing::Combine(testing::Bool(), testing::Bool())); + +} // namespace + } // namespace caffe2 diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc index 55227a07526..7b55c0c3e8c 100644 --- a/caffe2/utils/math_test.cc +++ b/caffe2/utils/math_test.cc @@ -14,7 +14,11 @@ * limitations under the License. */ +#include +#include + #include + #include "caffe2/core/blob.h" #include "caffe2/core/context.h" #include "caffe2/core/tensor.h" @@ -116,6 +120,71 @@ TEST(MathTest, GemmNoTransTrans) { } } +namespace { + +class GemmBatchedTest + : public testing::TestWithParam> { + protected: + void SetUp() override { + cpu_context_ = make_unique(option_); + X_.Resize(std::vector{3, 5, 10}); + W_.Resize(std::vector{3, 6, 10}); + Y_.Resize(std::vector{3, 5, 6}); + math::Set( + X_.size(), 1, X_.mutable_data(), cpu_context_.get()); + math::Set( + W_.size(), 1, W_.mutable_data(), cpu_context_.get()); + trans_X_ = std::get<0>(GetParam()); + trans_W_ = std::get<1>(GetParam()); + } + + void RunGemmBatched(const float alpha, const float beta) { + math::GemmBatched( + trans_X_ ? CblasTrans : CblasNoTrans, + trans_W_ ? CblasTrans : CblasNoTrans, + 3, + 5, + 6, + 10, + alpha, + X_.template data(), + W_.template data(), + beta, + Y_.template mutable_data(), + cpu_context_.get()); + } + + void VerifyOutput(const float value) const { + for (int i = 0; i < Y_.size(); ++i) { + EXPECT_FLOAT_EQ(value, Y_.template data()[i]); + } + } + + DeviceOption option_; + std::unique_ptr cpu_context_; + TensorCPU X_; + TensorCPU W_; + TensorCPU Y_; + bool trans_X_; + bool trans_W_; +}; + +TEST_P(GemmBatchedTest, GemmBatchedFloatTest) { + RunGemmBatched(1.0f, 0.0f); + VerifyOutput(10.0f); + RunGemmBatched(1.0f, 0.5f); + VerifyOutput(15.0f); + RunGemmBatched(0.5f, 1.0f); + VerifyOutput(20.0f); +} + +INSTANTIATE_TEST_CASE_P( + GemmBatchedTrans, + GemmBatchedTest, + testing::Combine(testing::Bool(), testing::Bool())); + +} // namespace + TEST(MathTest, GemvNoTrans) { DeviceOption option; CPUContext cpu_context(option);