diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index 2ba2fba3757..a2418918336 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -265,6 +265,21 @@ graph(%0 : Tensor, "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"); } EXPECT_TRUE(exception_handled); + + FLAGS_torch_jit_enable_rethrow_caught_exception = true; + c10::intrusive_ptr future = interp.runAsync(stack); + future->wait(); + ASSERT_TRUE(future->completed()); + ASSERT_TRUE(future->hasError()); + try { + std::rethrow_exception(future->exception_ptr()); + } catch (c10::Error& e) { + std::string exception_msg = e.what_without_backtrace(); + EXPECT_STREQ( + exception_msg.c_str(), + "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"); + } + FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value; } diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index a095e4a26ad..be2019e532f 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -720,7 +720,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } if (FLAGS_torch_jit_enable_rethrow_caught_exception) { if (future_) { - future_->setError(std::make_exception_ptr(e)); + future_->setError(std::current_exception()); + return false; } throw; }