Optimised kernel_dot() in SVM op (#3135)

This commit is contained in:
Prabhat 2020-03-04 16:30:40 +00:00 committed by GitHub
parent 9d874c1225
commit a2eeb126b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,25 +33,30 @@ class SVMCommon {
double sum = 0;
const T* pA = A + a;
const float* pB = B.data() + b;
if (k == KERNEL::POLY) {
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
sum += *pA * *pB;
sum = gamma_ * sum + coef0_;
sum = std::pow(sum, degree_);
} else if (k == KERNEL::SIGMOID) {
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
sum += *pA * *pB;
sum = gamma_ * sum + coef0_;
sum = std::tanh(sum);
} else if (k == KERNEL::RBF) {
for (int64_t i = len; i > 0; --i, ++pA, ++pB) {
double val = *pA - *pB;
sum += val * val;
}
sum = std::exp(-gamma_ * sum);
} else if (k == KERNEL::LINEAR) {
for (int64_t i = len; i > 0; --i, ++pA, ++pB)
sum += *pA * *pB;
switch(k) {
case KERNEL::POLY:
for (int64_t i = len; i > 0; --i)
sum += *pA++ * *pB++;
sum = gamma_ * sum + coef0_;
sum = std::pow(sum, degree_);
break;
case KERNEL::SIGMOID:
for (int64_t i = len; i > 0; --i)
sum += *pA++ * *pB++;
sum = gamma_ * sum + coef0_;
sum = std::tanh(sum);
break;
case KERNEL::RBF:
for (int64_t i = len; i > 0; --i) {
double val = *pA++ - *pB++;
sum += val * val;
}
sum = std::exp(-gamma_ * sum);
break;
case KERNEL::LINEAR:
for (int64_t i = len; i > 0; --i)
sum += *pA++ * *pB++;
break;
}
return (float)sum;
}