mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
Rework CDist (#3393)
* Make CDist faster via Eigen squaredNorma and GEMM. * Add call to abs() as the GEMM output may differ slightly due to floating point accuracy and result in a negative distance which returns NaN if sqrt() is applied to it. * Update math::Gemm to use the type for alpha and beta instead of hardcoding to float. Matches the GemmEx definition. * Provide Eigen based replication of the GEMM call on x86 if T=double. * Make test model data deterministic. * Do the GEMM first so we can avoid potentially subtracting two numbers that are very close to each other.
This commit is contained in:
parent
718068f020
commit
40d80cde8f
6 changed files with 227 additions and 129 deletions
|
|
@ -2,6 +2,11 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "cdist.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -12,5 +17,115 @@ namespace contrib {
|
|||
DEFINE_KERNEL(float);
|
||||
DEFINE_KERNEL(double);
|
||||
|
||||
template <typename T>
|
||||
static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, concurrency::ThreadPool* threadpool) {
|
||||
// input shapes have already been validated
|
||||
const auto& shape_a = a.Shape().GetDims(); // {m, k}
|
||||
const auto& shape_b = b.Shape().GetDims(); // {n, k}
|
||||
int64_t m = shape_a[0];
|
||||
int64_t n = shape_b[0];
|
||||
int64_t k = shape_a[1];
|
||||
|
||||
// https://github.com/droyed/eucl_dist/wiki/Main-Article
|
||||
// dist(Xi,Yj) = sum_k(Xik**2) + sum_k(Yjk**2) - 2*sum_k(Xik*Yjk)
|
||||
|
||||
const auto* a_data = a.Data<T>();
|
||||
const auto* b_data = b.Data<T>();
|
||||
auto* c_data = c.MutableData<T>();
|
||||
|
||||
// ReduceSumSquare for A
|
||||
std::vector<T> a_ss;
|
||||
a_ss.resize(m);
|
||||
const auto* cur_a = a_data;
|
||||
for (int64_t i = 0; i < m; ++i) {
|
||||
a_ss[i] = ConstEigenVectorMap<T>(cur_a, k).squaredNorm();
|
||||
cur_a += k;
|
||||
}
|
||||
|
||||
// ReduceSumSquare for B
|
||||
std::vector<T> b_ss;
|
||||
b_ss.resize(n);
|
||||
const auto* cur_b = b_data;
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
b_ss[i] = ConstEigenVectorMap<T>(cur_b, k).squaredNorm();
|
||||
cur_b += k;
|
||||
}
|
||||
|
||||
// NOTE: We want to avoid subtracting two numbers that are very close to each other as that can lead to
|
||||
// 'catastrophic cancellation'. (sum_k(Xik**2) + sum_k(Yjk**2)) would be close to 2*sum_k(Xik*Yjk) if the values
|
||||
// in Xij and Yjk are very similar, so subtracting can be problematic.
|
||||
// Due to that we calculate -2*sum_k(Xik*Yjk) using GEMM, add sum_k(Xik**2) next, and add sum_k(Yjk**2) last.
|
||||
|
||||
// use MLAS on 64-bit (no 32-bit dgemm), or MKL on 32-bit or 64-bit
|
||||
#if defined(_M_AMD64) || defined(__x86_64__) || defined(USE_MKLML_FOR_BLAS)
|
||||
// Use GEMM of A and B^T with -2 as alpha to calculate -2*sum_k(Xik*Yjk)
|
||||
math::Gemm<T>(CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans,
|
||||
m, n, k,
|
||||
static_cast<T>(-2.), a_data, b_data, static_cast<T>(0.),
|
||||
c_data,
|
||||
threadpool);
|
||||
#else
|
||||
// the performance of this isn't great as the eigen matmul is single threaded by default
|
||||
// if you're on x86 and care about performance try MKL first. if there's a good enough argument for optimising this
|
||||
// we can look into it in the future.
|
||||
ORT_UNUSED_PARAMETER(threadpool);
|
||||
|
||||
// https://eigen.tuxfamily.org/dox/TopicWritingEfficientProductExpression.html
|
||||
auto out_map = EigenMatrixMapRowMajor<T>(c_data, m, n);
|
||||
out_map.noalias() = static_cast<T>(-2.) *
|
||||
(ConstEigenMatrixMapRowMajor<T>(a_data, m, k) *
|
||||
ConstEigenMatrixMapRowMajor<T>(b_data, n, k).transpose());
|
||||
#endif
|
||||
|
||||
// add a_ss and b_ss, with broadcast
|
||||
// output shape is {m, n}
|
||||
auto* cur_out = c_data;
|
||||
for (int64_t i = 0; i < m; ++i) {
|
||||
T a_val = a_ss[i];
|
||||
for (int64_t j = 0; j < n; ++j) {
|
||||
*cur_out = (*cur_out + a_val) + b_ss[j];
|
||||
++cur_out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
common::Status CDist<T>::Compute(OpKernelContext* context) const {
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
|
||||
assert(context->InputCount() == 2);
|
||||
const Tensor* A = context->Input<Tensor>(0);
|
||||
const Tensor* B = context->Input<Tensor>(1);
|
||||
const TensorShape& shape_a = A->Shape();
|
||||
const TensorShape& shape_b = B->Shape();
|
||||
if (shape_a.NumDimensions() != 2 || shape_a[1] <= 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The first input of CDist kernel has wrong shape: ", shape_a);
|
||||
}
|
||||
|
||||
if (shape_b.NumDimensions() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The second input of CDist kernel has wrong shape: ", shape_b);
|
||||
}
|
||||
if (shape_a[1] != shape_b[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input shape dimensions mismatch:", shape_a, " and ", shape_b);
|
||||
}
|
||||
|
||||
TensorShape output_shape = {shape_a[0], shape_b[0]};
|
||||
Tensor* C = context->Output(0, output_shape);
|
||||
T* output = C->MutableData<T>();
|
||||
|
||||
CalculateSqeuclidean<T>(*A, *B, *C, tp);
|
||||
auto map_out = EigenVectorArrayMap<T>(output, output_shape.Size());
|
||||
|
||||
// because we use GEMM in CalculateSqeuclidean there's a slight chance a number extremely close to zero
|
||||
// could be negative, so we need to run abs() to avoid NaN's in the results.
|
||||
if (mode_ == Mode::EUCLIDEAN) {
|
||||
map_out = map_out.abs().sqrt(); // do both abs and sqrt in one call so Eigen has a chance to combine
|
||||
} else if (mode_ == Mode::SQEUCLIDEAN) {
|
||||
map_out = map_out.abs();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,152 +4,33 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/util/distance.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "assert.h"
|
||||
#ifndef _OPENMP
|
||||
#include "core/util/eigen_common_wrapper.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html
|
||||
//\param a: matrix with shape of[ma,n]
|
||||
//\param b: matrix with shape of[mb,n]
|
||||
//\param dest: matrix with shape of [ma,mb]
|
||||
template <typename T, typename ElemFunc>
|
||||
void cdist_single_threaded(const T* a, const T* b, T* dest, size_t ma, size_t mb, size_t n) {
|
||||
ElemFunc f;
|
||||
for (size_t i = 0; i != ma; ++i) {
|
||||
// i-th row of matrix A
|
||||
const T* a1 = a + n * i;
|
||||
for (size_t j = 0; j != mb; ++j) {
|
||||
// j-th row of matrix B
|
||||
const T* b1 = b + n * j;
|
||||
*dest++ = f(a1, b1, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename ElemFunc>
|
||||
class CDistOneBlock {
|
||||
public:
|
||||
const T* a;
|
||||
const T* b;
|
||||
T* dest;
|
||||
size_t mb;
|
||||
size_t n;
|
||||
|
||||
CDistOneBlock(const T* a1, const T* b1, T* dest1, size_t mb1, size_t n1)
|
||||
: a(a1), b(b1), dest(dest1), mb(mb1), n(n1) {}
|
||||
|
||||
void operator()(Eigen::Index start, Eigen::Index end) {
|
||||
Eigen::Index mb_local = mb;
|
||||
Eigen::Index i = start / mb_local;
|
||||
Eigen::Index j = start - i * mb_local;
|
||||
assert(i * mb_local + j == start);
|
||||
Eigen::Index i_end = end / mb_local;
|
||||
Eigen::Index j_end = end - i_end * mb_local;
|
||||
assert(i_end * mb_local + j_end == end);
|
||||
|
||||
T* dest_local = dest + start;
|
||||
ElemFunc f;
|
||||
const T* a1 = a + n * i;
|
||||
for (; i != i_end; ++i) {
|
||||
a1 = a + n * i;
|
||||
for (; j != mb_local; ++j) {
|
||||
const T* b1 = b + n * j;
|
||||
*dest_local++ = f(a1, b1, n);
|
||||
}
|
||||
j = 0;
|
||||
}
|
||||
a1 = a + n * i;
|
||||
for (j = 0; j != j_end; ++j) {
|
||||
const T* b1 = b + n * j;
|
||||
*dest_local++ = f(a1, b1, n);
|
||||
}
|
||||
assert(dest_local == dest + end);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename ElemFunc>
|
||||
void cdist(const T* a, const T* b, T* dest, size_t ma, size_t mb, size_t n, concurrency::ThreadPool* tp) {
|
||||
#ifndef _OPENMP
|
||||
if (tp == nullptr) {
|
||||
#else
|
||||
(void)tp;
|
||||
#endif
|
||||
return cdist_single_threaded<T, ElemFunc>(a, b, dest, ma, mb, n);
|
||||
#ifndef _OPENMP
|
||||
}
|
||||
tp->ParallelFor(ma * mb, static_cast<double>(3 * n), CDistOneBlock<T, ElemFunc>(a, b, dest, mb, n));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class CDist final : public OpKernel {
|
||||
private:
|
||||
typedef void (*DistFunc)(const T* a, const T* b, T* dest, size_t ma, size_t mb, size_t n,
|
||||
concurrency::ThreadPool* tp);
|
||||
enum { EUCLIDEAN,
|
||||
SQEUCLIDEAN } mode_;
|
||||
enum class Mode { EUCLIDEAN,
|
||||
SQEUCLIDEAN } mode_;
|
||||
|
||||
public:
|
||||
CDist(const OpKernelInfo& info) : OpKernel(info) {
|
||||
std::string metric;
|
||||
ORT_ENFORCE(info.GetAttr<std::string>("metric", &metric).IsOK());
|
||||
if (metric.compare("sqeuclidean") == 0)
|
||||
mode_ = SQEUCLIDEAN;
|
||||
mode_ = Mode::SQEUCLIDEAN;
|
||||
else if (metric.compare("euclidean") == 0) {
|
||||
mode_ = EUCLIDEAN;
|
||||
mode_ = Mode::EUCLIDEAN;
|
||||
} else
|
||||
ORT_NOT_IMPLEMENTED();
|
||||
}
|
||||
|
||||
common::Status Compute(OpKernelContext* context) const override {
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
|
||||
assert(context->InputCount() == 2);
|
||||
const Tensor* A = context->Input<Tensor>(0);
|
||||
const Tensor* B = context->Input<Tensor>(1);
|
||||
const TensorShape& shape_a = A->Shape();
|
||||
const TensorShape& shape_b = B->Shape();
|
||||
if (shape_a.NumDimensions() != 2 || shape_a[1] <= 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The first input of CDist kernel has wrong shape: ", shape_a);
|
||||
}
|
||||
|
||||
if (shape_b.NumDimensions() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The second input of CDist kernel has wrong shape: ", shape_b);
|
||||
}
|
||||
if (shape_a[1] != shape_b[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input shape dimensions mismatch:", shape_a, " and ", shape_b);
|
||||
}
|
||||
|
||||
TensorShape output_shape = {shape_a[0], shape_b[0]};
|
||||
Tensor* C = context->Output(0, output_shape);
|
||||
T* output = C->MutableData<T>();
|
||||
switch (mode_) {
|
||||
case EUCLIDEAN:
|
||||
if (shape_a[1] >= 8)
|
||||
cdist<T, EuclideanWithEigen<T> >(A->Data<T>(), B->Data<T>(), output, shape_a[0], shape_b[0], shape_a[1], tp);
|
||||
else // for smaller vector size, a raw loop is better
|
||||
cdist<T, Euclidean<T> >(A->Data<T>(), B->Data<T>(), output, shape_a[0], shape_b[0], shape_a[1], tp);
|
||||
break;
|
||||
case SQEUCLIDEAN:
|
||||
if (shape_a[1] >= 8)
|
||||
cdist<T, SqeuclideanWithEigen<T> >(A->Data<T>(), B->Data<T>(), output, shape_a[0], shape_b[0], shape_a[1],
|
||||
tp);
|
||||
else // for smaller vector size, a raw loop is better
|
||||
cdist<T, Sqeuclidean<T> >(A->Data<T>(), B->Data<T>(), output, shape_a[0], shape_b[0], shape_a[1], tp);
|
||||
break;
|
||||
default:
|
||||
return Status(ONNXRUNTIME, NOT_IMPLEMENTED);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
common::Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -88,10 +88,10 @@ void Gemm(
|
|||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
float alpha,
|
||||
T alpha,
|
||||
const T* A,
|
||||
const T* B,
|
||||
float beta,
|
||||
T beta,
|
||||
T* C,
|
||||
Provider*);
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,17 @@ void Gemm<float, ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE
|
|||
MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, threadpool);
|
||||
}
|
||||
|
||||
#if defined(_M_AMD64) || defined(__x86_64__)
|
||||
template <>
|
||||
void Gemm<double, ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, double alpha, const double* A, const double* B, double beta,
|
||||
double* C, ThreadPool* threadpool) {
|
||||
int lda = static_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = static_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N, threadpool);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
void MatMul<float>(int M, int N, int K, const float* A, const float* B, float* C, ThreadPool* threadpool) {
|
||||
MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool);
|
||||
|
|
@ -160,6 +171,20 @@ void Gemm<float, ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE
|
|||
beta, C, gsl::narrow_cast<int>(N));
|
||||
}
|
||||
|
||||
template <>
|
||||
void Gemm<double, ThreadPool>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int64_t M,
|
||||
const int64_t N, const int64_t K, double alpha, const double* A, const double* B, double beta,
|
||||
double* C, ThreadPool* /*context*/) {
|
||||
int lda = gsl::narrow_cast<int>((TransA == CblasNoTrans) ? K : M);
|
||||
int ldb = gsl::narrow_cast<int>((TransB == CblasNoTrans) ? N : K);
|
||||
cblas_dgemm(CblasRowMajor, TransA, TransB,
|
||||
gsl::narrow_cast<int>(M),
|
||||
gsl::narrow_cast<int>(N),
|
||||
gsl::narrow_cast<int>(K),
|
||||
alpha, A, lda, B, ldb,
|
||||
beta, C, gsl::narrow_cast<int>(N));
|
||||
}
|
||||
|
||||
template <>
|
||||
void MatMul<float>(int M, int N, int K, const float* A, const float* B, float* C, ThreadPool*) {
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1, A, K, B, N, 0, C, N);
|
||||
|
|
|
|||
73
onnxruntime/test/contrib_ops/cdist_op_test.cc
Normal file
73
onnxruntime/test/contrib_ops/cdist_op_test.cc
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(CDistOpTest, Euclidean) {
|
||||
OpTester test("CDist", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("metric", "euclidean");
|
||||
|
||||
test.AddInput<float>("A", {4, 2},
|
||||
{-1.0856307f, 0.99734545f,
|
||||
0.2829785f, -1.5062947f,
|
||||
-0.5786002f, 1.6514366f,
|
||||
-2.4266791f, -0.42891264f});
|
||||
test.AddInput<float>("B", {3, 2},
|
||||
{1.2659363f, -0.8667404f,
|
||||
-0.6788862f, -0.09470897f,
|
||||
1.4913896f, -0.638902f});
|
||||
|
||||
test.AddOutput<float>("y", {4, 3},
|
||||
{3.0007803f, 1.1653428f, 3.0525956f,
|
||||
1.1727045f, 1.7081447f, 1.4874904f,
|
||||
3.1214628f, 1.749023f, 3.0871522f,
|
||||
3.718481f, 1.7794584f, 3.923692f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(CDistOpTest, Sqeuclidean) {
|
||||
OpTester test("CDist", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("metric", "sqeuclidean");
|
||||
|
||||
test.AddInput<float>("A", {4, 2},
|
||||
{-1.0856307f, 0.99734545f,
|
||||
0.2829785f, -1.5062947f,
|
||||
-0.5786002f, 1.6514366f,
|
||||
-2.4266791f, -0.42891264f});
|
||||
test.AddInput<float>("B", {3, 2},
|
||||
{1.2659363f, -0.8667404f,
|
||||
-0.6788862f, -0.09470897f,
|
||||
1.4913896f, -0.638902f});
|
||||
|
||||
test.AddOutput<float>("y", {4, 3},
|
||||
{9.004683f, 1.3580238f, 9.318338f,
|
||||
1.3752356f, 2.917758f, 2.2126276f,
|
||||
9.74353f, 3.0590816f, 9.530509f,
|
||||
13.827101f, 3.1664724f, 15.395359f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(CDistOpTest, DoubleEuclidean) {
|
||||
OpTester test("CDist", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("metric", "euclidean");
|
||||
|
||||
test.AddInput<double>("A", {2, 3},
|
||||
{0.17251948, 1.6354825, 0.0373364,
|
||||
-0.8841497, -1.1431923, -0.621366});
|
||||
test.AddInput<double>("B", {3, 3},
|
||||
{-1.3486496, -0.81973106, -0.1342539,
|
||||
1.5996001, -0.28360364, -0.5063398,
|
||||
0.06890842, 1.4522595, -1.6390957});
|
||||
|
||||
test.AddOutput<double>("y", {2, 3},
|
||||
{2.8933496, 2.4525568, 1.6895947,
|
||||
0.7467701, 2.6308053, 2.9462626});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@ import onnx
|
|||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
from datetime import date
|
||||
from onnx import numpy_helper
|
||||
from onnx import helper
|
||||
from onnx import utils
|
||||
|
|
@ -20,7 +21,7 @@ def parse_arguments():
|
|||
|
||||
def write_config(model_dir):
|
||||
with open(os.path.join(model_dir, "config.txt"), "w") as f:
|
||||
f.write("per_sample_tolerance:1e-6\n")
|
||||
f.write("per_sample_tolerance:1e-3\n")
|
||||
f.write("relative_per_sample_tolerance:1e-6\n")
|
||||
|
||||
|
||||
|
|
@ -182,6 +183,9 @@ def test_cdist(output_dir):
|
|||
|
||||
args = parse_arguments()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
# make test values deterministic but variable
|
||||
today = date.today()
|
||||
np.random.seed(today.year + today.month + today.day)
|
||||
test_abs(args.output_dir)
|
||||
test_size(args.output_dir)
|
||||
test_reducesum(args.output_dir)
|
||||
|
|
|
|||
Loading…
Reference in a new issue