onnxruntime/onnxruntime/test/mlas/unittest.cpp
Dmitri Smirnov d1b1cdc5c4
Replace GSL with GSL-LITE submodule and fix up refs (#1920)
Remove gsl subodule and replace with a local copy of gsl-lite
  Refactor for onnxruntime::make_unique
  gsl::span size and index are now size_t
  Remove lambda auto argument type detection.
  Remove constexpr from fail_fast in gsl due to Linux not being happy.
  Comment out std::stream support due to MacOS std lib broken.
  Move make_unique into include/core/common so it is accessible for server builds.
  Relax requirements for onnxruntime/test/providers/cpu/ml/write_scores_test.cc
  due to x86 build.
  Add ONNXRUNTIME_ROOT to Server Lib includes so gsl is recognized
2019-10-01 12:43:29 -07:00

2021 lines
72 KiB
C++

/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
unittest.cpp
Abstract:
This module implements unit tests of the MLAS library.
--*/
#include <stdio.h>
#include <memory.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <mlas.h>
#if defined(_WIN32)
#include <windows.h>
#else
#include <sys/mman.h>
#endif
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
#include "core/platform/threadpool.h"
#endif
#include "core/common/make_unique.h"
#if !defined(_countof)
#define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0]))
#endif
#if defined(_M_AMD64) || defined(__x86_64__)
#define MLAS_HAS_DGEMM
#endif
#if defined(_M_IX86) || defined(__i386__) || defined(_M_AMD64) || defined(__x86_64__)
#define MLAS_HAS_QGEMM_U8X8
#endif
MLAS_THREADPOOL* threadpool = nullptr;
template <typename T>
class MatrixGuardBuffer
{
public:
MatrixGuardBuffer()
{
_BaseBuffer = nullptr;
_BaseBufferSize = 0;
_ElementsAllocated = 0;
}
~MatrixGuardBuffer(void)
{
ReleaseBuffer();
}
T* GetBuffer(size_t Elements)
{
//
// Check if the internal buffer needs to be reallocated.
//
if (Elements > _ElementsAllocated) {
ReleaseBuffer();
//
// Reserve a virtual address range for the allocation plus an unmapped
// guard region.
//
constexpr size_t BufferAlignment = 64 * 1024;
constexpr size_t GuardPadding = 256 * 1024;
size_t BytesToAllocate = ((Elements * sizeof(T)) + BufferAlignment - 1) & ~(BufferAlignment - 1);
_BaseBufferSize = BytesToAllocate + GuardPadding;
#if defined(_WIN32)
_BaseBuffer = VirtualAlloc(NULL, _BaseBufferSize, MEM_RESERVE, PAGE_NOACCESS);
#else
_BaseBuffer = mmap(0, _BaseBufferSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
#endif
if (_BaseBuffer == nullptr) {
throw std::bad_alloc();
}
//
// Commit the number of bytes for the allocation leaving the upper
// guard region as unmapped.
//
#if defined(_WIN32)
if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) {
throw std::bad_alloc();
}
#else
if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) {
throw std::bad_alloc();
}
#endif
_ElementsAllocated = BytesToAllocate / sizeof(T);
_GuardAddress = (T*)((unsigned char*)_BaseBuffer + BytesToAllocate);
}
//
//
//
T* GuardAddress = _GuardAddress;
T* buffer = GuardAddress - Elements;
const int MinimumFillValue = -23;
const int MaximumFillValue = 23;
int FillValue = MinimumFillValue;
T* FillAddress = buffer;
while (FillAddress < GuardAddress) {
*FillAddress++ = (T)FillValue;
FillValue++;
if (FillValue > MaximumFillValue) {
FillValue = MinimumFillValue;
}
}
return buffer;
}
void ReleaseBuffer(void)
{
if (_BaseBuffer != nullptr) {
#if defined(_WIN32)
VirtualFree(_BaseBuffer, 0, MEM_RELEASE);
#else
munmap(_BaseBuffer, _BaseBufferSize);
#endif
_BaseBuffer = nullptr;
_BaseBufferSize = 0;
}
_ElementsAllocated = 0;
}
private:
size_t _ElementsAllocated;
void* _BaseBuffer;
size_t _BaseBufferSize;
T* _GuardAddress;
};
class MlasTestBase
{
public:
virtual
~MlasTestBase(
void
)
{
}
//
// Contains tests that run quickly as part of a checkin integration to
// sanity check that the functionality is working.
//
virtual
void
ExecuteShort(
void
) = 0;
//
// Contains tests that can run slowly to more exhaustively test that
// functionality is working across a broader range of parameters.
//
virtual
void
ExecuteLong(
void
) = 0;
};
template <typename T>
class MlasFgemmTest : public MlasTestBase
{
private:
void
Test(
size_t M,
size_t N,
size_t K,
float alpha,
float beta
)
{
const T* A = BufferA.GetBuffer(K * M);
const T* B = BufferB.GetBuffer(N * K);
T* C = BufferC.GetBuffer(N * M);
T* CReference = BufferCReference.GetBuffer(N * M);
Test(CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, K, B, N, beta, C, CReference, N);
Test(CblasNoTrans, CblasTrans, M, N, K, alpha, A, K, B, K, beta, C, CReference, N);
Test(CblasTrans, CblasNoTrans, M, N, K, alpha, A, M, B, N, beta, C, CReference, N);
Test(CblasTrans, CblasTrans, M, N, K, alpha, A, M, B, K, beta, C, CReference, N);
}
void
Test(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
float alpha,
const T* A,
size_t lda,
const T* B,
size_t ldb,
float beta,
T* C,
T* CReference,
size_t ldc
)
{
std::fill_n(C, M * N, -0.5f);
std::fill_n(CReference, M * N, -0.5f);
MlasGemm(TransA, TransB, M, N, K, T(alpha), A, lda, B, ldb, T(beta), C, ldc, threadpool);
ReferenceSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, CReference, ldc);
for (size_t f = 0; f < M * N; f++) {
// Sensitive to comparing positive/negative zero.
if (C[f] != CReference[f]) {
printf("mismatch TransA=%d, TransB=%d, M=%zd, N=%zd, K=%zd, alpha=%f, beta=%f %f %f!\n", TransA, TransB, M, N, K, alpha, beta, float(C[f]), float(CReference[f]));
}
}
}
void
ReferenceSgemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
float alpha,
const T* A,
size_t lda,
const T* B,
size_t ldb,
float beta,
T* C,
size_t ldc
)
{
if (TransA == CblasNoTrans) {
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += 1;
}
*c = (*c * beta) + (sum * alpha);
}
}
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + (m * lda);
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += 1;
}
*c = (*c * beta) + (sum * alpha);
}
}
}
} else {
if (TransB == CblasNoTrans) {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + n;
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += ldb;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
}
}
} else {
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const T* a = A + m;
const T* b = B + (n * ldb);
T* c = C + (m * ldc) + n;
T sum = 0.0f;
for (size_t k = 0; k < K; k++) {
sum += (*b * *a);
b += 1;
a += lda;
}
*c = (*c * beta) + (sum * alpha);
}
}
}
}
}
MatrixGuardBuffer<T> BufferA;
MatrixGuardBuffer<T> BufferB;
MatrixGuardBuffer<T> BufferC;
MatrixGuardBuffer<T> BufferCReference;
public:
void
ExecuteShort(
void
) override
{
for (size_t b = 1; b < 16; b++) {
Test(b, b, b, 1.0f, 0.0f);
}
for (size_t b = 16; b <= 256; b <<= 1) {
Test(b, b, b, 1.0f, 0.0f);
}
for (size_t b = 256; b < 320; b += 32) {
Test(b, b, b, 1.0f, 0.0f);
}
}
void
ExecuteLong(
void
) override
{
static const float multipliers[] = { 0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f };
for (size_t N = 1; N < 128; N++) {
for (size_t K = 1; K < 128; K++) {
for (size_t a = 0; a < _countof(multipliers); a++) {
for (size_t b = 0; b < _countof(multipliers); b++) {
Test(1, N, K, multipliers[a], multipliers[b]);
}
}
}
}
for (size_t a = 0; a < _countof(multipliers); a++) {
float alpha = multipliers[a];
for (size_t b = 0; b < _countof(multipliers); b++) {
float beta = multipliers[b];
for (size_t M = 16; M < 160; M += 32) {
for (size_t N = 16; N < 160; N += 32) {
static const size_t ks[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320 };
for (size_t k = 0; k < _countof(ks); k++) {
size_t K = ks[k];
Test(M, N, K, alpha, beta);
Test(M + 1, N, K, alpha, beta);
Test(M, N + 1, K, alpha, beta);
Test(M + 1, N + 1, K, alpha, beta);
Test(M + 3, N + 2, K, alpha, beta);
Test(M + 4, N, K, alpha, beta);
Test(M, N + 4, K, alpha, beta);
Test(M + 4, N + 4, K, alpha, beta);
Test(M + 3, N + 7, K, alpha, beta);
Test(M + 8, N, K, alpha, beta);
Test(M, N + 8, K, alpha, beta);
Test(M + 12, N + 12, K, alpha, beta);
Test(M + 13, N, K, alpha, beta);
Test(M, N + 15, K, alpha, beta);
Test(M + 15, N + 15, K, alpha, beta);
}
}
printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M);
}
}
}
for (size_t M = 1; M < 160; M++) {
for (size_t N = 1; N < 160; N++) {
for (size_t K = 1; K < 160; K++) {
Test(M, N, K, 1.0f, 0.0f);
}
}
printf("M %zd\n", M);
}
for (size_t M = 160; M < 320; M += 24) {
for (size_t N = 112; N < 320; N += 24) {
for (size_t K = 1; K < 16; K++) {
Test(M, N, K, 1.0f, 0.0f);
}
for (size_t K = 16; K < 160; K += 32) {
Test(M, N, K, 1.0f, 0.0f);
}
}
printf("M %zd\n", M);
}
}
};
#ifdef MLAS_HAS_QGEMM_U8X8
template <typename xint8_t>
class MlasQgemmU8X8Test : public MlasTestBase
{
private:
void
Test(
size_t M,
size_t N,
size_t K,
uint8_t offa,
uint8_t offb
)
{
const uint8_t* A = BufferA.GetBuffer(K * M);
const xint8_t* B = BufferB.GetBuffer(N * K);
int32_t* C = BufferC.GetBuffer(N * M);
int32_t* CReference = BufferCReference.GetBuffer(N * M);
Test(M, N, K, A, K, offa, B, N, xint8_t(offb), C, CReference, N);
}
void
Test(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const xint8_t* B,
size_t ldb,
xint8_t offb,
int32_t* C,
int32_t* CReference,
size_t ldc
)
{
std::fill_n(C, M * N, -1);
std::fill_n(CReference, M * N, -1);
MlasGemm(M, N, K, A, lda, offa, B, ldb, offb, C, ldc, threadpool);
ReferenceQgemm(M, N, K, A, lda, offa, B, ldb, offb, CReference, ldc);
for (size_t f = 0; f < M * N; f++) {
if (C[f] != CReference[f]) {
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d!\n", M, N, K, offa, offb);
}
}
}
void
ReferenceQgemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const xint8_t* B,
size_t ldb,
xint8_t offb,
int32_t* C,
size_t ldc
)
{
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
const uint8_t* a = A + (m * lda);
const xint8_t* b = B + n;
int32_t* c = C + (m * ldc) + n;
int32_t sum = 0;
for (size_t k = 0; k < K; k++) {
sum += ((int32_t(*b) - offb) * (int32_t(*a) - offa));
b += ldb;
a += 1;
}
*c = sum;
}
}
}
MatrixGuardBuffer<uint8_t> BufferA;
MatrixGuardBuffer<xint8_t> BufferB;
MatrixGuardBuffer<int32_t> BufferC;
MatrixGuardBuffer<int32_t> BufferCReference;
public:
void
ExecuteShort(
void
) override
{
for (size_t b = 1; b < 16; b++) {
Test(b, b, b, 14, 211);
}
for (size_t b = 16; b <= 256; b <<= 1) {
Test(b, b, b, 34, 1);
}
for (size_t b = 256; b < 320; b += 32) {
Test(b, b, b, 85, 173);
}
}
void
ExecuteLong(
void
) override
{
static const uint8_t zero_points[] = { 0, 18, 75, 128, 157, 231, 255 };
for (size_t a = 0; a < _countof(zero_points); a++) {
uint8_t offa = zero_points[a];
for (size_t b = 0; b < _countof(zero_points); b++) {
uint8_t offb = zero_points[b];
for (size_t M = 16; M < 160; M += 32) {
for (size_t N = 16; N < 160; N += 32) {
static const size_t ks[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320 };
for (size_t k = 0; k < _countof(ks); k++) {
size_t K = ks[k];
Test(M, N, K, offa, offb);
Test(M + 1, N, K, offa, offb);
Test(M, N + 1, K, offa, offb);
Test(M + 1, N + 1, K, offa, offb);
Test(M + 3, N + 2, K, offa, offb);
Test(M + 4, N, K, offa, offb);
Test(M, N + 4, K, offa, offb);
Test(M + 4, N + 4, K, offa, offb);
Test(M + 3, N + 7, K, offa, offb);
Test(M + 8, N, K, offa, offb);
Test(M, N + 8, K, offa, offb);
Test(M + 12, N + 12, K, offa, offb);
Test(M + 13, N, K, offa, offb);
Test(M, N + 15, K, offa, offb);
Test(M + 15, N + 15, K, offa, offb);
}
}
printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(zero_points), b, _countof(zero_points), M);
}
}
}
for (size_t M = 1; M < 160; M++) {
for (size_t N = 1; N < 160; N++) {
for (size_t K = 1; K < 160; K++) {
Test(M, N, K, 18, 24);
}
}
printf("M %zd\n", M);
}
for (size_t M = 160; M < 320; M += 24) {
for (size_t N = 112; N < 320; N += 24) {
for (size_t K = 1; K < 16; K++) {
Test(M, N, K, 1, 3);
}
for (size_t K = 16; K < 160; K += 32) {
Test(M, N, K, 5, 7);
}
}
printf("M %zd\n", M);
}
}
};
#endif
class MlasConv2DTest : public MlasTestBase
{
protected:
void
Test(
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
size_t InputHeight,
size_t InputWidth,
size_t FilterCount,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t DilationHeight,
size_t DilationWidth,
size_t StrideHeight,
size_t StrideWidth
)
{
int64_t OutputHeight64 =
((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) -
(int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / int64_t(StrideHeight) + 1;
int64_t OutputWidth64 =
((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) -
(int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / int64_t(StrideWidth) + 1;
if (OutputHeight64 <= 0 || OutputWidth64 <= 0) {
return;
}
size_t OutputHeight = size_t(OutputHeight64);
size_t OutputWidth = size_t(OutputWidth64);
size_t InputSize = InputHeight * InputWidth;
size_t KernelSize = KernelHeight * KernelWidth;
size_t OutputSize = OutputHeight * OutputWidth;
size_t InputElements = BatchCount * GroupCount * InputChannels * InputSize;
size_t FilterElements = GroupCount * FilterCount * InputChannels * KernelSize;
size_t BiasElements = GroupCount * FilterCount;
size_t OutputElements = BatchCount * GroupCount * FilterCount * OutputSize;
const float* Input = BufferInput.GetBuffer(InputElements);
const float* Filter = BufferFilter.GetBuffer(FilterElements);
const float* Bias = BufferBias.GetBuffer(BiasElements);
float* Output = BufferOutput.GetBuffer(OutputElements);
float* OutputReference = BufferOutputReference.GetBuffer(OutputElements);
MlasConv2D(BatchCount,
GroupCount,
InputChannels,
InputHeight, InputWidth,
FilterCount,
KernelHeight, KernelWidth,
PaddingLeftHeight, PaddingLeftWidth,
PaddingRightHeight, PaddingRightWidth,
DilationHeight, DilationWidth,
StrideHeight, StrideWidth,
OutputHeight, OutputWidth,
Input,
Filter,
Bias,
Output);
ReferenceConv2D(BatchCount,
GroupCount,
InputChannels,
InputHeight, InputWidth,
FilterCount,
KernelHeight, KernelWidth,
PaddingLeftHeight, PaddingLeftWidth,
DilationHeight, DilationWidth,
StrideHeight, StrideWidth,
OutputHeight, OutputWidth,
Input,
Filter,
Bias,
OutputReference);
if (memcmp(Output, OutputReference, OutputElements * sizeof(float)) != 0) {
printf("mismatch: batch=%zd,group=%zd,input(%zd,%zd,%zd),filter=%zd,kernel(%zd,%zd)!!!\n",
BatchCount, GroupCount, InputChannels, InputHeight, InputWidth, FilterCount,
KernelHeight, KernelWidth);
}
}
virtual
void
MlasConv2D(
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
size_t InputHeight,
size_t InputWidth,
size_t FilterCount,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t DilationHeight,
size_t DilationWidth,
size_t StrideHeight,
size_t StrideWidth,
size_t OutputHeight,
size_t OutputWidth,
const float* Input,
const float* Filter,
const float* Bias,
float* Output
)
{
int64_t InputShape[] = { int64_t(InputHeight), int64_t(InputWidth) };
int64_t KernelShape[] = { int64_t(KernelHeight), int64_t(KernelWidth) };
int64_t DilationShape[] = { int64_t(DilationHeight), int64_t(DilationWidth) };
int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) };
int64_t OutputShape[] = { int64_t(OutputHeight), int64_t(OutputWidth) };
MLAS_ACTIVATION Activation;
Activation.ActivationKind = MlasIdentityActivation;
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize;
MlasConvPrepare(&Parameters,
2,
BatchCount,
GroupCount,
InputChannels,
InputShape,
KernelShape,
DilationShape,
Padding,
StrideShape,
OutputShape,
FilterCount,
&Activation,
&WorkingBufferSize,
nullptr);
MlasConv(&Parameters,
Input,
Filter,
Bias,
BufferWorking.GetBuffer(WorkingBufferSize),
Output,
nullptr);
}
void
ReferenceConv2D(
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
size_t InputHeight,
size_t InputWidth,
size_t FilterCount,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t DilationHeight,
size_t DilationWidth,
size_t StrideHeight,
size_t StrideWidth,
size_t OutputHeight,
size_t OutputWidth,
const float* Input,
const float* Filter,
const float* Bias,
float* Output
)
{
size_t InputSize = InputHeight * InputWidth;
size_t OutputSize = OutputHeight * OutputWidth;
size_t KernelSize = KernelHeight * KernelWidth;
size_t K = InputChannels * KernelSize;
size_t Im2ColElements = OutputSize * K;
for (size_t b = 0; b < BatchCount; b++) {
const float* filter = Filter;
const float* bias = Bias;
for (size_t g = 0; g < GroupCount; g++) {
//
// Transform the image using IM2COL and invoke the GEMM.
//
float* Im2Col = BufferIm2Col.GetBuffer(Im2ColElements);
float* Im2ColOut = Im2Col;
for (size_t c = 0; c < InputChannels; c++) {
for (size_t ky = 0; ky < KernelHeight; ky++) {
for (size_t kx = 0; kx < KernelWidth; kx++) {
for (size_t oh = 0; oh < OutputHeight; oh++) {
size_t ih = oh * StrideHeight + ky * DilationHeight - PaddingLeftHeight;
for (size_t ow = 0; ow < OutputWidth; ow++) {
size_t iw = ow * StrideWidth + kx * DilationWidth - PaddingLeftWidth;
*Im2ColOut++ = (ih < InputHeight && iw < InputWidth) ?
Input[ih * InputWidth + iw] : 0;
}
}
}
}
Input += InputSize;
}
MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, OutputSize, K, 1.0f,
filter, K, Im2Col, OutputSize, 0.0f, Output, OutputSize, threadpool);
//
// Apply the bias.
//
for (size_t f = 0; f < FilterCount; f++) {
float biasValue = *bias++;
for (size_t o = 0; o < OutputSize; o++) {
*Output++ += biasValue;
}
}
filter += FilterCount * InputChannels * KernelSize;
}
}
}
MatrixGuardBuffer<float> BufferInput;
MatrixGuardBuffer<float> BufferFilter;
MatrixGuardBuffer<float> BufferBias;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutputReference;
MatrixGuardBuffer<float> BufferWorking;
MatrixGuardBuffer<float> BufferIm2Col;
public:
void
ExecuteShort(
void
) override
{
for (unsigned i = 1; i < 256; i <<= 1) {
Test(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2);
Test(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 2, 2, 1, 1);
Test(1, 1, 16, i, i, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1);
Test(1, 1, 16, i, i, 32, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, 1, 16, i, i, 32, i, 1, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, 1, 16, i, i, 32, 1, i, 0, 0, 0, 0, 1, 1, 1, 1);
}
}
void
ExecuteLong(
void
) override
{
static const unsigned cs[] = { 32, 14, 1 };
static const unsigned is[] = { 53, 11, 5, 1 };
for (unsigned i = 1; i <= 32; i++) {
Test(4, 18, 1, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1);
Test(4, 18, 1, 32, 89, 48, i, 89, 1, 1, 1, 1, 1, 1, 1, 1);
Test(4, 18, 2, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1);
}
for (unsigned b = 1; b < 64; b++) {
Test(b, 1, 64, 11, 11, 128, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1);
}
for (unsigned ic = 0; ic < _countof(cs); ic++) {
for (unsigned ih = 0; ih < _countof(is); ih++) {
for (unsigned iw = 0; iw < _countof(is); iw++) {
fprintf(stderr, "Handling %ux%ux%u\n", cs[ic], is[ih], is[iw]);
for (unsigned fc = 0; fc < _countof(cs); fc++) {
for (unsigned kh = 1; kh <= 5; kh++) {
if (kh == 4) continue;
for (unsigned kw = 1; kw <= 5; kw++) {
if (kw == 4) continue;
for (unsigned p0 = 0; p0 < 2; p0++) {
for (unsigned p1 = 0; p1 < 2; p1++) {
for (unsigned p2 = 0; p2 < 2; p2++) {
for (unsigned p3 = 0; p3 < 2; p3++) {
for (unsigned dh = 1; dh <= 2; dh++) {
for (unsigned dw = 1; dw <= 2; dw++) {
for (unsigned sh = 1; sh <= 2; sh++) {
for (unsigned sw = 1; sw <= 2; sw++) {
Test(1, 1, cs[ic], is[ih], is[iw], cs[fc], kh, kw, p0, p1, p2, p3, dh, dw, sh, sw);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
};
class MlasNchwcConv2DTest : public MlasConv2DTest
{
protected:
void
MlasConv2D(
size_t BatchCount,
size_t GroupCount,
size_t InputChannels,
size_t InputHeight,
size_t InputWidth,
size_t FilterCount,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t DilationHeight,
size_t DilationWidth,
size_t StrideHeight,
size_t StrideWidth,
size_t OutputHeight,
size_t OutputWidth,
const float* Input,
const float* Filter,
const float* Bias,
float* Output
) override
{
int64_t InputShape[] = { int64_t(BatchCount), int64_t(GroupCount * InputChannels), int64_t(InputHeight), int64_t(InputWidth) };
int64_t FilterShape[] = { int64_t(GroupCount * FilterCount), int64_t(InputChannels), int64_t(KernelHeight), int64_t(KernelWidth) };
int64_t OutputShape[] = { int64_t(BatchCount), int64_t(GroupCount * FilterCount), int64_t(OutputHeight), int64_t(OutputWidth) };
int64_t KernelShape[] = { int64_t(KernelHeight), int64_t(KernelWidth) };
int64_t DilationShape[] = { int64_t(DilationHeight), int64_t(DilationWidth) };
int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) };
//
// Select the type of convolution that will be performed.
//
bool DoReorderInput;
bool ReorderFilterOIHWBo;
if (GroupCount > 1 && InputChannels == 1 && FilterCount == 1) {
// Depthwise convolution.
DoReorderInput = true;
ReorderFilterOIHWBo = true;
} else if (InputChannels >= BlockSize) {
// NCHWc or pointwise convolution;
DoReorderInput = true;
ReorderFilterOIHWBo = false;
} else {
// NCHW convolution.
DoReorderInput = false;
ReorderFilterOIHWBo = true;
}
size_t NchwcInputChannels = (GroupCount * InputChannels + BlockSize - 1) & ~(BlockSize - 1);
size_t NchwcOutputChannels = (GroupCount * FilterCount + BlockSize - 1) & ~(BlockSize - 1);
//
// Reorder the filter buffer as needed.
//
float* ReorderedFilter;
if (ReorderFilterOIHWBo) {
size_t NchwcFilterElements = NchwcOutputChannels * InputChannels * KernelHeight * KernelWidth;
ReorderedFilter = BufferNchwcFilter.GetBuffer(NchwcFilterElements);
MlasReorderFilterOIHWBo(FilterShape, Filter, ReorderedFilter);
} else {
size_t NchwcFilterElements = NchwcOutputChannels * NchwcInputChannels * KernelHeight * KernelWidth;
ReorderedFilter = BufferNchwcFilter.GetBuffer(NchwcFilterElements);
MlasReorderFilterOIHWBiBo(FilterShape, Filter, ReorderedFilter);
}
//
// Align the bias buffer to the filter count if needed.
//
if (Bias != nullptr && GroupCount * FilterCount < NchwcOutputChannels) {
float* AlignedBias = BufferNchwcBias.GetBuffer(NchwcOutputChannels);
size_t i;
for (i = 0; i < GroupCount * FilterCount; i++) {
AlignedBias[i] = Bias[i];
}
for (; i < NchwcOutputChannels; i++) {
AlignedBias[i] = 0.0f;
}
Bias = AlignedBias;
}
//
// Reorder the input buffer if needed.
//
if (DoReorderInput) {
size_t NchwcInputElements = BatchCount * NchwcInputChannels * InputHeight * InputWidth;
float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements);
MlasReorderInput(InputShape, Input, NchwcInput);
Input = NchwcInput;
InputShape[1] = NchwcInputChannels;
}
int64_t NchwcOutputShape[] = { int64_t(BatchCount), int64_t(NchwcOutputChannels), int64_t(OutputHeight), int64_t(OutputWidth) };
size_t NchwcOutputElements = BatchCount * NchwcOutputChannels * OutputHeight * OutputWidth;
float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements);
MLAS_ACTIVATION Activation;
Activation.ActivationKind = MlasIdentityActivation;
MlasNchwcConv(2,
InputShape,
KernelShape,
DilationShape,
Padding,
StrideShape,
NchwcOutputShape,
GroupCount,
Input,
ReorderedFilter,
Bias,
NchwcOutput,
&Activation,
true,
nullptr);
//
// Reorder the output buffer.
//
MlasReorderOutput(OutputShape, NchwcOutput, Output);
}
const size_t BlockSize = MlasNchwcGetBlockSize();
MatrixGuardBuffer<float> BufferNchwcInput;
MatrixGuardBuffer<float> BufferNchwcFilter;
MatrixGuardBuffer<float> BufferNchwcBias;
MatrixGuardBuffer<float> BufferNchwcOutput;
public:
void
ExecuteLong(
void
) override
{
// N.B. InputChannels must be a multiple of 4 if the count is greater
// than the block size.
static const unsigned cis[] = { 32, 20, 5, 1 };
static const unsigned cos[] = { 64, 15, 1 };
static const unsigned is[] = { 27, 11, 5, 1 };
// Depthwise convolutions.
for (unsigned i = 16; i < 256; i <<= 1) {
Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2);
Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 2, 2, 1, 1);
Test(1, i, 1, 28, 28, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1);
Test(1, i, 1, 28, 28, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, i, 1, 28, 28, 1, i, 1, 0, 0, 0, 0, 1, 1, 1, 1);
Test(12, i, 1, 11, 11, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1);
}
// Test varying FilterCounts.
for (unsigned i = 1; i < 128; i++) {
Test(1, 1, 3, 34, 34, i, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, 1, 16, 34, 34, i, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1);
Test(1, 1, 16, 34, 34, i, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1);
}
for (unsigned i = 1; i <= 32; i++) {
Test(4, 18, 1, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1);
Test(4, 18, 1, 32, 89, 48, i, 89, 1, 1, 1, 1, 1, 1, 1, 1);
Test(4, 18, 2, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1);
}
for (unsigned b = 1; b < 64; b++) {
Test(b, 1, 64, 11, 11, 128, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1);
}
for (unsigned ic = 0; ic < _countof(cis); ic++) {
for (unsigned ih = 0; ih < _countof(is); ih++) {
for (unsigned iw = 0; iw < _countof(is); iw++) {
fprintf(stderr, "Handling %ux%ux%u\n", cis[ic], is[ih], is[iw]);
for (unsigned fc = 0; fc < _countof(cos); fc++) {
for (unsigned kh = 1; kh <= 5; kh++) {
if (kh == 4) continue;
for (unsigned kw = 1; kw <= 5; kw++) {
if (kw == 4) continue;
for (unsigned p0 = 0; p0 <= 3; p0++) {
for (unsigned p1 = 0; p1 <= 3; p1++) {
for (unsigned p2 = 0; p2 <= 3; p2++) {
for (unsigned p3 = 0; p3 <= 3; p3++) {
for (unsigned dh = 1; dh <= 2; dh++) {
for (unsigned dw = 1; dw <= 2; dw++) {
for (unsigned sh = 1; sh <= 2; sh++) {
for (unsigned sw = 1; sw <= 2; sw++) {
Test(1, 1, cis[ic], is[ih], is[iw], cos[fc], kh, kw, p0, p1, p2, p3, dh, dw, sh, sw);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
};
class MlasPool2DTest : public MlasTestBase
{
protected:
void
Test(
size_t BatchCount,
size_t InputChannels,
size_t InputHeight,
size_t InputWidth,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t StrideHeight,
size_t StrideWidth
)
{
const size_t DilationHeight = 1;
const size_t DilationWidth = 1;
int64_t OutputHeight64 =
((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) -
(int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / int64_t(StrideHeight) + 1;
int64_t OutputWidth64 =
((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) -
(int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / int64_t(StrideWidth) + 1;
if (OutputHeight64 <= 0 || OutputWidth64 <= 0) {
return;
}
int64_t InputShape[] = { int64_t(BatchCount), int64_t(InputChannels), int64_t(InputHeight), int64_t(InputWidth) };
int64_t KernelShape[] = { int64_t(KernelHeight), int64_t(KernelWidth) };
int64_t Padding[] = { int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
int64_t StrideShape[] = { int64_t(StrideHeight), int64_t(StrideWidth) };
int64_t OutputShape[] = { int64_t(BatchCount), int64_t(InputChannels), OutputHeight64, OutputWidth64 };
size_t InputBufferElements = size_t(InputShape[0] * InputShape[1] * InputShape[2] * InputShape[3]);
size_t OutputBufferElements = size_t(OutputShape[0] * OutputShape[1] * OutputShape[2] * OutputShape[3]);
const float* Input = BufferInput.GetBuffer(InputBufferElements);
float* Output = BufferOutput.GetBuffer(OutputBufferElements);
float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements);
MlasPool2D(MlasMaximumPooling, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output);
ReferenceMaximumPool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: maximum input(%zd,%zd,%zd),kernel(%zd,%zd)!!!\n",
InputChannels, InputHeight, InputWidth, KernelHeight, KernelWidth);
}
MlasPool2D(MlasAveragePoolingExcludePad, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output);
ReferenceAveragePool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, false);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: averageexcpad input(%zd,%zd,%zd),kernel(%zd,%zd)!!!\n",
InputChannels, InputHeight, InputWidth, KernelHeight, KernelWidth);
}
MlasPool2D(MlasAveragePoolingIncludePad, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output);
ReferenceAveragePool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, true);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: averageincpad input(%zd,%zd,%zd),kernel(%zd,%zd)!!!\n",
InputChannels, InputHeight, InputWidth, KernelHeight, KernelWidth);
}
}
virtual
void
MlasPool2D(
MLAS_POOLING_KIND PoolingKind,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
const float* Input,
float* Output
)
{
MlasPool(PoolingKind, 2, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool);
}
void
ReferenceMaximumPool2D(
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const float* Input,
float* Output
)
{
int64_t ChannelCount = InputShape[0] * InputShape[1];
int64_t InputHeight = InputShape[2];
int64_t InputWidth = InputShape[3];
int64_t KernelHeight = KernelShape[0];
int64_t KernelWidth = KernelShape[1];
int64_t PaddingLeftY = Padding[0];
int64_t PaddingLeftX = Padding[1];
int64_t PaddingRightY = Padding[2];
int64_t PaddingRightX = Padding[3];
int64_t StrideHeight = StrideShape[0];
int64_t StrideWidth = StrideShape[1];
int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1;
int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1;
for (int64_t c = 0; c < ChannelCount; c++) {
for (int64_t ph = 0; ph < OutputHeight; ph++) {
int64_t ihStart = ph * StrideHeight - PaddingLeftY;
int64_t ihEnd = ihStart + KernelHeight;
ihStart = (std::max)(ihStart, int64_t(0));
ihEnd = (std::min)(ihEnd, InputHeight);
for (int64_t pw = 0; pw < OutputWidth; pw++) {
int64_t iwStart = pw * StrideWidth - PaddingLeftX;
int64_t iwEnd = iwStart + KernelWidth;
iwStart = (std::max)(iwStart, int64_t(0));
iwEnd = (std::min)(iwEnd, InputWidth);
float m = std::numeric_limits<float>::lowest();
for (int64_t ih = ihStart; ih < ihEnd; ih++) {
for (int64_t iw = iwStart; iw < iwEnd; iw++) {
m = (std::max)(m, Input[ih * InputWidth + iw]);
}
}
Output[ph * OutputWidth + pw] = m;
}
}
Input += InputHeight * InputWidth;
Output += OutputHeight * OutputWidth;
}
}
void
ReferenceAveragePool2D(
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const float* Input,
float* Output,
bool CountIncludePad
)
{
int64_t ChannelCount = InputShape[0] * InputShape[1];
int64_t InputHeight = InputShape[2];
int64_t InputWidth = InputShape[3];
int64_t KernelHeight = KernelShape[0];
int64_t KernelWidth = KernelShape[1];
int64_t PaddingLeftY = Padding[0];
int64_t PaddingLeftX = Padding[1];
int64_t PaddingRightY = Padding[2];
int64_t PaddingRightX = Padding[3];
int64_t StrideHeight = StrideShape[0];
int64_t StrideWidth = StrideShape[1];
int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1;
int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1;
for (int64_t c = 0; c < ChannelCount; c++) {
for (int64_t ph = 0; ph < OutputHeight; ph++) {
int64_t ihStart = ph * StrideHeight - PaddingLeftY;
int64_t ihEnd = ihStart + KernelHeight;
ihStart = (std::max)(ihStart, int64_t(0));
ihEnd = (std::min)(ihEnd, InputHeight);
for (int64_t pw = 0; pw < OutputWidth; pw++) {
int64_t iwStart = pw * StrideWidth - PaddingLeftX;
int64_t iwEnd = iwStart + KernelWidth;
iwStart = (std::max)(iwStart, int64_t(0));
iwEnd = (std::min)(iwEnd, InputWidth);
float m = 0.0f;
for (int64_t ih = ihStart; ih < ihEnd; ih++) {
for (int64_t iw = iwStart; iw < iwEnd; iw++) {
m += Input[ih * InputWidth + iw];
}
}
if (CountIncludePad) {
m /= (KernelHeight * KernelWidth);
} else {
m /= (ihEnd - ihStart) * (iwEnd - iwStart);
}
Output[ph * OutputWidth + pw] = m;
}
}
Input += InputHeight * InputWidth;
Output += OutputHeight * OutputWidth;
}
}
MatrixGuardBuffer<float> BufferInput;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutputReference;
public:
void
ExecuteShort(
void
) override
{
for (unsigned i = 1; i < 256; i <<= 1) {
Test(1, 16, i, i, 3, 3, 0, 0, 0, 0, 1, 1);
Test(1, 16, i, i, 3, 3, 0, 0, 0, 0, 2, 2);
Test(1, 16, i, i, 3, 3, 0, 0, 0, 0, 1, 1);
Test(1, 16, i, i, 3, 3, 1, 1, 1, 1, 1, 1);
Test(1, 16, i, i, 1, 1, 0, 0, 0, 0, 1, 1);
Test(1, 16, i, i, i, 1, 0, 0, 0, 0, 1, 1);
Test(1, 16, i, i, 1, i, 0, 0, 0, 0, 1, 1);
}
}
void
ExecuteLong(
void
) override
{
static const unsigned is[] = { 53, 17, 11, 5, 4, 3, 2, 1 };
for (unsigned i = 1; i < 2058; i++) {
Test(1, 1, 4, i, 2, 4, 0, 2, 0, 1, 1, 1);
}
for (unsigned ih = 0; ih < _countof(is); ih++) {
for (unsigned iw = 0; iw < _countof(is); iw++) {
fprintf(stderr, "Handling %ux%u\n", is[ih], is[iw]);
Test(1, 1, is[ih], is[iw], is[ih], is[iw], 0, 0, 0, 0, 1, 1);
Test(1, 1, is[ih], is[iw], is[ih], 1, 0, 0, 0, 0, 1, 1);
Test(1, 1, is[ih], is[iw], 1, is[iw], 0, 0, 0, 0, 1, 1);
for (unsigned kh = 1; kh <= 5; kh++) {
if (kh > is[ih]) break;
for (unsigned kw = 1; kw <= 5; kw++) {
if (kw > is[iw]) break;
for (unsigned sh = 1; sh <= 3; sh++) {
for (unsigned sw = 1; sw <= 3; sw++) {
for (unsigned p0 = 0; p0 < kh; p0++) {
for (unsigned p1 = 0; p1 < kw; p1++) {
for (unsigned p2 = 0; p2 < kh; p2++) {
for (unsigned p3 = 0; p3 < kw; p3++) {
Test(5, 3, is[ih], is[iw], kh, kw, p0, p1, p2, p3, sh, sw);
}
}
}
}
}
}
}
}
}
}
}
};
class MlasNchwcPool2DTest : public MlasPool2DTest
{
protected:
void
MlasPool2D(
MLAS_POOLING_KIND PoolingKind,
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const int64_t* OutputShape,
const float* Input,
float* Output
) override
{
size_t NchwcChannels = (size_t(InputShape[1]) + BlockSize - 1) & ~(BlockSize - 1);
int64_t NchwcInputShape[] = { InputShape[0], int64_t(NchwcChannels), InputShape[2], InputShape[3] };
size_t NchwcInputElements = size_t(NchwcInputShape[0]) * size_t(NchwcInputShape[1]) * size_t(NchwcInputShape[2]) * size_t(NchwcInputShape[3]);
float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements);
int64_t NchwcOutputShape[] = { OutputShape[0], int64_t(NchwcChannels), OutputShape[2], OutputShape[3] };
size_t NchwcOutputElements = size_t(NchwcOutputShape[0]) * size_t(NchwcOutputShape[1]) * size_t(NchwcOutputShape[2]) * size_t(NchwcOutputShape[3]);
float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements);
MlasReorderInput(InputShape, Input, NchwcInput);
MlasNchwcPool(PoolingKind,
2,
NchwcInputShape,
KernelShape,
nullptr,
Padding,
StrideShape,
NchwcOutputShape,
NchwcInput,
NchwcOutput,
nullptr);
MlasReorderOutput(OutputShape, NchwcOutput, Output);
}
MatrixGuardBuffer<float> BufferNchwcInput;
MatrixGuardBuffer<float> BufferNchwcOutput;
const size_t BlockSize = MlasNchwcGetBlockSize();
public:
void
ExecuteLong(
void
) override
{
static const unsigned is[] = { 53, 11, 1 };
for (unsigned ih = 0; ih < _countof(is); ih++) {
for (unsigned iw = 0; iw < _countof(is); iw++) {
fprintf(stderr, "Handling %ux%u\n", is[ih], is[iw]);
Test(1, 12, is[ih], is[iw], is[ih], is[iw], 0, 0, 0, 0, 1, 1);
Test(1, 32, is[ih], is[iw], is[ih], 1, 0, 0, 0, 0, 1, 1);
Test(1, 68, is[ih], is[iw], 1, is[iw], 0, 0, 0, 0, 1, 1);
for (unsigned kh = 1; kh <= 5; kh++) {
if (kh > is[ih]) break;
for (unsigned kw = 1; kw <= 5; kw++) {
if (kw > is[iw]) break;
for (unsigned sh = 1; sh <= 3; sh++) {
for (unsigned sw = 1; sw <= 3; sw++) {
for (unsigned p0 = 0; p0 < kh; p0++) {
for (unsigned p1 = 0; p1 < kw; p1++) {
for (unsigned p2 = 0; p2 < kh; p2++) {
for (unsigned p3 = 0; p3 < kw; p3++) {
Test(1, 32, is[ih], is[iw], kh, kw, p0, p1, p2, p3, sh, sw);
}
}
}
}
}
}
}
}
}
}
}
};
class MlasPool3DTest : public MlasTestBase
{
protected:
void
Test(
size_t BatchCount,
size_t InputChannels,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t KernelDepth,
size_t KernelHeight,
size_t KernelWidth,
size_t PaddingLeftDepth,
size_t PaddingLeftHeight,
size_t PaddingLeftWidth,
size_t PaddingRightDepth,
size_t PaddingRightHeight,
size_t PaddingRightWidth,
size_t StrideDepth,
size_t StrideHeight,
size_t StrideWidth
)
{
const size_t DilationDepth = 1;
const size_t DilationHeight = 1;
const size_t DilationWidth = 1;
int64_t OutputDepth64 =
((int64_t(InputDepth) + int64_t(PaddingLeftDepth) + int64_t(PaddingRightDepth)) -
(int64_t(DilationDepth) * (int64_t(KernelDepth) - 1) + 1)) / int64_t(StrideDepth) + 1;
int64_t OutputHeight64 =
((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) -
(int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / int64_t(StrideHeight) + 1;
int64_t OutputWidth64 =
((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) -
(int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / int64_t(StrideWidth) + 1;
if (OutputDepth64 <= 0 || OutputHeight64 <= 0 || OutputWidth64 <= 0) {
return;
}
int64_t InputShape[] = { int64_t(BatchCount), int64_t(InputChannels), int64_t(InputDepth), int64_t(InputHeight), int64_t(InputWidth) };
int64_t KernelShape[] = { int64_t(KernelDepth), int64_t(KernelHeight), int64_t(KernelWidth) };
int64_t Padding[] = { int64_t(PaddingLeftDepth), int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightDepth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth) };
int64_t StrideShape[] = { int64_t(StrideDepth), int64_t(StrideHeight), int64_t(StrideWidth) };
int64_t OutputShape[] = { int64_t(BatchCount), int64_t(InputChannels), OutputDepth64, OutputHeight64, OutputWidth64 };
OutputShape[2] = (InputShape[2] + Padding[0] + Padding[3] - KernelShape[0]) / StrideShape[0] + 1;
OutputShape[3] = (InputShape[3] + Padding[1] + Padding[4] - KernelShape[1]) / StrideShape[1] + 1;
OutputShape[4] = (InputShape[4] + Padding[2] + Padding[5] - KernelShape[2]) / StrideShape[2] + 1;
size_t InputBufferElements = size_t(InputShape[0] * InputShape[1] * InputShape[2] * InputShape[3] * InputShape[4]);
size_t OutputBufferElements = size_t(OutputShape[0] * OutputShape[1] * OutputShape[2] * OutputShape[3] * OutputShape[4]);
const float* Input = BufferInput.GetBuffer(InputBufferElements);
float* Output = BufferOutput.GetBuffer(OutputBufferElements);
float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements);
MlasPool(MlasMaximumPooling, 3, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool);
ReferenceMaximumPool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: maximum input(%zd,%zd,%zd,%zd),kernel(%zd,%zd,%zd)!!!\n",
InputChannels, InputDepth, InputHeight, InputWidth, KernelDepth, KernelHeight, KernelWidth);
}
MlasPool(MlasAveragePoolingExcludePad, 3, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool);
ReferenceAveragePool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, false);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: averageexcpad input(%zd,%zd,%zd,%zd),kernel(%zd,%zd,%zd)!!!\n",
InputChannels, InputDepth, InputHeight, InputWidth, KernelDepth, KernelHeight, KernelWidth);
}
MlasPool(MlasAveragePoolingIncludePad, 3, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool);
ReferenceAveragePool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, true);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch: averageincpad input(%zd,%zd,%zd,%zd),kernel(%zd,%zd,%zd)!!!\n",
InputChannels, InputDepth, InputHeight, InputWidth, KernelDepth, KernelHeight, KernelWidth);
}
}
void
ReferenceMaximumPool3D(
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const float* Input,
float* Output
)
{
int64_t ChannelCount = InputShape[0] * InputShape[1];
int64_t InputDepth = InputShape[2];
int64_t InputHeight = InputShape[3];
int64_t InputWidth = InputShape[4];
int64_t KernelDepth = KernelShape[0];
int64_t KernelHeight = KernelShape[1];
int64_t KernelWidth = KernelShape[2];
int64_t PaddingLeftZ = Padding[0];
int64_t PaddingLeftY = Padding[1];
int64_t PaddingLeftX = Padding[2];
int64_t PaddingRightZ = Padding[3];
int64_t PaddingRightY = Padding[4];
int64_t PaddingRightX = Padding[5];
int64_t StrideDepth = StrideShape[0];
int64_t StrideHeight = StrideShape[1];
int64_t StrideWidth = StrideShape[2];
int64_t OutputDepth = (InputDepth + PaddingLeftZ + PaddingRightZ - KernelDepth) / StrideDepth + 1;
int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1;
int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1;
for (int64_t c = 0; c < ChannelCount; c++) {
for (int64_t pd = 0; pd < OutputDepth; pd++) {
int64_t idStart = pd * StrideDepth - PaddingLeftZ;
int64_t idEnd = idStart + KernelDepth;
idStart = (std::max)(idStart, int64_t(0));
idEnd = (std::min)(idEnd, InputDepth);
for (int64_t ph = 0; ph < OutputHeight; ph++) {
int64_t ihStart = ph * StrideHeight - PaddingLeftY;
int64_t ihEnd = ihStart + KernelHeight;
ihStart = (std::max)(ihStart, int64_t(0));
ihEnd = (std::min)(ihEnd, InputHeight);
for (int64_t pw = 0; pw < OutputWidth; pw++) {
int64_t iwStart = pw * StrideWidth - PaddingLeftX;
int64_t iwEnd = iwStart + KernelWidth;
iwStart = (std::max)(iwStart, int64_t(0));
iwEnd = (std::min)(iwEnd, InputWidth);
float m = std::numeric_limits<float>::lowest();
for (int64_t id = idStart; id < idEnd; id++) {
for (int64_t ih = ihStart; ih < ihEnd; ih++) {
for (int64_t iw = iwStart; iw < iwEnd; iw++) {
m = (std::max)(m, Input[id * InputHeight * InputWidth + ih * InputWidth + iw]);
}
}
}
Output[pd * OutputHeight * OutputWidth + ph * OutputWidth + pw] = m;
}
}
}
Input += InputDepth * InputHeight * InputWidth;
Output += OutputDepth * OutputHeight * OutputWidth;
}
}
void
ReferenceAveragePool3D(
const int64_t* InputShape,
const int64_t* KernelShape,
const int64_t* Padding,
const int64_t* StrideShape,
const float* Input,
float* Output,
bool CountIncludePad
)
{
int64_t ChannelCount = InputShape[0] * InputShape[1];
int64_t InputDepth = InputShape[2];
int64_t InputHeight = InputShape[3];
int64_t InputWidth = InputShape[4];
int64_t KernelDepth = KernelShape[0];
int64_t KernelHeight = KernelShape[1];
int64_t KernelWidth = KernelShape[2];
int64_t PaddingLeftZ = Padding[0];
int64_t PaddingLeftY = Padding[1];
int64_t PaddingLeftX = Padding[2];
int64_t PaddingRightZ = Padding[3];
int64_t PaddingRightY = Padding[4];
int64_t PaddingRightX = Padding[5];
int64_t StrideDepth = StrideShape[0];
int64_t StrideHeight = StrideShape[1];
int64_t StrideWidth = StrideShape[2];
int64_t OutputDepth = (InputDepth + PaddingLeftZ + PaddingRightZ - KernelDepth) / StrideDepth + 1;
int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1;
int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1;
for (int64_t c = 0; c < ChannelCount; c++) {
for (int64_t pd = 0; pd < OutputDepth; pd++) {
int64_t idStart = pd * StrideDepth - PaddingLeftZ;
int64_t idEnd = idStart + KernelDepth;
idStart = (std::max)(idStart, int64_t(0));
idEnd = (std::min)(idEnd, InputDepth);
for (int64_t ph = 0; ph < OutputHeight; ph++) {
int64_t ihStart = ph * StrideHeight - PaddingLeftY;
int64_t ihEnd = ihStart + KernelHeight;
ihStart = (std::max)(ihStart, int64_t(0));
ihEnd = (std::min)(ihEnd, InputHeight);
for (int64_t pw = 0; pw < OutputWidth; pw++) {
int64_t iwStart = pw * StrideWidth - PaddingLeftX;
int64_t iwEnd = iwStart + KernelWidth;
iwStart = (std::max)(iwStart, int64_t(0));
iwEnd = (std::min)(iwEnd, InputWidth);
float m = 0.0f;
for (int64_t id = idStart; id < idEnd; id++) {
for (int64_t ih = ihStart; ih < ihEnd; ih++) {
for (int64_t iw = iwStart; iw < iwEnd; iw++) {
m += Input[id * InputHeight * InputWidth + ih * InputWidth + iw];
}
}
}
if (CountIncludePad) {
m /= (KernelDepth * KernelHeight * KernelWidth);
} else {
m /= (idEnd - idStart) * (ihEnd - ihStart) * (iwEnd - iwStart);
}
Output[pd * OutputHeight * OutputWidth + ph * OutputWidth + pw] = m;
}
}
}
Input += InputDepth * InputHeight * InputWidth;
Output += OutputDepth * OutputHeight * OutputWidth;
}
}
MatrixGuardBuffer<float> BufferInput;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutputReference;
public:
void
ExecuteShort(
void
) override
{
for (unsigned i = 1; i < 64; i <<= 1) {
Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1);
Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 2, 2, 2);
Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1);
Test(1, 16, i, i, i, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Test(1, 16, i, i, i, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1);
Test(1, 16, i, i, i, 1, i, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1);
Test(1, 16, i, i, i, 1, 1, i, 0, 0, 0, 0, 0, 0, 1, 1, 1);
}
}
void
ExecuteLong(
void
) override
{
static const unsigned is[] = { 11, 5, 4, 3, 2, 1 };
for (unsigned id = 0; id < _countof(is); id++) {
for (unsigned ih = 0; ih < _countof(is); ih++) {
for (unsigned iw = 0; iw < _countof(is); iw++) {
fprintf(stderr, "Handling %ux%ux%u\n", is[id], is[ih], is[iw]);
Test(1, 1, is[id], is[ih], is[iw], is[id], is[ih], is[iw], 0, 0, 0, 0, 0, 0, 1, 1, 1);
for (unsigned kd = 1; kd <= 4; kd++) {
if (kd > is[id]) break;
for (unsigned kh = 1; kh <= 4; kh++) {
if (kh > is[ih]) break;
for (unsigned kw = 1; kw <= 4; kw++) {
if (kw > is[iw]) break;
for (unsigned sd = 1; sd <= 3; sd++) {
for (unsigned sh = 1; sh <= 3; sh++) {
for (unsigned sw = 1; sw <= 3; sw++) {
for (unsigned p0 = 0; p0 < kd; p0++) {
for (unsigned p1 = 0; p1 < kh; p1++) {
for (unsigned p2 = 0; p2 < kw; p2++) {
for (unsigned p3 = 0; p3 < kd; p3++) {
for (unsigned p4 = 0; p4 < kh; p4++) {
for (unsigned p5 = 0; p5 < kw; p5++) {
Test(1, 1, is[id], is[ih], is[iw], kd, kh, kw, p0, p1, p2, p3, p4, p5, sd, sh, sw);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
};
class MlasActivationTest : public MlasTestBase
{
public:
void
ExecuteShort(
void
) override
{
union AliasedValue {
unsigned u;
float f;
};
// N.B. The test data includes values at the edge of Tanh/Logistic boundaries.
// Identity, Relu, LeakyRelu, Tanh, Logistic, Clip,
static const AliasedValue TestData[20][6] = {
{ {0x00000001}, {0x00000001}, {0x00000001}, {0x00000000}, {0x3f000000}, {0x00000001}, }, // positive denormal
{ {0x80000001}, {0x00000000}, {0x80000000}, {0x80000000}, {0x3f000000}, {0x00000000}, }, // negative denormal
{ {0x7ff00002}, {0x7ff00002}, {0x7ff00002}, {0x7ff00002}, {0x7ff00002}, {0x7ff00002}, }, // positive NaN
{ {0xfff00002}, {0xfff00002}, {0xfff00002}, {0xfff00002}, {0xfff00002}, {0xfff00002}, }, // negative NaN
{ {0x00000000}, {0x00000000}, {0x00000000}, {0x00000000}, {0x3f000000}, {0x00000000}, }, // 0.0f
{ {0x80000000}, {0x80000000}, {0x80000000}, {0x80000000}, {0x3f000000}, {0x80000000}, }, // -0.0f
{ {0x3e800000}, {0x3e800000}, {0x3e800000}, {0x3e7acbf5}, {0x3f0feacc}, {0x3e800000}, }, // 0.25f
{ {0xbe800000}, {0x00000000}, {0xbd4ccccd}, {0xbe7acbf5}, {0x3ee02a67}, {0x00000000}, }, // -0.25f
{ {0x40800000}, {0x40800000}, {0x40800000}, {0x3f7fd40a}, {0x3f7b6541}, {0x40800000}, }, // 4.0f
{ {0xc0800000}, {0x00000000}, {0xbf4ccccd}, {0xbf7fd40a}, {0x3c9357e0}, {0x00000000}, }, // -4.0f
{ {0x41200000}, {0x41200000}, {0x41200000}, {0x3f800000}, {0x3f7ffd06}, {0x40c00000}, }, // 10.0f
{ {0xc1200000}, {0x00000000}, {0xc0000000}, {0xbf800000}, {0x383e6000}, {0x00000000}, }, // -10.0f
{ {0xc18866eb}, {0x00000000}, {0xc05a3e45}, {0xbf800000}, {0x33000000}, {0x00000000}, }, // -17.0502529144f
{ {0xc18869bb}, {0x00000000}, {0xc05a42c5}, {0xbf800000}, {0x33c00000}, {0x00000000}, }, // -17.0516262054f
{ {0xc18852a8}, {0x00000000}, {0xc05a1dda}, {0xbf800000}, {0x00000000}, {0x00000000}, }, // -17.0403594971f
{ {0xc18844aa}, {0x00000000}, {0xc05a0777}, {0xbf800000}, {0x00000000}, {0x00000000}, }, // -17.0335273743f
{ {0x418866eb}, {0x418866eb}, {0x418866eb}, {0x3f800000}, {0x3f800000}, {0x40c00000}, }, // +17.0502529144f
{ {0x418869bb}, {0x418869bb}, {0x418869bb}, {0x3f800000}, {0x3f7ffffe}, {0x40c00000}, }, // +17.0516262054f
{ {0x418852a8}, {0x418852a8}, {0x418852a8}, {0x3f800000}, {0x3f800000}, {0x40c00000}, }, // +17.0403594971f
{ {0x418844aa}, {0x418844aa}, {0x418844aa}, {0x3f800000}, {0x3f800000}, {0x40c00000}, }, // +17.0335273743f
};
MLAS_ACTIVATION Activation;
AliasedValue Buffer[_countof(TestData)];
for (unsigned kind = 0; kind < unsigned(MlasClipActivation); kind++) {
Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind);
if (Activation.ActivationKind == MlasLeakyReluActivation) {
Activation.Parameters.LeakyRelu.alpha = 0.2f;
} else if (Activation.ActivationKind == MlasClipActivation) {
Activation.Parameters.Clip.minimum = 0.0f;
Activation.Parameters.Clip.maximum = 6.0f;
}
//
// Test the vectorized activations.
//
for (unsigned i = 0; i < _countof(TestData); i++) {
Buffer[i].u = TestData[i][0].u;
}
MlasActivation(&Activation, &Buffer[0].f, nullptr, _countof(Buffer), 1, 1);
for (unsigned i = 0; i < _countof(TestData); i++) {
// Sensitive to comparing positive/negative zero and NaNs.
if (Buffer[i].u != TestData[i][kind].u && Buffer[i].f != TestData[i][kind].f) {
printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", kind, i, Buffer[i].u, TestData[i][kind].u);
}
}
//
// Test the scalar activations.
//
for (unsigned i = 0; i < _countof(TestData); i++) {
Buffer[i].u = TestData[i][0].u;
MlasActivation(&Activation, &Buffer[i].f, nullptr, 1, 1, 1);
}
for (unsigned i = 0; i < _countof(TestData); i++) {
// Sensitive to comparing positive/negative zero and NaNs.
if (Buffer[i].u != TestData[i][kind].u && Buffer[i].f != TestData[i][kind].f) {
printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", kind, i, Buffer[i].u, TestData[i][kind].u);
}
}
}
}
void
ExecuteLong(
void
) override
{
}
};
int
#if defined(_WIN32)
__cdecl
#endif
main(
void
)
{
for (int i = 0; i != 2; ++i) {
printf("SGEMM tests.\n");
onnxruntime::make_unique<MlasFgemmTest<float>>()->ExecuteShort();
#ifdef MLAS_HAS_DGEMM
printf("DGEMM tests.\n");
onnxruntime::make_unique<MlasFgemmTest<double>>()->ExecuteShort();
#endif
#ifdef MLAS_HAS_QGEMM_U8X8
printf("QGEMM tests.\n");
onnxruntime::make_unique<MlasQgemmU8X8Test<int8_t>>()->ExecuteShort();
onnxruntime::make_unique<MlasQgemmU8X8Test<uint8_t>>()->ExecuteShort();
#endif
printf("Conv2D tests.\n");
onnxruntime::make_unique<MlasConv2DTest>()->ExecuteShort();
if (MlasNchwcGetBlockSize() > 1) {
onnxruntime::make_unique<MlasNchwcConv2DTest>()->ExecuteShort();
}
printf("Pool2D tests.\n");
onnxruntime::make_unique<MlasPool2DTest>()->ExecuteShort();
if (MlasNchwcGetBlockSize() > 1) {
onnxruntime::make_unique<MlasNchwcPool2DTest>()->ExecuteShort();
}
printf("Pool3D tests.\n");
onnxruntime::make_unique<MlasPool3DTest>()->ExecuteShort();
printf("Activation tests.\n");
onnxruntime::make_unique<MlasActivationTest>()->ExecuteShort();
printf("Done.\n");
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
if(threadpool != nullptr) threadpool = new onnxruntime::concurrency::ThreadPool("test", 2);
#else
break;
#endif
}
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
delete threadpool;
#endif
return 0;
}