From e7724bb100a37b006fba66b861165ba07082ecbf Mon Sep 17 00:00:00 2001 From: Don Jang Date: Mon, 16 Aug 2021 17:30:26 -0700 Subject: [PATCH] [JIT] Set future's error to current exception as is when `--torch_jit_enable_rethrow_caught_exception=true` (#63348) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63348 This change addresses singlaiiit's comment on D30241792 (https://github.com/pytorch/pytorch/commit/61b49c8e41a2faf7fd40278ca72616c5d92963cb), which makes the JIT interpreter's behavior consistent between `future` is set and not. Test Plan: Enhanced `EnableRethrowCaughtExceptionTest.EnableRethrowCaughtExceptionTestRethrowsCaughtException` to cover the modified code path. Reviewed By: singlaiiit Differential Revision: D30347782 fbshipit-source-id: 79ce57283154ca4372e5341217d942398db21ac8 --- test/cpp/jit/test_interpreter.cpp | 15 +++++++++++++++ torch/csrc/jit/runtime/interpreter.cpp | 3 ++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index 2ba2fba3757..a2418918336 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -265,6 +265,21 @@ graph(%0 : Tensor, "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"); } EXPECT_TRUE(exception_handled); + + FLAGS_torch_jit_enable_rethrow_caught_exception = true; + c10::intrusive_ptr future = interp.runAsync(stack); + future->wait(); + ASSERT_TRUE(future->completed()); + ASSERT_TRUE(future->hasError()); + try { + std::rethrow_exception(future->exception_ptr()); + } catch (c10::Error& e) { + std::string exception_msg = e.what_without_backtrace(); + EXPECT_STREQ( + exception_msg.c_str(), + "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"); + } + FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value; } diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index a095e4a26ad..be2019e532f 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -720,7 +720,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } if (FLAGS_torch_jit_enable_rethrow_caught_exception) { if (future_) { - future_->setError(std::make_exception_ptr(e)); + future_->setError(std::current_exception()); + return false; } throw; }