From a2eeb126b9553827cce4a401b8a857df78d4a49e Mon Sep 17 00:00:00 2001 From: Prabhat Date: Wed, 4 Mar 2020 16:30:40 +0000 Subject: [PATCH] Optimised kernel_dot() in SVM op (#3135) --- .../core/providers/cpu/ml/svmclassifier.h | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/svmclassifier.h b/onnxruntime/core/providers/cpu/ml/svmclassifier.h index bc0cf83627..060f84cacf 100644 --- a/onnxruntime/core/providers/cpu/ml/svmclassifier.h +++ b/onnxruntime/core/providers/cpu/ml/svmclassifier.h @@ -33,25 +33,30 @@ class SVMCommon { double sum = 0; const T* pA = A + a; const float* pB = B.data() + b; - if (k == KERNEL::POLY) { - for (int64_t i = len; i > 0; --i, ++pA, ++pB) - sum += *pA * *pB; - sum = gamma_ * sum + coef0_; - sum = std::pow(sum, degree_); - } else if (k == KERNEL::SIGMOID) { - for (int64_t i = len; i > 0; --i, ++pA, ++pB) - sum += *pA * *pB; - sum = gamma_ * sum + coef0_; - sum = std::tanh(sum); - } else if (k == KERNEL::RBF) { - for (int64_t i = len; i > 0; --i, ++pA, ++pB) { - double val = *pA - *pB; - sum += val * val; - } - sum = std::exp(-gamma_ * sum); - } else if (k == KERNEL::LINEAR) { - for (int64_t i = len; i > 0; --i, ++pA, ++pB) - sum += *pA * *pB; + switch(k) { + case KERNEL::POLY: + for (int64_t i = len; i > 0; --i) + sum += *pA++ * *pB++; + sum = gamma_ * sum + coef0_; + sum = std::pow(sum, degree_); + break; + case KERNEL::SIGMOID: + for (int64_t i = len; i > 0; --i) + sum += *pA++ * *pB++; + sum = gamma_ * sum + coef0_; + sum = std::tanh(sum); + break; + case KERNEL::RBF: + for (int64_t i = len; i > 0; --i) { + double val = *pA++ - *pB++; + sum += val * val; + } + sum = std::exp(-gamma_ * sum); + break; + case KERNEL::LINEAR: + for (int64_t i = len; i > 0; --i) + sum += *pA++ * *pB++; + break; } return (float)sum; }