diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 71c2e09773..ec7fca5c82 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -144,7 +144,14 @@ Status BFCArena::Extend(size_t rounded_bytes) { // Try allocating less memory. while (mem_addr == nullptr) { bytes = RoundedBytes(static_cast(bytes * kBackpedalFactor)); - if (bytes < rounded_bytes) + + // give up if we can't satisfy the requested size, or we're attempting an allocation of less than 8K. + // + // the latter protects against an infinite loop that occurs when bytes is less than 2560. at that point the 10% + // reduction to 2304 bytes is undone by rounding to a 256 boundary in RoundedBytes, leading to an infinite loop. + // the 8K value is just to give up a little earlier vs. getting all the way down to 2560 bytes. + // If we can't allocate 8K, we're pretty much dead. + if (bytes < rounded_bytes || bytes < 8 * 1024) break; mem_addr = safe_alloc(bytes); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index df265922cb..873032fa37 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -284,5 +284,19 @@ TEST(BFCArenaTest, TestReserve) { a.GetStats(&stats); EXPECT_EQ(stats.total_allocated_bytes, 1048576); } + +class BadAllocator : public IAllocator { + public: + BadAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} + + void* Alloc(size_t /*size*/) override { throw std::bad_alloc(); } + void Free(void* /*p*/) override {} +}; + +TEST(BFCArenaTest, TestBackoffDoesntHang) { + // test that if there are allocation failures the backoff logic doesn't hang. See comments in BFCArena::Extend + BFCArena a(std::unique_ptr(new BadAllocator()), 10 * 1024 * 1024); + EXPECT_THROW(a.Alloc(1024), OnnxRuntimeException) << "Arena should be unable to allocate memory"; +} } // namespace test } // namespace onnxruntime