mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: This is a re-attempt to land the iwyu header changes, by taking the diff from [PR 100304](https://github.com/pytorch/pytorch/pull/100304), and adding the bare minimal changes to make the diff build corectly in the internal builds. X-link: https://github.com/facebookresearch/pytorch3d/pull/1541 X-link: https://github.com/fairinternal/pytorch3d/pull/44 - Re-work D45769819 to fix header inclusions in c10 Test Plan: ``` buck2 build --no-remote-cache mode/dev-nosan //caffe2/c10/... buck2 build --no-remote-cache mode/dev-nosan //deeplearning/fbgemm/fbgemm_gpu/... buck2 build mode/dev-nosan //vision/fair/pytorch3d/pytorch3d:_C ``` Reviewed By: malfet Differential Revision: D45920611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101846 Approved by: https://github.com/malfet, https://github.com/Skylion007
74 lines
2.3 KiB
C++
74 lines
2.3 KiB
C++
#include <c10/core/DispatchKey.h>
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
|
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
thread_local TorchDispatchModeTLS torchDispatchModeState;
|
|
|
|
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
|
|
if (torchDispatchModeState.stack_.empty()) {
|
|
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
DispatchKey::PythonTLSSnapshot, true);
|
|
}
|
|
torchDispatchModeState.stack_.push_back(std::move(mode));
|
|
}
|
|
|
|
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
|
|
TORCH_CHECK(
|
|
!torchDispatchModeState.stack_.empty(),
|
|
"trying to pop from empty mode stack");
|
|
std::shared_ptr<SafePyObject> out = torchDispatchModeState.stack_.back();
|
|
torchDispatchModeState.stack_.pop_back();
|
|
|
|
if (torchDispatchModeState.stack_.empty()) {
|
|
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
DispatchKey::PythonTLSSnapshot, false);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
|
|
int64_t idx) {
|
|
TORCH_CHECK(
|
|
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
|
|
"Tried to get stack at idx that's too big");
|
|
return torchDispatchModeState.stack_[idx];
|
|
}
|
|
|
|
int64_t TorchDispatchModeTLS::stack_len() {
|
|
return static_cast<int64_t>(torchDispatchModeState.stack_.size());
|
|
}
|
|
|
|
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
|
|
return torchDispatchModeState;
|
|
}
|
|
|
|
void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
|
|
torchDispatchModeState = std::move(state);
|
|
if (torchDispatchModeState.stack_.empty()) {
|
|
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
DispatchKey::PythonTLSSnapshot, false);
|
|
} else {
|
|
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
DispatchKey::PythonTLSSnapshot, true);
|
|
}
|
|
}
|
|
|
|
// UTIL
|
|
|
|
bool dispatch_mode_enabled() {
|
|
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
|
|
TorchDispatchModeTLS::stack_len() > 0;
|
|
}
|
|
|
|
} // namespace impl
|
|
} // namespace c10
|