pytorch/c10/core/impl/TorchDispatchModeTLS.cpp
Benson Ma 66a2600b6a [T153220354] Fix header inclusions in c10 (#1541) (#101846)
Summary:
This is a re-attempt to land the iwyu header changes, by taking the diff from [PR 100304](https://github.com/pytorch/pytorch/pull/100304), and adding the bare minimal changes to make the diff build corectly in the internal builds.

X-link: https://github.com/facebookresearch/pytorch3d/pull/1541

X-link: https://github.com/fairinternal/pytorch3d/pull/44

- Re-work D45769819 to fix header inclusions in c10

Test Plan:
```
buck2 build --no-remote-cache mode/dev-nosan //caffe2/c10/...

buck2 build --no-remote-cache mode/dev-nosan //deeplearning/fbgemm/fbgemm_gpu/...

buck2 build mode/dev-nosan //vision/fair/pytorch3d/pytorch3d:_C
```

Reviewed By: malfet

Differential Revision: D45920611

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101846
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-05-20 19:35:14 +00:00

74 lines
2.3 KiB
C++

#include <c10/core/DispatchKey.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <utility>
namespace c10 {
namespace impl {
thread_local TorchDispatchModeTLS torchDispatchModeState;
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
if (torchDispatchModeState.stack_.empty()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
TORCH_CHECK(
!torchDispatchModeState.stack_.empty(),
"trying to pop from empty mode stack");
std::shared_ptr<SafePyObject> out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
if (torchDispatchModeState.stack_.empty()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
int64_t idx) {
TORCH_CHECK(
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
"Tried to get stack at idx that's too big");
return torchDispatchModeState.stack_[idx];
}
int64_t TorchDispatchModeTLS::stack_len() {
return static_cast<int64_t>(torchDispatchModeState.stack_.size());
}
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
torchDispatchModeState = std::move(state);
if (torchDispatchModeState.stack_.empty()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
}
// UTIL
bool dispatch_mode_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
TorchDispatchModeTLS::stack_len() > 0;
}
} // namespace impl
} // namespace c10