Revert "Remove abs in LpPool (#6303)"

This reverts commit 3b3e698674.
This commit is contained in:
Hariharan Seshadri 2021-02-09 21:27:36 -08:00 committed by Changming Sun
parent 8972621138
commit b09bfc8611
2 changed files with 1 additions and 20 deletions

View file

@ -88,7 +88,7 @@ class LpPool {
template <typename T>
static void Process(const T& x_data, T& y_data, const PoolProcessContext& cxt) {
y_data += static_cast<T>(std::pow(x_data, cxt.p_));
y_data += static_cast<T>(std::pow(std::abs(x_data), cxt.p_));
}
template <typename T>

View file

@ -1264,25 +1264,6 @@ TEST(PoolTest, LpPool) {
test.Run();
}
TEST(PoolTest, LpPoolWithNegativeNumbers) {
OpTester test("LpPool");
test.AddAttribute("p", static_cast<int64_t>(1));
test.AddAttribute("auto_pad", "");
test.AddAttribute("strides", std::vector<int64_t>{2});
test.AddAttribute("pads", vector<int64_t>{0, 0});
test.AddAttribute("kernel_shape", vector<int64_t>{2});
std::vector<float> x_vals = {0.2f, -0.6f};
std::vector<int64_t> x_dims = {1, 1, 2};
std::vector<int64_t> expected_dims = {1, 1, 1};
std::vector<float> expected_vals = {-0.4f};
test.AddInput<float>("X", x_dims, x_vals);
test.AddOutput<float>("Y", expected_dims, expected_vals);
test.Run();
}
TEST(PoolTest, GlobalLpPool) {
OpTester test("GlobalLpPool");
test.AddAttribute("p", static_cast<int64_t>(3));