From f4cee22b9bec2eac51f8a16c0d5aed78ef6031a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 26 Oct 2020 09:58:02 +0100 Subject: [PATCH] Handle -inf in ReduceSumLogExp, fix regression introduced in PR #5370 (#5583) * Handle -inf in ReduceSumLogExp operator * Update reduction_ops_test.cc * Remove a case which has a different behaviour CPU/GPU --- .../providers/cpu/reduction/reduction_ops.h | 38 ++++++++- .../cpu/reduction/reduction_ops_test.cc | 84 +++++++++++++++++++ 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h index 0f5366cfc0..9b0c30b880 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h @@ -64,6 +64,36 @@ inline int32_t reduce_log(int32_t value) { return static_cast( template inline T reduce_exp(T value) { return static_cast(std::exp(value)); } +template +inline bool reduce_isinf(T value) { return std::isinf(value); } + +template <> +inline bool reduce_isinf(int8_t) { return false; } + +template <> +inline bool reduce_isinf(uint8_t) { return false; } + +template <> +inline bool reduce_isinf(int32_t) { return false; } + +template <> +inline bool reduce_isinf(int64_t) { return false; } + +template +inline bool reduce_isnan(T value) { return std::isnan(value); } + +template <> +inline bool reduce_isnan(int8_t) { return false; } + +template <> +inline bool reduce_isnan(uint8_t) { return false; } + +template <> +inline bool reduce_isnan(int32_t) { return false; } + +template <> +inline bool reduce_isnan(int64_t) { return false; } + template class ReduceAggregator { public: @@ -273,7 +303,9 @@ class ReduceAggregatorLogSumExp : public ReduceAggregator { T max_; public: - inline ReduceAggregatorLogSumExp(int64_t N, const T&) : ReduceAggregator(N, 0) { max_ = this->accumulator_; } + inline ReduceAggregatorLogSumExp(int64_t N, const T& init) : ReduceAggregator(N, 0) { + max_ = reduce_isinf(init) ? this->accumulator_ : init; + } inline TVAL aggall(const T* from_data) { max_ = Eigen::Map>(from_data, this->N_).maxCoeff(); for (int64_t i = 0; i < this->N_; ++i) { @@ -281,7 +313,9 @@ class ReduceAggregatorLogSumExp : public ReduceAggregator { } return get_value(); } - inline void update0(const T& v) { max_ = v > max_ ? v : max_; } + inline void update0(const T& v) { + max_ = (reduce_isinf(v) || reduce_isnan(v) || v < max_) ? max_ : v; + } inline void update(const T& v) { this->accumulator_ += reduce_exp(v - max_); } inline TVAL get_value() { return reduce_log(this->accumulator_) + max_; } static inline bool two_loops() { return true; } diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 602913d300..8c3ab7981d 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -15,6 +15,9 @@ namespace onnxruntime { namespace test { +const float FLOAT_INF = std::numeric_limits::infinity(); +const float FLOAT_NINF = -std::numeric_limits::infinity(); + // Disable TensorRT on some of the tests because the limit in its parser: axis >=0 && axis < nbDims template void TestReduceOp(const std::string& op, @@ -2018,5 +2021,86 @@ TEST(ReductionOpTest, ReduceDimWithZero) { run(test3); } +TEST(ReductionOpTest, ReduceInfMax) { + OpTester test("ReduceMax"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {6, 2}, + {1.0f, FLOAT_NINF, + FLOAT_NINF, 4.0f, + FLOAT_INF, FLOAT_NINF, + FLOAT_NINF, FLOAT_INF, + 1.0f, FLOAT_INF, + FLOAT_INF, 4.0f}); + test.AddOutput("reduced", {6}, + {1.0f, 4.0f, + FLOAT_INF, FLOAT_INF, + FLOAT_INF, FLOAT_INF}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceInfMin) { + OpTester test("ReduceMin"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {6, 2}, + {1.0f, FLOAT_INF, + FLOAT_INF, 4.0f, + FLOAT_INF, FLOAT_NINF, + FLOAT_NINF, FLOAT_INF, + 1.0f, FLOAT_NINF, + FLOAT_NINF, 4.0f}); + test.AddOutput("reduced", {6}, + {1.0f, 4.0f, + FLOAT_NINF, FLOAT_NINF, + FLOAT_NINF, FLOAT_NINF}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceInfSum) { + OpTester test("ReduceSum"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {6, 2}, + {1.0f, FLOAT_INF, + FLOAT_INF, 4.0f, + FLOAT_INF, FLOAT_NINF, + FLOAT_NINF, FLOAT_INF, + 1.0f, FLOAT_NINF, + FLOAT_NINF, 4.0f}); + test.AddOutput("reduced", {6}, + {FLOAT_INF, FLOAT_INF, + std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), + FLOAT_NINF, FLOAT_NINF}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceInfLogSum) { + OpTester test("ReduceLogSum"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {6, 2}, + {1.0f, FLOAT_INF, + FLOAT_INF, 1.0f, + FLOAT_INF, FLOAT_NINF, + FLOAT_NINF, FLOAT_INF, + 1.0f, FLOAT_NINF, + FLOAT_NINF, 1.0f}); + test.AddOutput("reduced", {6}, + {FLOAT_INF, FLOAT_INF, + -std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceInfLogSumExp) { + OpTester test("ReduceLogSumExp"); + test.AddAttribute("axes", std::vector{1}); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {2, 2}, {1.0f, FLOAT_NINF, FLOAT_NINF, 1.0f}); + test.AddOutput("reduced", {2}, {1.0f, 1.0f}); + test.Run(); +} + } // namespace test } // namespace onnxruntime