mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Optimised kernel_dot() in SVM op (#3135)
This commit is contained in:
parent
9d874c1225
commit
a2eeb126b9
1 changed files with 24 additions and 19 deletions
|
|
@ -33,25 +33,30 @@ class SVMCommon {
|
|||
double sum = 0;
|
||||
const T* pA = A + a;
|
||||
const float* pB = B.data() + b;
|
||||
if (k == KERNEL::POLY) {
|
||||
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
|
||||
sum += *pA * *pB;
|
||||
sum = gamma_ * sum + coef0_;
|
||||
sum = std::pow(sum, degree_);
|
||||
} else if (k == KERNEL::SIGMOID) {
|
||||
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
|
||||
sum += *pA * *pB;
|
||||
sum = gamma_ * sum + coef0_;
|
||||
sum = std::tanh(sum);
|
||||
} else if (k == KERNEL::RBF) {
|
||||
for (int64_t i = len; i > 0; --i, ++pA, ++pB) {
|
||||
double val = *pA - *pB;
|
||||
sum += val * val;
|
||||
}
|
||||
sum = std::exp(-gamma_ * sum);
|
||||
} else if (k == KERNEL::LINEAR) {
|
||||
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
|
||||
sum += *pA * *pB;
|
||||
switch(k) {
|
||||
case KERNEL::POLY:
|
||||
for (int64_t i = len; i > 0; --i)
|
||||
sum += *pA++ * *pB++;
|
||||
sum = gamma_ * sum + coef0_;
|
||||
sum = std::pow(sum, degree_);
|
||||
break;
|
||||
case KERNEL::SIGMOID:
|
||||
for (int64_t i = len; i > 0; --i)
|
||||
sum += *pA++ * *pB++;
|
||||
sum = gamma_ * sum + coef0_;
|
||||
sum = std::tanh(sum);
|
||||
break;
|
||||
case KERNEL::RBF:
|
||||
for (int64_t i = len; i > 0; --i) {
|
||||
double val = *pA++ - *pB++;
|
||||
sum += val * val;
|
||||
}
|
||||
sum = std::exp(-gamma_ * sum);
|
||||
break;
|
||||
case KERNEL::LINEAR:
|
||||
for (int64_t i = len; i > 0; --i)
|
||||
sum += *pA++ * *pB++;
|
||||
break;
|
||||
}
|
||||
return (float)sum;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue