mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36299 Test Plan: Imported from OSS Differential Revision: D20943005 Pulled By: ezyang fbshipit-source-id: 9dd0a58824bd0f1b5ad259942f92954ba1f63eae
126 lines
3.2 KiB
C++
126 lines
3.2 KiB
C++
#include <c10/util/Exception.h>
|
|
#include <c10/util/Backtrace.h>
|
|
#include <c10/util/Type.h>
|
|
#include <c10/util/Logging.h>
|
|
|
|
#include <iostream>
|
|
#include <numeric>
|
|
#include <string>
|
|
|
|
namespace c10 {
|
|
|
|
Error::Error(
|
|
const std::string& new_msg,
|
|
const std::string& backtrace,
|
|
const void* caller)
|
|
: msg_stack_{new_msg}, backtrace_(backtrace), caller_(caller) {
|
|
msg_ = msg();
|
|
msg_without_backtrace_ = msg_without_backtrace();
|
|
}
|
|
|
|
// PyTorch-style error message
|
|
// Error::Error(SourceLocation source_location, const std::string& msg)
|
|
// NB: This is defined in Logging.cpp for access to GetFetchStackTrace
|
|
|
|
// Caffe2-style error message
|
|
Error::Error(
|
|
const char* file,
|
|
const uint32_t line,
|
|
const char* condition,
|
|
const std::string& msg,
|
|
const std::string& backtrace,
|
|
const void* caller)
|
|
: Error(
|
|
str("[enforce fail at ",
|
|
detail::StripBasename(file),
|
|
":",
|
|
line,
|
|
"] ",
|
|
condition,
|
|
". ",
|
|
msg,
|
|
"\n"),
|
|
backtrace,
|
|
caller) {}
|
|
|
|
std::string Error::msg() const {
|
|
return std::accumulate(
|
|
msg_stack_.begin(), msg_stack_.end(), std::string("")) +
|
|
backtrace_;
|
|
}
|
|
|
|
std::string Error::msg_without_backtrace() const {
|
|
return std::accumulate(msg_stack_.begin(), msg_stack_.end(), std::string(""));
|
|
}
|
|
|
|
void Error::AppendMessage(const std::string& new_msg) {
|
|
msg_stack_.push_back(new_msg);
|
|
// Refresh the cache
|
|
// TODO: Calling AppendMessage O(n) times has O(n^2) cost. We can fix
|
|
// this perf problem by populating the fields lazily... if this ever
|
|
// actually is a problem.
|
|
msg_ = msg();
|
|
msg_without_backtrace_ = msg_without_backtrace();
|
|
}
|
|
|
|
namespace Warning {
|
|
|
|
namespace {
|
|
WarningHandler* getBaseHandler() {
|
|
static WarningHandler base_warning_handler_ = WarningHandler();
|
|
return &base_warning_handler_;
|
|
};
|
|
|
|
class ThreadWarningHandler {
|
|
public:
|
|
ThreadWarningHandler() = delete;
|
|
|
|
static WarningHandler* get_handler() {
|
|
if (!warning_handler_) {
|
|
warning_handler_ = getBaseHandler();
|
|
}
|
|
return warning_handler_;
|
|
}
|
|
|
|
static void set_handler(WarningHandler* handler) {
|
|
warning_handler_ = handler;
|
|
}
|
|
|
|
private:
|
|
static thread_local WarningHandler* warning_handler_;
|
|
};
|
|
|
|
thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
|
|
|
|
}
|
|
|
|
void warn(SourceLocation source_location, const std::string& msg) {
|
|
ThreadWarningHandler::get_handler()->process(source_location, msg);
|
|
}
|
|
|
|
void set_warning_handler(WarningHandler* handler) noexcept(true) {
|
|
ThreadWarningHandler::set_handler(handler);
|
|
}
|
|
|
|
WarningHandler* get_warning_handler() noexcept(true) {
|
|
return ThreadWarningHandler::get_handler();
|
|
}
|
|
|
|
} // namespace Warning
|
|
|
|
void WarningHandler::process(
|
|
const SourceLocation& source_location,
|
|
const std::string& msg) {
|
|
std::cerr << "Warning: " << msg << " (" << source_location << ")\n";
|
|
}
|
|
|
|
|
|
std::string GetExceptionString(const std::exception& e) {
|
|
#ifdef __GXX_RTTI
|
|
return demangle(typeid(e).name()) + ": " + e.what();
|
|
#else
|
|
return std::string("Exception (no RTTI available): ") + e.what();
|
|
#endif // __GXX_RTTI
|
|
}
|
|
|
|
} // namespace c10
|