2020-01-15 19:12:17 +00:00
|
|
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
2019-09-11 15:55:51 +00:00
|
|
|
|
2024-01-31 00:32:35 +00:00
|
|
|
namespace c10::impl {
|
2019-09-11 15:55:51 +00:00
|
|
|
|
2021-03-31 17:46:38 +00:00
|
|
|
// NB: POD, must be zero initialized!
|
|
|
|
|
// Note [TLS Initialization]
|
|
|
|
|
// We wanted raw_local_dispatch_key_set to be initialized with non-zero state
|
2021-05-02 05:55:12 +00:00
|
|
|
// e.g. BackendSelect and ADInplaceOrView in included set. But certain Windows
|
2021-03-31 17:46:38 +00:00
|
|
|
// compiler (e.g the one used in ARVR tests) only allow TLS to be
|
|
|
|
|
// zero-initialized. To preserve the invariant that raw TLS storage of the
|
|
|
|
|
// default state is zero, we obtain the actual include keyset by XORing
|
|
|
|
|
// raw_local_dispatch_key_set.included_ with c10::default_included_set. This
|
|
|
|
|
// logic is encapsulated in struct PODLocalDispatchKeySet.
|
2021-03-29 18:16:19 +00:00
|
|
|
thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
2019-09-11 15:55:51 +00:00
|
|
|
|
2021-09-13 17:48:55 +00:00
|
|
|
#if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
|
2020-12-14 20:43:59 +00:00
|
|
|
LocalDispatchKeySet tls_local_dispatch_key_set() {
|
|
|
|
|
return raw_local_dispatch_key_set;
|
|
|
|
|
}
|
2021-09-13 17:48:55 +00:00
|
|
|
#endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
|
2020-12-14 20:43:59 +00:00
|
|
|
|
2020-04-01 08:51:34 +00:00
|
|
|
void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
|
2021-03-31 17:46:38 +00:00
|
|
|
raw_local_dispatch_key_set.set_included(key_set.included_);
|
|
|
|
|
raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
|
2020-04-01 08:51:34 +00:00
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
// An RAII guard could snapshot and restore the entire state (entire
|
2020-08-08 23:24:55 +00:00
|
|
|
// DispatchKeySet) as opposed to only snapshotting and restoring the state of
|
2019-11-21 15:10:14 +00:00
|
|
|
// its assigned DispatchKeySet. I'm not sure which is better. If only the RAII
|
|
|
|
|
// API is used, the two choices are not distinguishable.
|
|
|
|
|
//
|
2020-01-15 19:12:17 +00:00
|
|
|
// However, if the guard chooses to snapshot and restore the entire
|
2019-11-21 15:10:14 +00:00
|
|
|
// DispatchKeySet, the interaction with the non-RAII API changes. Consider this
|
|
|
|
|
// sequence of events:
|
2020-08-08 23:24:55 +00:00
|
|
|
// - An RAII guard is declared for a particular DispatchKeySet, but snapshots
|
|
|
|
|
// the entire
|
2020-01-15 19:12:17 +00:00
|
|
|
// current DispatchKeySet.
|
2020-08-08 23:24:55 +00:00
|
|
|
// - A call to the non-RAII API changes the state for DispatchKeys outside the
|
|
|
|
|
// assigned
|
|
|
|
|
// set.
|
2020-01-15 19:12:17 +00:00
|
|
|
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it
|
|
|
|
|
// snapshotted
|
|
|
|
|
// (which restores the state for its own assigned DispatchKey and wipes out
|
2020-08-08 23:24:55 +00:00
|
|
|
// the state for the other DispatchKeys set by the non-RAII API).
|
2019-11-21 15:10:14 +00:00
|
|
|
|
|
|
|
|
// RAII API
|
Tests for fallback boxed dispatch (including TLS mode) (#26719)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26719
This PR adds a pair of tests for fallback boxed dispatch, exercising two different ways you might use it: (1) to implement a "wrapper" tensor type (e.g., LazyTensor, NestedTensor), and (2) to implement a toggleable "mode" (e.g., Profiling, Tracing). Both implement the most trivial possible implementations of their type: they "wrap" a real tensor simply forward along to the real implementation. This PR also adds the necessary feature support for toggleable mode, which is in the original generic dispatch abstraction design, but was not previously implemented. I had not originally intended to add this, but it turns out writing a new "mode" is a lot simpler than writing a "wrapper" type, so I ended up writing the mode version first.
General structure of the PR:
* Add two new testing tensor type ids, `TESTING_ONLY_GenericWrapperTensorId` and `TESTING_ONLY_GenericModeTensorId`, which our tests use. They might find other use in other tests if necessary.
* Add support for toggling the availability of `TESTING_ONLY_GenericModeTensorId`. Introduces a new thread local variable accessible by `tls_local_tensor_type_set()` which is considered as part of dispatch.
* The mode fallback is very simple: it increments a counter and then passes on the call to the underlying kernel by invoking the JIT.
* The wrapper fallback is more complex: it parses the arguments, unwrapping any wrapped tensor arguments, then invokes the JIT, and then rewraps the outputs.
The examples here are somewhat simplistic; there are a number of engineering improvements that could be applied. We could save these for later (landing this patch to get immediate testing), or incorporate them into this patch:
* `getOperator` is horrible. Bram Wasti and I discussed a plan for how to make this easier, by simply refactoring the JIT interface.
* `GenericWrapperTensorImpl` doesn't populate all of its fields accurately. Most notably, size is not setup correctly.
* `generic_wrapper_fallback` should handle tensor lists in arguments and returns properly.
One pitfall: fallback dispatch only works with non-c10 code. That's why I test using `batch_norm`.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D17549624
Test Plan: Imported from OSS
Pulled By: ezyang
fbshipit-source-id: 57dbdd8d6812a66082aa6db2934c8edcda340ea6
2019-10-09 19:18:55 +00:00
|
|
|
|
2020-08-08 23:24:55 +00:00
|
|
|
IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include)
|
|
|
|
|
: tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) {
|
|
|
|
|
if (!include_.empty()) {
|
|
|
|
|
tls_->set_included(tls_->included() | include_);
|
Tests for fallback boxed dispatch (including TLS mode) (#26719)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26719
This PR adds a pair of tests for fallback boxed dispatch, exercising two different ways you might use it: (1) to implement a "wrapper" tensor type (e.g., LazyTensor, NestedTensor), and (2) to implement a toggleable "mode" (e.g., Profiling, Tracing). Both implement the most trivial possible implementations of their type: they "wrap" a real tensor simply forward along to the real implementation. This PR also adds the necessary feature support for toggleable mode, which is in the original generic dispatch abstraction design, but was not previously implemented. I had not originally intended to add this, but it turns out writing a new "mode" is a lot simpler than writing a "wrapper" type, so I ended up writing the mode version first.
General structure of the PR:
* Add two new testing tensor type ids, `TESTING_ONLY_GenericWrapperTensorId` and `TESTING_ONLY_GenericModeTensorId`, which our tests use. They might find other use in other tests if necessary.
* Add support for toggling the availability of `TESTING_ONLY_GenericModeTensorId`. Introduces a new thread local variable accessible by `tls_local_tensor_type_set()` which is considered as part of dispatch.
* The mode fallback is very simple: it increments a counter and then passes on the call to the underlying kernel by invoking the JIT.
* The wrapper fallback is more complex: it parses the arguments, unwrapping any wrapped tensor arguments, then invokes the JIT, and then rewraps the outputs.
The examples here are somewhat simplistic; there are a number of engineering improvements that could be applied. We could save these for later (landing this patch to get immediate testing), or incorporate them into this patch:
* `getOperator` is horrible. Bram Wasti and I discussed a plan for how to make this easier, by simply refactoring the JIT interface.
* `GenericWrapperTensorImpl` doesn't populate all of its fields accurately. Most notably, size is not setup correctly.
* `generic_wrapper_fallback` should handle tensor lists in arguments and returns properly.
One pitfall: fallback dispatch only works with non-c10 code. That's why I test using `batch_norm`.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D17549624
Test Plan: Imported from OSS
Pulled By: ezyang
fbshipit-source-id: 57dbdd8d6812a66082aa6db2934c8edcda340ea6
2019-10-09 19:18:55 +00:00
|
|
|
}
|
2019-09-11 15:55:51 +00:00
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
|
2020-08-08 23:24:55 +00:00
|
|
|
if (!include_.empty()) {
|
|
|
|
|
tls_->set_included(tls_->included() - include_);
|
Tests for fallback boxed dispatch (including TLS mode) (#26719)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26719
This PR adds a pair of tests for fallback boxed dispatch, exercising two different ways you might use it: (1) to implement a "wrapper" tensor type (e.g., LazyTensor, NestedTensor), and (2) to implement a toggleable "mode" (e.g., Profiling, Tracing). Both implement the most trivial possible implementations of their type: they "wrap" a real tensor simply forward along to the real implementation. This PR also adds the necessary feature support for toggleable mode, which is in the original generic dispatch abstraction design, but was not previously implemented. I had not originally intended to add this, but it turns out writing a new "mode" is a lot simpler than writing a "wrapper" type, so I ended up writing the mode version first.
General structure of the PR:
* Add two new testing tensor type ids, `TESTING_ONLY_GenericWrapperTensorId` and `TESTING_ONLY_GenericModeTensorId`, which our tests use. They might find other use in other tests if necessary.
* Add support for toggling the availability of `TESTING_ONLY_GenericModeTensorId`. Introduces a new thread local variable accessible by `tls_local_tensor_type_set()` which is considered as part of dispatch.
* The mode fallback is very simple: it increments a counter and then passes on the call to the underlying kernel by invoking the JIT.
* The wrapper fallback is more complex: it parses the arguments, unwrapping any wrapped tensor arguments, then invokes the JIT, and then rewraps the outputs.
The examples here are somewhat simplistic; there are a number of engineering improvements that could be applied. We could save these for later (landing this patch to get immediate testing), or incorporate them into this patch:
* `getOperator` is horrible. Bram Wasti and I discussed a plan for how to make this easier, by simply refactoring the JIT interface.
* `GenericWrapperTensorImpl` doesn't populate all of its fields accurately. Most notably, size is not setup correctly.
* `generic_wrapper_fallback` should handle tensor lists in arguments and returns properly.
One pitfall: fallback dispatch only works with non-c10 code. That's why I test using `batch_norm`.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D17549624
Test Plan: Imported from OSS
Pulled By: ezyang
fbshipit-source-id: 57dbdd8d6812a66082aa6db2934c8edcda340ea6
2019-10-09 19:18:55 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-08-08 23:24:55 +00:00
|
|
|
ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude)
|
2020-08-06 20:13:37 +00:00
|
|
|
: tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) {
|
|
|
|
|
if (!exclude_.empty()) {
|
|
|
|
|
tls_->set_excluded(tls_->excluded() | exclude_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-08-08 23:24:55 +00:00
|
|
|
ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
|
2020-08-06 20:13:37 +00:00
|
|
|
if (!exclude_.empty()) {
|
|
|
|
|
tls_->set_excluded(tls_->excluded() - exclude_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-21 15:10:14 +00:00
|
|
|
// Non-RAII API
|
2020-01-15 19:12:17 +00:00
|
|
|
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h
|
|
|
|
|
// for details.
|
2019-11-21 15:10:14 +00:00
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
bool tls_is_dispatch_key_excluded(DispatchKey x) {
|
|
|
|
|
return raw_local_dispatch_key_set.excluded().has(x);
|
2019-11-21 15:10:14 +00:00
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) {
|
|
|
|
|
auto* tls = &raw_local_dispatch_key_set;
|
2019-11-21 15:10:14 +00:00
|
|
|
bool current_state = tls->excluded().has(x);
|
|
|
|
|
if (desired_state != current_state) {
|
|
|
|
|
if (desired_state) {
|
|
|
|
|
tls->set_excluded(tls->excluded().add(x));
|
|
|
|
|
} else {
|
|
|
|
|
tls->set_excluded(tls->excluded().remove(x));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
bool tls_is_dispatch_key_included(DispatchKey x) {
|
|
|
|
|
return raw_local_dispatch_key_set.included().has(x);
|
2019-11-21 15:10:14 +00:00
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:12:17 +00:00
|
|
|
void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) {
|
|
|
|
|
auto* tls = &raw_local_dispatch_key_set;
|
2019-11-21 15:10:14 +00:00
|
|
|
bool current_state = tls->included().has(x);
|
|
|
|
|
if (desired_state != current_state) {
|
|
|
|
|
if (desired_state) {
|
|
|
|
|
tls->set_included(tls->included().add(x));
|
|
|
|
|
} else {
|
|
|
|
|
tls->set_included(tls->included().remove(x));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-03-31 17:46:38 +00:00
|
|
|
bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) {
|
|
|
|
|
return raw_local_dispatch_key_set.excluded().isSupersetOf(ks);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool tls_is_dispatch_keyset_included(DispatchKeySet ks) {
|
|
|
|
|
return raw_local_dispatch_key_set.included().isSupersetOf(ks);
|
|
|
|
|
}
|
2024-01-31 00:32:35 +00:00
|
|
|
} // namespace c10::impl
|