mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
* Handle -inf in ReduceSumLogExp operator * Update reduction_ops_test.cc * Remove a case which has a different behaviour CPU/GPU
This commit is contained in:
parent
502f67ba58
commit
f4cee22b9b
2 changed files with 120 additions and 2 deletions
|
|
@ -64,6 +64,36 @@ inline int32_t reduce_log<int32_t>(int32_t value) { return static_cast<int32_t>(
|
|||
template <typename T>
|
||||
inline T reduce_exp(T value) { return static_cast<T>(std::exp(value)); }
|
||||
|
||||
template <typename T>
|
||||
inline bool reduce_isinf(T value) { return std::isinf(value); }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isinf<int8_t>(int8_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isinf<uint8_t>(uint8_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isinf<int32_t>(int32_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isinf<int64_t>(int64_t) { return false; }
|
||||
|
||||
template <typename T>
|
||||
inline bool reduce_isnan(T value) { return std::isnan(value); }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isnan<int8_t>(int8_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isnan<uint8_t>(uint8_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isnan<int32_t>(int32_t) { return false; }
|
||||
|
||||
template <>
|
||||
inline bool reduce_isnan<int64_t>(int64_t) { return false; }
|
||||
|
||||
template <typename T, typename TVAL = T>
|
||||
class ReduceAggregator {
|
||||
public:
|
||||
|
|
@ -273,7 +303,9 @@ class ReduceAggregatorLogSumExp : public ReduceAggregator<T, TVAL> {
|
|||
T max_;
|
||||
|
||||
public:
|
||||
inline ReduceAggregatorLogSumExp(int64_t N, const T&) : ReduceAggregator<T, TVAL>(N, 0) { max_ = this->accumulator_; }
|
||||
inline ReduceAggregatorLogSumExp(int64_t N, const T& init) : ReduceAggregator<T, TVAL>(N, 0) {
|
||||
max_ = reduce_isinf(init) ? this->accumulator_ : init;
|
||||
}
|
||||
inline TVAL aggall(const T* from_data) {
|
||||
max_ = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>(from_data, this->N_).maxCoeff();
|
||||
for (int64_t i = 0; i < this->N_; ++i) {
|
||||
|
|
@ -281,7 +313,9 @@ class ReduceAggregatorLogSumExp : public ReduceAggregator<T, TVAL> {
|
|||
}
|
||||
return get_value();
|
||||
}
|
||||
inline void update0(const T& v) { max_ = v > max_ ? v : max_; }
|
||||
inline void update0(const T& v) {
|
||||
max_ = (reduce_isinf(v) || reduce_isnan(v) || v < max_) ? max_ : v;
|
||||
}
|
||||
inline void update(const T& v) { this->accumulator_ += reduce_exp(v - max_); }
|
||||
inline TVAL get_value() { return reduce_log<T>(this->accumulator_) + max_; }
|
||||
static inline bool two_loops() { return true; }
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
const float FLOAT_INF = std::numeric_limits<float>::infinity();
|
||||
const float FLOAT_NINF = -std::numeric_limits<float>::infinity();
|
||||
|
||||
// Disable TensorRT on some of the tests because the limit in its parser: axis >=0 && axis < nbDims
|
||||
template <typename OutT>
|
||||
void TestReduceOp(const std::string& op,
|
||||
|
|
@ -2018,5 +2021,86 @@ TEST(ReductionOpTest, ReduceDimWithZero) {
|
|||
run(test3);
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceInfMax) {
|
||||
OpTester test("ReduceMax");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{1});
|
||||
test.AddAttribute("keepdims", (int64_t)0);
|
||||
test.AddInput<float>("data", {6, 2},
|
||||
{1.0f, FLOAT_NINF,
|
||||
FLOAT_NINF, 4.0f,
|
||||
FLOAT_INF, FLOAT_NINF,
|
||||
FLOAT_NINF, FLOAT_INF,
|
||||
1.0f, FLOAT_INF,
|
||||
FLOAT_INF, 4.0f});
|
||||
test.AddOutput<float>("reduced", {6},
|
||||
{1.0f, 4.0f,
|
||||
FLOAT_INF, FLOAT_INF,
|
||||
FLOAT_INF, FLOAT_INF});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceInfMin) {
|
||||
OpTester test("ReduceMin");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{1});
|
||||
test.AddAttribute("keepdims", (int64_t)0);
|
||||
test.AddInput<float>("data", {6, 2},
|
||||
{1.0f, FLOAT_INF,
|
||||
FLOAT_INF, 4.0f,
|
||||
FLOAT_INF, FLOAT_NINF,
|
||||
FLOAT_NINF, FLOAT_INF,
|
||||
1.0f, FLOAT_NINF,
|
||||
FLOAT_NINF, 4.0f});
|
||||
test.AddOutput<float>("reduced", {6},
|
||||
{1.0f, 4.0f,
|
||||
FLOAT_NINF, FLOAT_NINF,
|
||||
FLOAT_NINF, FLOAT_NINF});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceInfSum) {
|
||||
OpTester test("ReduceSum");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{1});
|
||||
test.AddAttribute("keepdims", (int64_t)0);
|
||||
test.AddInput<float>("data", {6, 2},
|
||||
{1.0f, FLOAT_INF,
|
||||
FLOAT_INF, 4.0f,
|
||||
FLOAT_INF, FLOAT_NINF,
|
||||
FLOAT_NINF, FLOAT_INF,
|
||||
1.0f, FLOAT_NINF,
|
||||
FLOAT_NINF, 4.0f});
|
||||
test.AddOutput<float>("reduced", {6},
|
||||
{FLOAT_INF, FLOAT_INF,
|
||||
std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(),
|
||||
FLOAT_NINF, FLOAT_NINF});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceInfLogSum) {
|
||||
OpTester test("ReduceLogSum");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{1});
|
||||
test.AddAttribute("keepdims", (int64_t)0);
|
||||
test.AddInput<float>("data", {6, 2},
|
||||
{1.0f, FLOAT_INF,
|
||||
FLOAT_INF, 1.0f,
|
||||
FLOAT_INF, FLOAT_NINF,
|
||||
FLOAT_NINF, FLOAT_INF,
|
||||
1.0f, FLOAT_NINF,
|
||||
FLOAT_NINF, 1.0f});
|
||||
test.AddOutput<float>("reduced", {6},
|
||||
{FLOAT_INF, FLOAT_INF,
|
||||
-std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(),
|
||||
std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN()});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceInfLogSumExp) {
|
||||
OpTester test("ReduceLogSumExp");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{1});
|
||||
test.AddAttribute("keepdims", (int64_t)0);
|
||||
test.AddInput<float>("data", {2, 2}, {1.0f, FLOAT_NINF, FLOAT_NINF, 1.0f});
|
||||
test.AddOutput<float>("reduced", {2}, {1.0f, 1.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue