mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fake_tensor.py had mypy error ignored. That seems less than desirable. Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees). Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428 Approved by: https://github.com/malfet
67 lines
2.2 KiB
C++
67 lines
2.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <c10/macros/Export.h>
|
|
|
|
namespace c10::impl {
|
|
|
|
enum class TorchDispatchModeKey : int8_t {
|
|
FAKE,
|
|
PROXY,
|
|
FUNCTIONAL,
|
|
NUM_MODE_KEYS
|
|
};
|
|
|
|
using PyObject_TorchDispatchMode = SafePyObjectT<TorchDispatchModeKey>;
|
|
|
|
struct C10_API TorchDispatchModeTLS {
|
|
// This API is NOT invariant safe.
|
|
// It must not take in an infra mode that uses TorchDispatchModeKey
|
|
// If you're pushing an infra mode onto the stack, we expect
|
|
// you to use set_mode
|
|
static void push_non_infra_mode_onto_stack(
|
|
std::shared_ptr<PyObject_TorchDispatchMode> mode);
|
|
// Pops the top mode of the stack,
|
|
// giving precedence to user modes before attempting to pop
|
|
// any infra modes
|
|
static const std::shared_ptr<PyObject_TorchDispatchMode> pop_stack();
|
|
// Returns the highest-priority infra mode on the stack,
|
|
// along with its mode key.
|
|
static const std::
|
|
tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
|
|
pop_highest_infra_mode();
|
|
|
|
static const std::shared_ptr<PyObject_TorchDispatchMode>& get_stack_at(
|
|
int64_t idx);
|
|
static int64_t stack_len();
|
|
|
|
static const c10::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
|
|
get_mode(TorchDispatchModeKey mode_key);
|
|
static const c10::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
|
|
unset_mode(TorchDispatchModeKey mode_key);
|
|
static void set_mode(
|
|
const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
|
|
TorchDispatchModeKey mode_key);
|
|
|
|
static const TorchDispatchModeTLS& get_state();
|
|
static void set_state(TorchDispatchModeTLS state);
|
|
|
|
static bool any_modes_set(bool skip_infra_modes = false);
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<PyObject_TorchDispatchMode>> stack_;
|
|
// Users are allowed to push multiple ProxyTorchDispatchMode objects onto the
|
|
// stack
|
|
// However, we only allow a single FakeTensorMode onto the stack at a time
|
|
// (Pushing additional FakeTensorModes onto the stack is a no-op)
|
|
std::array<
|
|
c10::optional<std::shared_ptr<PyObject_TorchDispatchMode>>,
|
|
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS)>
|
|
infra_modes_;
|
|
};
|
|
|
|
C10_API bool dispatch_mode_enabled();
|
|
|
|
C10_API std::string to_string(TorchDispatchModeKey mode_key);
|
|
|
|
} // namespace c10::impl
|