Precompute entries in dispatch tables (#40512)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40512

Fixes https://github.com/pytorch/pytorch/issues/32454

The heart of this diff is changing this:

```
inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatchKey) c
nst {
  const KernelFunction* backendKernel = dispatchTable.lookup(dispatchKey);

  if (nullptr != backendKernel) {
    return *backendKernel;
  }

  const auto& backendFallbackKernel = backendFallbackKernels_[dispatchKey];
  if (backendFallbackKernel.isValid()) {
    return backendFallbackKernel;
  }

  const KernelFunction* catchallKernel = dispatchTable.lookupCatchallKernel();
  if (C10_LIKELY(nullptr != catchallKernel)) {
    return *catchallKernel;
  }

  reportError(dispatchTable, dispatchKey);
}
```

to this:

```
const KernelFunction& OperatorEntry::lookup(DispatchKey k) const {
  const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
  if (C10_UNLIKELY(!kernel.isValid())) {
    reportError(k);
  }
  return kernel;
}
```

The difference is that instead of checking a bunch of places to find the
right kernel to use for an operator, all of the operators are
precomputed into dispatchTable_ itself (so you don't have to consult
anything else at runtime.)  OperatorEntry::computeDispatchTableEntry
contains that computation (which is exactly the same as it was before.)
By doing this, we are able to substantially simplify many runtime
components of dispatch.

The diff is fairly large, as there are also some refactors interspersed
with the substantive change:

- I deleted the DispatchTable abstraction, folding it directly into
  OperatorEntry.  It might make sense to have some sort of DispatchTable
  abstraction (if only to let you do operator[] on DispatchKey without
  having to cast it to integers first), but I killed DispatchTable to
  avoid having to design a new abstraction; the old abstraction wasn't
  appropriate for the new algorithm.

- I renamed OperatorEntry::KernelEntry to AnnotatedKernel, and use it
  to store backend fallbacks as well as regular kernel registrations
  (this improves error messages when you incorrectly register a backend
  fallback twice).

- I moved schema_ and debug_ into an AnnotatedSchema type, to make the
  invariant clearer that these are set together, or not at all.

- I moved catch-all kernels out of kernels_ into its own property
  (undoing a refactor I did before).  The main reason I did this was
  because our intended future state is to not have a single catch-all,
  but rather possibly multiple catch-alls which fill-in different
  portions of the dispatch table.  This may change some more in
  the future: if we allow registrations for multiple types of
  catch alls, we will need a NEW data type (representing bundles
  of dispatch keys) which can represent this case, or perhaps
  overload DispatchKey to also record these types.

The key changes for precomputation:

- OperatorEntry::updateDispatchTable_ is now updated to fill in the
  entry at a DispatchKey, considering both kernels (what it did
  before) as well as catch-all and backend fallback.  There is also
  OperatorEntry::updateDispatchTableFull_ which will update the
  entire dispatch table (which is necessary when someone sets a
  catch-all kernel).  OperatorEntry::computeDispatchTableEntry
  holds the canonical algorithm specifying how we decide what
  function will handle a dispatch key for the operator.

- Because dispatch table entry computation requires knowledge of
  what backend fallbacks are (which is recorded in Dispatcher,
  not OperatorEntry), several functions on OperatorEntry now
  take Dispatcher as an argument so they can query this information.

- I modified the manual boxing wrapper invariant: previously, kernels
  stored in kernels_ did NOT have manual boxing wrappers and this
  was maintained by DispatchTable.  Now, we just ALWAYS maintain
  manual boxing wrappers for all KernelFunctions we store.

- DispatchKeyExtractor is greatly simplified: we only need to maintain
  a single per-operator bitmask of what entries are fallthrough
  (we don't need the global bitmask anymore).

- Introduced a new debugging 'dumpComputedTable' method, which prints
  out the computed dispatch table, and how we computed it to be some way.
  This was helpful for debugging cases when the dispatch table and
  the canonical metadata were not in sync.

Things that I didn't do but would be worth doing at some point:

- I really wanted to get rid of the C10_UNLIKELY branch for
  whether or not the KernelFunction is valid, but it looks like
  I cannot easily do this while maintaining good error messages.
  In principle, I could always populate a KernelFunction which
  errors, but the KernelFunction needs to know what the dispatch
  key that is missing is (this is not passed in from the
  calling convention).  Actually, it might be possible to do
  something with functors, but I didn't do it here.

- If we are going to get serious about catchalls for subsets of
  operators, we will need to design a new API for them.  This diff
  is agnostic to this question; we don't change public API at all.

- Precomputation opens up the possibility of subsuming DispatchStub
  by querying CPU capability when filling in the dispatch table.
  This is not implemented yet. (There is also a mild blocker here,
  which is that DispatchStub is also used to share TensorIterator
  configuration, and this cannot be directly supported by the
  regular Dispatcher.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D22236352

Pulled By: ezyang

fbshipit-source-id: d6d90f267078451816b1899afc3f79737b4e128c
This commit is contained in:
Edward Yang 2020-06-26 08:59:33 -07:00 committed by Facebook GitHub Bot
parent a4cabd1a3c
commit a0ba7fb43e
10 changed files with 421 additions and 568 deletions

View file

@ -4,19 +4,11 @@
namespace c10 {
void DispatchKeyExtractor::setOperatorHasKernelForBackend(DispatchKey k, bool has_kernel) {
if (has_kernel) {
operatorHasKernelForBackend_ = operatorHasKernelForBackend_.add(k);
} else {
operatorHasKernelForBackend_ = operatorHasKernelForBackend_.remove(k);
}
}
void DispatchKeyExtractor::setOperatorHasFallthroughForBackend(DispatchKey k, bool has_fallthrough) {
void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
if (has_fallthrough) {
operatorHasFallthroughForBackend_ = operatorHasFallthroughForBackend_.add(k);
nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
} else {
operatorHasFallthroughForBackend_ = operatorHasFallthroughForBackend_.remove(k);
nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
}
}
@ -29,7 +21,7 @@ std::string DispatchKeyExtractor::dumpState() const {
oss << "0";
}
}
oss << " " << operatorHasKernelForBackend_ << "\n";
oss << " " << nonFallthroughKeys_ << "\n";
return oss.str();
}

View file

@ -28,9 +28,8 @@ static inline DispatchKey dispatchTypeId(
// The key mask lets us eliminate (by zero entries) keys which should not
// be considered for dispatch. There are two cases when we use this:
//
// - If there is no operator registered for a backend whose fallback behavior
// is to fallthrough, we eliminate that backend from consideration (since
// we want to "fallthrough" to the next valid key.)
// - If an operator's dispatch table contains a fallthrough entry, we
// should bypass it entirely when finding the key
// - If a user invokes with redispatch, the mask lets us
// zero out the key the user asked us to stop.
//
@ -119,7 +118,7 @@ public:
dispatch_arg_indices_reverse_ = c10::utils::bitset();
}
DispatchKey getDispatchKeyBoxed(DispatchKeySet backendsWithoutFallthrough, const torch::jit::Stack* stack) const {
DispatchKey getDispatchKeyBoxed(const torch::jit::Stack* stack) const {
DispatchKeySet ks;
dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
@ -133,19 +132,16 @@ public:
}
}
});
return dispatchKeySetToDispatchKey_(backendsWithoutFallthrough, DispatchKeySet::FULL, ks);
return dispatchKeySetToDispatchKey_(DispatchKeySet::FULL, ks);
}
template<class... Args>
DispatchKey getDispatchKeyUnboxed(DispatchKeySet backendsWithoutFallthrough, DispatchKeySet eligibleKeys, const Args&... args) const {
DispatchKey getDispatchKeyUnboxed(DispatchKeySet eligibleKeys, const Args&... args) const {
auto ks = detail::multi_dispatch_key_set(args...);
return dispatchKeySetToDispatchKey_(backendsWithoutFallthrough, eligibleKeys, ks);
return dispatchKeySetToDispatchKey_(eligibleKeys, ks);
}
// Used by DispatchTable to maintain the fallthrough invariant, see
// docs on operatorHasKernelForBackend_
void setOperatorHasKernelForBackend(DispatchKey k, bool has_kernel);
void setOperatorHasFallthroughForBackend(DispatchKey k, bool has_fallthrough);
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
std::string dumpState() const;
void checkInvariants(const FunctionSchema& schema) const;
@ -166,22 +162,13 @@ private:
// NB: If there is no valid dispatch key, this will return Undefined
DispatchKey dispatchKeySetToDispatchKey_(
DispatchKeySet backendsWithoutFallthrough,
// This is often known statically to be all ones; IN OPTIMIZER WE TRUST
DispatchKeySet eligibleKeys,
DispatchKeySet ks
) const {
return impl::dispatchTypeId(ks,
// We must NOT respect the passed in backendsWithoutFallthrough if an operator has
// specifically overridden the backend, since that means we've opted to
// not fallthrough and instead apply some specific behavior (which we
// must dispatch to).
//
// This scheme doesn't work if you want to also apply fallthrough on a
// per-op basis, but while we could directly fix this by maintaining a
// second DispatchKeySet, it doesn't seem that there is any actual use case,
// so we are deferring it for #32454.
((backendsWithoutFallthrough | operatorHasKernelForBackend_) - operatorHasFallthroughForBackend_)
// Keys that are fallthrough should be skipped
nonFallthroughKeys_
// Regardless of fallthrough behavior, only accept keys which are eligible
// for dispatch, as requested by the user
& eligibleKeys);
@ -189,8 +176,7 @@ private:
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
, operatorHasKernelForBackend_()
, operatorHasFallthroughForBackend_() {}
, nonFallthroughKeys_(DispatchKeySet::FULL) {}
// this is a bitset that has ones for each argument index which has to be
// considered for dispatch. This avoids having to iterate over the stack
@ -202,10 +188,8 @@ private:
// fallthrough
c10::utils::bitset dispatch_arg_indices_reverse_;
// Set of backends for which the operator has explicitly registered a kernel.
DispatchKeySet operatorHasKernelForBackend_;
// Set of backends for which the operator has explicitly registered a fallthrough kernel.
DispatchKeySet operatorHasFallthroughForBackend_;
// Set of keys for which the operator does NOT have fallthrough kernel.
DispatchKeySet nonFallthroughKeys_;
};
}

View file

@ -1,27 +0,0 @@
#include <ATen/core/dispatch/DispatchTable.h>
#include <sstream>
namespace c10 {
namespace impl {
std::string KernelFunctionTable::dumpState() const {
std::ostringstream oss;
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
if (kernels_[i].isValid()) oss << " " << kernels_[i].dumpState() << "\n";
}
return oss.str();
}
}
std::string DispatchTable::dumpState() const {
std::ostringstream oss;
oss << "table:\n";
oss << kernels_.dumpState();
oss << " catchall: " << catchallKernel_.dumpState() << "\n";
oss << " extractor: " << dispatchKeyExtractor_.dumpState() << "\n";
oss << " name: " << operatorName_ << "\n";
return oss.str();
}
} // namespace c10

View file

@ -1,247 +0,0 @@
#pragma once
#include <ATen/core/function_schema.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/either.h>
#include <c10/core/DispatchKey.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
#include <array>
#include <atomic>
#include <iostream>
#include <mutex>
#include <type_traits>
#include <sstream>
#include <unordered_map>
#include <functional>
namespace c10 {
namespace impl {
/**
* A KernelFunctionTable is a map from DispatchKey to a KernelFunction.
* It can store zero or one KernelFunctions for each DispatchKey.
*/
class KernelFunctionTable final {
public:
explicit KernelFunctionTable()
: kernels_()
, kernelCount_(0) {}
void setKernel(DispatchKey dispatchKey, KernelFunction kernel) {
TORCH_INTERNAL_ASSERT(dispatchKey != DispatchKey::Undefined);
auto& slot = kernels_[static_cast<uint8_t>(dispatchKey)];
if (!slot.isValid()) {
++kernelCount_;
}
slot = std::move(kernel);
}
void removeKernelIfExists(DispatchKey dispatchKey) {
auto& slot = kernels_[static_cast<uint8_t>(dispatchKey)];
if (slot.isValid()) {
--kernelCount_;
slot = {};
} else {
}
}
const KernelFunction& operator[](DispatchKey dispatchKey) const {
return kernels_[static_cast<uint8_t>(dispatchKey)];
}
KernelFunction& operator[](DispatchKey dispatchKey) {
return kernels_[static_cast<uint8_t>(dispatchKey)];
}
size_t size() const {
return kernelCount_;
}
std::string dumpState() const;
private:
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> kernels_;
size_t kernelCount_;
};
}
/**
* Per-operator dispatch table.
*
* Given an operator specified by a FunctionSchema, this class records a dispatch
* table for various kernels provided for this operator. For example, if we
* consider the operator add(Tensor, Tensor), the dispatch table for this
* operator may contain implementations for various dynamic tensor types, such
* as CPU, CUDA, etc.
*/
class DispatchTable final {
public:
explicit DispatchTable(const FunctionSchema& schema)
: kernels_()
, catchallKernel_()
, dispatchKeyExtractor_(DispatchKeyExtractor::make(schema))
, operatorName_(schema.operator_name()) {}
// a dispatch table may be default constructed with only an
// operator name. Such a dispatch table is not callable until
// the schema is provided
DispatchTable(OperatorName op_name)
: kernels_()
, catchallKernel_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, operatorName_(std::move(op_name)) {}
/**
* Register a kernel in the table at some dispatch key.
* @param dispatch_key Dispatch key to define when this kernel is selected.
* @param kernel Concrete kernel function implementation to register
*/
void setKernel(DispatchKey dispatchKey, KernelFunction kernel) {
if (manuallyBoxedKernel_.has_value()) {
kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_);
}
kernels_.setKernel(dispatchKey, std::move(kernel));
dispatchKeyExtractor_.setOperatorHasKernelForBackend(dispatchKey, true);
if (kernel.isFallthrough()) {
dispatchKeyExtractor_.setOperatorHasFallthroughForBackend(dispatchKey, true);
}
}
/**
* Deregister the kernel for some dispatch key.
*
* @param dispatch_key Dispatch key to unregister.
*/
void removeKernelIfExists(DispatchKey dispatchKey) {
kernels_.removeKernelIfExists(dispatchKey);
dispatchKeyExtractor_.setOperatorHasKernelForBackend(dispatchKey, false);
dispatchKeyExtractor_.setOperatorHasFallthroughForBackend(dispatchKey, false); // may be no op
}
/**
* Register a catch-all kernel that is called for this operator
* independent of the inputs. An operator can have either
* a catch-all kernel or a set of kernels with concrete
* dispatch keys, not both.
*/
void setCatchallKernel(KernelFunction kernel) {
if (manuallyBoxedKernel_.has_value()) {
kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_);
}
catchallKernel_ = std::move(kernel);
}
/**
* Remove the catch-all kernel.
*/
void removeCatchallKernel() {
catchallKernel_ = {};
}
bool isEmpty() const {
return !catchallKernel_.isValid() && kernels_.size() == 0;
}
std::string listAllDispatchKeys() const {
std::ostringstream str;
str << "[";
bool has_kernels = false;
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
if (!kernels_[static_cast<DispatchKey>(iter)].isValid()) {
continue;
}
if (has_kernels) {
str << ", ";
}
str << static_cast<DispatchKey>(iter);
has_kernels = true;
}
if (catchallKernel_.isValid()) {
if (has_kernels) {
str << ", ";
}
str << "CATCH-ALL";
}
str << "]";
return str.str();
}
const KernelFunction* lookup(DispatchKey dispatchKey) const {
auto& slot = kernels_[dispatchKey];
// TODO: this condition shouldn't be necessary
if (slot.isValid()) {
return &slot;
} else {
return nullptr;
}
}
const KernelFunction* lookupCatchallKernel() const {
// TODO: this condition shouldn't be necessary
if (!catchallKernel_.isValid()) {
return nullptr;
}
return &catchallKernel_;
}
const DispatchKeyExtractor& dispatchKeyExtractor() const {
return dispatchKeyExtractor_;
}
const OperatorName& operatorName() const {
return operatorName_;
}
void registerSchema(const FunctionSchema& schema) {
dispatchKeyExtractor_.registerSchema(schema);
}
void deregisterSchema() {
dispatchKeyExtractor_.deregisterSchema();
}
std::string dumpState() const;
// This function is a temporary hack, see comment at manuallyBoxedKernel_ member
void setManuallyBoxedKernel_(KernelFunction::InternalBoxedKernelFunction* func) {
TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_.has_value(), "Cannot set multiple manually boxed kernels for the same operator ", operatorName_);
manuallyBoxedKernel_ = func;
// make sure that all previously registered kernels get this manually boxed kernel
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
auto& kernel = kernels_[static_cast<DispatchKey>(iter)];
if (kernel.isValid()) {
kernel.setManuallyBoxedKernel_(func);
}
}
if (catchallKernel_.isValid()) {
catchallKernel_.setManuallyBoxedKernel_(func);
}
}
c10::optional<KernelFunction::InternalBoxedKernelFunction*> manuallyBoxedKernel() const {
return manuallyBoxedKernel_;
}
private:
impl::KernelFunctionTable kernels_;
KernelFunction catchallKernel_;
DispatchKeyExtractor dispatchKeyExtractor_;
OperatorName operatorName_;
// This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
// unboxing wrapper for aten operators. We still need those for some operators because not all work
// with the templated unboxing logic yet.
// TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic
c10::optional<KernelFunction::InternalBoxedKernelFunction*> manuallyBoxedKernel_;
};
} // namespace c10

View file

@ -38,7 +38,6 @@ Dispatcher::Dispatcher()
: operators_()
, operatorLookupTable_()
, backendFallbackKernels_()
, backendsWithoutFallthrough_(DispatchKeySet::FULL)
, listeners_(std::make_unique<detail::RegistrationListenerList>())
, mutex_() {}
@ -206,7 +205,14 @@ RegistrationHandleRAII Dispatcher::registerImpl(
auto op = findOrRegisterName_(op_name);
auto handle = op.operatorIterator_->op.registerKernel(dispatch_key, std::move(kernel), std::move(cpp_signature), std::move(inferred_function_schema), std::move(debug));
auto handle = op.operatorIterator_->op.registerKernel(
*this,
dispatch_key,
std::move(kernel),
std::move(cpp_signature),
std::move(inferred_function_schema),
std::move(debug)
);
++op.operatorIterator_->def_and_impl_count;
@ -215,10 +221,10 @@ RegistrationHandleRAII Dispatcher::registerImpl(
});
}
void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, std::list<impl::OperatorEntry::KernelEntry>::iterator handle) {
void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, std::list<impl::AnnotatedKernel>::iterator handle) {
std::lock_guard<std::mutex> lock(mutex_);
op.operatorIterator_->op.deregisterKernel_(dispatch_key, handle);
op.operatorIterator_->op.deregisterKernel_(*this, dispatch_key, handle);
TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
@ -249,8 +255,6 @@ void Dispatcher::deregisterName_(
// Test if the operator entry is completely dead, and if so remove it completely
void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) {
if (0 == op.operatorIterator_->def_and_impl_count) {
// TODO: rename this to "assert deregistration invariants"
op.operatorIterator_->op.prepareForDeregistration();
operators_.erase(op.operatorIterator_);
operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
operatorLookupTable.erase(op_name);
@ -261,14 +265,17 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name)
RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
std::lock_guard<std::mutex> lock(mutex_);
// TODO: preserve debug for old fallback
TORCH_CHECK(
!backendFallbackKernels_[dispatchKey].isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, " (", debug, ")"
!backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].debug, ", new registration ", debug
);
backendFallbackKernels_.setKernel(dispatchKey, std::move(kernel));
if (kernel.isFallthrough()) {
backendsWithoutFallthrough_ = backendsWithoutFallthrough_.remove(dispatchKey);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unobxed
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);
}
return RegistrationHandleRAII([this, dispatchKey] {
@ -279,8 +286,11 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
std::lock_guard<std::mutex> lock(mutex_);
backendFallbackKernels_.removeKernelIfExists(dispatchKey);
backendsWithoutFallthrough_ = backendsWithoutFallthrough_.add(dispatchKey);
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = {};
for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);
}
}
@ -300,38 +310,16 @@ RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr<OpReg
});
}
[[noreturn]] void Dispatcher::reportError(const DispatchTable& dispatchTable, DispatchKey dispatchKey) {
if (dispatchKey == DispatchKey::Undefined) {
TORCH_CHECK(false,
"There were no tensor arguments to this function (e.g., you passed an "
"empty list of Tensors), but no fallback function is registered for schema ", dispatchTable.operatorName(),
". This usually means that this function requires a non-empty list of Tensors. "
"Available functions are ", dispatchTable.listAllDispatchKeys())
}
const std::string dispatchKeyStr = toString(dispatchKey);
TORCH_CHECK(false, "Could not run '", dispatchTable.operatorName(), "' with arguments",
" from the '", dispatchKeyStr, "' backend. '",
dispatchTable.operatorName(), "' is only available for these backends: ",
dispatchTable.listAllDispatchKeys(), ".");
}
void Dispatcher::checkInvariants() const {
for (const auto& op : operators_) {
op.op.checkInvariants();
}
// NB: skip Undefined
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
auto k = static_cast<DispatchKey>(i);
if (!backendsWithoutFallthrough_.has(k)) {
const auto& kernel = backendFallbackKernels_[k];
TORCH_INTERNAL_ASSERT(kernel.isFallthrough());
}
}
}
void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func) {
op.operatorIterator_->op.setManuallyBoxedKernel_(func);
std::lock_guard<std::mutex> lock(mutex_);
op.operatorIterator_->op.setManuallyBoxedKernel_(*this, func);
// NB: Do not need to set manually boxed kernel for backend fallbacks
}
}

View file

@ -41,6 +41,9 @@ class SchemaRegistrationHandleRAII;
*/
class CAFFE2_API Dispatcher final {
private:
// For direct access to backend fallback information
friend class impl::OperatorEntry;
struct OperatorDef final {
explicit OperatorDef(OperatorName&& op_name)
: op(std::move(op_name)) {}
@ -206,26 +209,20 @@ private:
const OperatorHandle& op,
const OperatorName& op_name,
c10::optional<DispatchKey> dispatch_key,
std::list<impl::OperatorEntry::KernelEntry>::iterator kernel_handle);
std::list<impl::AnnotatedKernel>::iterator kernel_handle);
void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
void deregisterFallback_(DispatchKey dispatchKey);
void deregisterLibrary_(const std::string& ns);
void cleanup(const OperatorHandle& op, const OperatorName& op_name);
void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
[[noreturn]] static void reportError(const DispatchTable& dispatchTable, DispatchKey dispatchKey);
const KernelFunction& dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatch_key) const;
std::list<OperatorDef> operators_;
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
// Map from namespace to debug string (saying, e.g., where the library was defined)
ska::flat_hash_map<std::string, std::string> libraries_;
impl::KernelFunctionTable backendFallbackKernels_;
// Set of backends which have specified they do NOT want fallthrough behavior
// (we store the inverse because it avoids a negation when we use this for
// masking)
DispatchKeySet backendsWithoutFallthrough_;
std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;
std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_;
};
@ -268,6 +265,12 @@ public:
template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
// NB: This assert is not 100% sound: you can retrieve a typed() operator
// handle prior to ANY C++ signature being registered on the operator
// and the check will say everything is OK (at which point you can then
// smuggle in a kernel that is typed incorrectly). For everything
// in core library this won't happen, because all the static registrations
// will be done by the time a typed() handle is acquired.
operatorIterator_->op.assertSignatureIsCorrect<FuncType>();
return TypedOperatorHandle<FuncType>(operatorIterator_);
}
@ -324,57 +327,38 @@ template<class... Args> inline void unused_arg_(const Args&...) {}
template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(Args...)>& op, DispatchKey dispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
template<class Return, class... Args>
inline Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed<Args...>(backendsWithoutFallthrough_, DispatchKeySet::FULL, args...);
auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor()
.template getDispatchKeyUnboxed<Args...>(
DispatchKeySet::FULL,
args...
);
return callWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey currentDispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed<Args...>(
backendsWithoutFallthrough_,
DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey),
args...);
const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor()
.template getDispatchKeyUnboxed<Args...>(
DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey),
args...
);
return callWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}
inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& dispatchTable = op.operatorIterator_->op.dispatch_table();
auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(backendsWithoutFallthrough_, stack);
const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey);
const auto& entry = op.operatorIterator_->op;
auto dispatchKey = entry.dispatchKeyExtractor().getDispatchKeyBoxed(stack);
const auto& kernel = entry.lookup(dispatchKey);
kernel.callBoxed(op, stack);
}
inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatchKey) const {
const KernelFunction* backendKernel = dispatchTable.lookup(dispatchKey);
if (nullptr != backendKernel) {
return *backendKernel;
}
const auto& backendFallbackKernel = backendFallbackKernels_[dispatchKey];
if (backendFallbackKernel.isValid()) {
return backendFallbackKernel;
}
const KernelFunction* catchallKernel = dispatchTable.lookupCatchallKernel();
if (C10_LIKELY(nullptr != catchallKernel)) {
return *catchallKernel;
}
reportError(dispatchTable, dispatchKey);
}
} // namespace c10

View file

@ -1,11 +1,11 @@
#include <ATen/core/dispatch/OperatorEntry.h>
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace c10 {
namespace impl {
namespace {
std::string toString(c10::optional<DispatchKey> k) {
if (k.has_value()) {
return toString(*k);
@ -13,33 +13,21 @@ namespace {
return "(catch all)";
}
}
std::string listAllDispatchKeys(const ska::flat_hash_map<c10::optional<DispatchKey>, std::list<OperatorEntry::KernelEntry>>& kernels) {
if (kernels.size() == 0) {
return "";
}
std::ostringstream str;
str << toString(kernels.begin()->first);
for (auto iter = ++kernels.begin(); iter != kernels.end(); ++iter) {
str << ", " << toString(iter->first);
}
return str.str();
}
}
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
: name_(std::move(operator_name))
, schema_()
, debug_()
, dispatchTable_(name_)
, kernels_() {
}
void OperatorEntry::prepareForDeregistration() {
if (!dispatchTable_.isEmpty()) {
TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator is ", toString(name_), ". Registered kernels for dispatch keys: ", dispatchTable_.listAllDispatchKeys());
}
TORCH_INTERNAL_ASSERT(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_), ". The operator is ", toString(name_));
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, manuallyBoxedKernel_()
, kernels_()
, catchAllKernel_()
, cpp_signature_()
{
// Pick up any backend fallbacks that were registered prior to this
// OperatorEntry being created
updateDispatchTableFull_(c10::Dispatcher::singleton());
}
namespace {
@ -52,7 +40,7 @@ namespace {
*schema_difference);
}
}
}
} // anonymous namespace
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) {
TORCH_INTERNAL_ASSERT(!schema_.has_value());
@ -63,28 +51,37 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug)
}
}
}
for (auto j = catchAllKernel_.begin(); j != catchAllKernel_.end(); ++j) {
if (j->inferred_function_schema != nullptr) {
checkSchema(name_, schema, debug, *j->inferred_function_schema, j->debug);
}
}
// NB: don't register schema until after we've checked everything!
schema_ = std::move(schema);
debug_ = std::move(debug);
dispatchTable_.registerSchema(*schema_);
dispatchKeyExtractor_.registerSchema(schema);
schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
}
void OperatorEntry::deregisterSchema() {
TORCH_INTERNAL_ASSERT(schema_.has_value());
schema_ = c10::nullopt;
debug_ = c10::nullopt;
dispatchTable_.deregisterSchema();
dispatchKeyExtractor_.deregisterSchema();
}
std::list<OperatorEntry::KernelEntry>::iterator OperatorEntry::registerKernel(
std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
// NB: cpp_signature doesn't get cleared even after the kernel that populated
// it is deleted. This means you could poison the value of cpp_signature_
// with a bad signature value, and then it would permanently stay there until
// you deregister the schema. This can't really be fixed, because we
// only do a typed() test once in the lifetime of a TypedOperatorHandle,
// which means if you could validly change the type of a cpp_signature, then
// that would also invalidate the old TypedOperatorHandles.
if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_,
@ -98,116 +95,229 @@ std::list<OperatorEntry::KernelEntry>::iterator OperatorEntry::registerKernel(
}
if (schema_ && inferred_function_schema) {
checkSchema(name_, *schema_, *debug_, *inferred_function_schema, debug);
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
}
// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
auto& k = kernels_[dispatch_key];
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : catchAllKernel_;
if (k.size() > 0) {
TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
}
if (manuallyBoxedKernel_.has_value()) {
kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_);
}
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
std::list<OperatorEntry::KernelEntry>::iterator inserted = k.begin();
std::list<AnnotatedKernel>::iterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
// that the dispatch table points to the newest kernel
updateDispatchTable_(dispatch_key);
if (dispatch_key.has_value()) {
updateDispatchTable_(dispatcher, *dispatch_key);
} else {
updateDispatchTableFull_(dispatcher);
}
return inserted;
}
void OperatorEntry::deregisterKernel_(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
std::list<OperatorEntry::KernelEntry>::iterator kernel
std::list<AnnotatedKernel>::iterator kernel
) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
auto found = kernels_.find(dispatch_key);
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
if (dispatch_key.has_value()) {
auto found = kernels_.find(*dispatch_key);
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
}
updateDispatchTable_(dispatcher, *dispatch_key);
} else {
catchAllKernel_.erase(kernel);
updateDispatchTableFull_(dispatcher);
}
updateDispatchTable_(dispatch_key);
}
void OperatorEntry::updateDispatchTable_(c10::optional<DispatchKey> dispatch_key) {
// precondition: kernelsMutex_ is locked
void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
updateDispatchTable_(dispatcher, dispatch_key);
}
auto k = kernels_.find(dispatch_key);
if (dispatch_key.has_value()) {
if (k == kernels_.end()) {
dispatchTable_.removeKernelIfExists(*dispatch_key);
} else {
dispatchTable_.setKernel(*dispatch_key, k->second.front().kernel);
}
const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel;
}
std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
// 1. Operator registration
auto kern_it = kernels_.find(dispatch_key);
if (kern_it != kernels_.end()) {
TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
return {kern_it->second.front(), "kernel"};
// 2. Backend fallback
} else if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
// 3. Catch all
} else if (!catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
// 4. Default to error
} else {
if (k == kernels_.end()) {
dispatchTable_.removeCatchallKernel();
} else {
dispatchTable_.setCatchallKernel(k->second.front().kernel);
return {missingKernel_, "missing"};
}
}
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
}
}
void OperatorEntry::setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func) {
TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_);
manuallyBoxedKernel_ = func;
for (auto& kv : kernels_) {
for (auto& k : kv.second) {
k.kernel.setManuallyBoxedKernel_(func);
}
}
for (auto& k : catchAllKernel_) {
k.kernel.setManuallyBoxedKernel_(func);
}
// Refresh entries in dispatchTable_
updateDispatchTableFull_(dispatcher);
}
void OperatorEntry::checkInvariants() const {
if (schema_) {
TORCH_INTERNAL_ASSERT(schema_->operator_name() == name_);
dispatchTable_.dispatchKeyExtractor().checkInvariants(*schema_);
TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState());
dispatchKeyExtractor().checkInvariants(schema_->schema);
}
TORCH_INTERNAL_ASSERT(schema_.has_value() == debug_.has_value());
TORCH_INTERNAL_ASSERT(name_ == dispatchTable_.operatorName());
TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end(), dumpState());
for (const auto& kv : kernels_) {
auto mb_dispatch_key = kv.first;
TORCH_INTERNAL_ASSERT(kv.second.size() > 0);
auto* kernel = mb_dispatch_key ? dispatchTable_.lookup(*mb_dispatch_key) : dispatchTable_.lookupCatchallKernel();
auto manual_boxed_kernel = dispatchTable_.manuallyBoxedKernel();
// NB: this is a copy
auto local_kernel = kv.second.front().kernel;
if (manual_boxed_kernel.has_value()) {
local_kernel.setManuallyBoxedKernel_(*manual_boxed_kernel);
}
TORCH_INTERNAL_ASSERT(local_kernel._equalsBoxedAndUnboxed(*kernel));
TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState());
}
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast<DispatchKey>(iter));
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]),
"Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
"Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
}
}
std::string OperatorEntry::listAllDispatchKeys() const {
std::ostringstream str;
str << "[";
bool has_kernels = false;
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
if (!dispatchTable_[iter].isValid()) {
continue;
}
if (has_kernels) {
str << ", ";
}
str << static_cast<DispatchKey>(iter);
has_kernels = true;
}
str << "]";
return str.str();
}
void OperatorEntry::reportError(DispatchKey dispatchKey) const {
// If there is an invariant problem, report it now.
checkInvariants();
if (dispatchKey == DispatchKey::Undefined) {
TORCH_CHECK(false,
"There were no tensor arguments to this function (e.g., you passed an "
"empty list of Tensors), but no fallback function is registered for schema ", name_,
". This usually means that this function requires a non-empty list of Tensors. "
"Available functions are ", listAllDispatchKeys(), ".\n\n", dumpComputedTable())
}
TORCH_CHECK(false, "Could not run '", name_, "' with arguments",
" from the '", toString(dispatchKey), "' backend. '",
name_, "' is only available for these backends: ",
listAllDispatchKeys(), ".\n\n", dumpComputedTable());
}
// INSPECTING DISPATCHER STATE
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
// The dumper functions purposely do not check invariants, as you might be using
// them to debug situations where the invariants are violated.
// Inspect what the computed dispatch table would be (e.g., what
// updateDispatchTableFull_ would update the dispatch table to be)
std::string OperatorEntry::dumpComputedTable() const {
std::ostringstream oss;
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
auto k = static_cast<DispatchKey>(i);
auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
if (kernel_prov.first.kernel.isValid()) {
oss << toString(k) << ": "
<< (kernel_prov.first.kernel.isFallthrough() ? "fallthrough " : "")
<< kernel_prov.first.debug << " [" << kernel_prov.second << "]\n";
}
}
return oss.str();
}
// Inspect the "canonical" information in OperatorEntry. This only prints out
// *non-derived* information; i.e., what the source of truth says about the
// operator. This dumping function is appropriate for expect tests.
// This WON'T report backend fallbacks.
std::string OperatorEntry::dumpState() const {
std::ostringstream oss;
oss << "name: " << name_ << "\n";
if (schema_) {
oss << "schema: " << *schema_ << "\n";
oss << "debug: " << *debug_ << "\n";
oss << "alias analysis kind: " << toString(schema_->aliasAnalysis()) << (schema_->isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
oss << "schema: " << schema_->schema << "\n";
oss << "debug: " << schema_->debug << "\n";
oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
} else {
oss << "schema: (none)\n";
}
// Iterate over DispatchKey, not the flat hash map, so we have a stable order
auto print_key = [&](c10::optional<DispatchKey> k) {
auto it = kernels_.find(k);
if (it != kernels_.end()) {
int64_t i = 0;
for (const auto& jt : it->second) {
oss << (k ? toString(k) : "catchall")
<< (i > 0 ? " (inactive)" : "")
<< ": "
<< jt.debug << " :: "
<< toString(*jt.inferred_function_schema) << " [ " << jt.kernel.dumpState() << "]\n";
i++;
}
auto print_kernel = [&](const char* k_desc, const std::list<AnnotatedKernel>& jts) {
int64_t i = 0;
for (const auto& jt : jts) {
oss << k_desc
<< (i > 0 ? " (inactive)" : "")
<< ": "
<< jt.debug << " :: "
<< (jt.inferred_function_schema ? toString(*jt.inferred_function_schema) : "(none)")
<< " [ " << jt.kernel.dumpState() << "]\n";
i++;
}
};
// Iterate over DispatchKey, not the flat hash map, so we have a stable order
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
print_key(static_cast<DispatchKey>(i));
auto k = static_cast<DispatchKey>(i);
auto it = kernels_.find(k);
if (it != kernels_.end()) {
print_kernel(toString(k), it->second);
}
}
print_key(c10::nullopt);
// dispatch table is 100% specified by OperatorEntry; so if you want to check
// if it makes sense use checkInvariants
// oss << dispatchTable_.dumpState();
print_kernel("catchall", catchAllKernel_);
return oss.str();
}

View file

@ -1,36 +1,67 @@
#pragma once
#include <ATen/core/dispatch/DispatchTable.h>
#include <ATen/core/function_schema.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/either.h>
#include <c10/core/DispatchKey.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <list>
#include <array>
namespace c10 {
namespace impl {
class OperatorEntry;
}
class Dispatcher;
namespace impl {
// This is a private class used inside the Dispatcher to represent an operator
// and its dispatch table. This is not part of the public API.
// This data structure represents a kernel that was registered to us from a
// user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata
// about the kernel that isn't necessary for actual dispatching (this is why
// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
// giving good error messages.
struct AnnotatedKernel final {
AnnotatedKernel(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d)
: kernel(std::move(k))
, inferred_function_schema(std::move(s))
, debug(std::move(d))
{}
AnnotatedKernel() {}
KernelFunction kernel;
std::unique_ptr<FunctionSchema> inferred_function_schema;
// A little debug string to help us identify the kernel in question.
// Most importantly it records the TORCH_LIBRARY block that did the
// registration.
std::string debug;
};
// This data structure represents operator schema, with metadata specifying
// where the registration of this schema occurred
struct AnnotatedSchema final {
AnnotatedSchema(FunctionSchema s, std::string d)
: schema(std::move(s))
, debug(std::move(d))
{}
FunctionSchema schema;
std::string debug;
};
// Internal data structure that records information about a specific operator.
// It's not part of the public API; typically, users will interact with
// OperatorHandle instead.
//
// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
// lock (this is important because some methods in OperatorEntry access
// dispatcher state)
class CAFFE2_API OperatorEntry final {
public:
struct KernelEntry final {
KernelEntry(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d)
: kernel(std::move(k))
, inferred_function_schema(std::move(s))
, debug(std::move(d))
{}
KernelFunction kernel;
std::unique_ptr<FunctionSchema> inferred_function_schema;
// A little debug string to help us identify the kernel in question.
// Mostly used in testing but it might be possible to augment
// regular registrations with some more info here too
std::string debug;
};
explicit OperatorEntry(OperatorName&& operator_name);
OperatorEntry(const OperatorEntry&) = delete;
@ -40,30 +71,19 @@ public:
const FunctionSchema& schema() const {
TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet");
return *schema_;
return schema_->schema;
}
const std::string& debug() const {
TORCH_INTERNAL_ASSERT(debug_.has_value());
return *debug_;
TORCH_INTERNAL_ASSERT(schema_.has_value());
return schema_->debug;
}
bool hasSchema() const {
return schema_.has_value();
}
// An OperatorEntry may be initialized with only an OperatorName.
// If this is the case, we may post facto register a schema to it.
//
// Some rules:
// - The following programs are equivalent:
// OperatorEntry op(std::move(schema))
// and
// OperatorEntry op(schema.operator_name())
// op.registerSchema(std::move(schema))
// - The following programs are equivalent:
// OperatorEntry op(schema.operator_name())
// and
// OperatorEntry op(std::move(schema))
// op.deregisterSchema()
// We may allocate an OperatorEntry for an operator even when we don't
// have a schema. When we receive the schema registration, we post
// facto register a schema.
//
// NB: registerSchema/deregisterSchema are not idempotent; if you
// attempt to register a schema when one is already present or vice
@ -76,31 +96,58 @@ public:
return name_;
}
const DispatchTable& dispatch_table() const {
return dispatchTable_;
}
void prepareForDeregistration();
// Why are kernels and fallback asymmetric? It has to do with ownership.
// Kernels and the computed dispatch tables for them are canonically
// owned by OperatorEntry, but backend fallbacks are specified once
// and apply for all operators, so they should be owned by Dispatcher.
// However, the registration of a backend fallback affects the
// state of the computed dispatch table, so when a backend fallback
// is updated, we need to update the operator tables too. Thus,
// registerKernel is the mechanism by which we give kernels to
// operator entry to own (and update dispatch table), but we only
// need a non-owning mechanism to update fallback.
// Precondition: Dispatcher::mutex_ is held
// Postcondition: caller is responsible for disposing of the kernel
std::list<KernelEntry>::iterator registerKernel(c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
void deregisterKernel_(c10::optional<DispatchKey> dispatch_key, std::list<KernelEntry>::iterator kernel);
std::list<AnnotatedKernel>::iterator registerKernel(
const Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
);
// Precondition: Dispatcher::mutex_ is held
void deregisterKernel_(
const Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
std::list<AnnotatedKernel>::iterator kernel
);
// Precondition: Dispatcher::mutex_ is held
void updateFallback(
const Dispatcher& dispatcher,
DispatchKey dispatch_key
);
// Precondition: Dispatcher::mutex_ is held
void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
TORCH_INTERNAL_ASSERT(schema_.has_value());
schema_->setAliasAnalysis(a);
schema_->schema.setAliasAnalysis(a);
}
std::string dumpComputedTable() const;
std::string dumpState() const;
void checkInvariants() const;
const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }
// This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
// unboxing wrapper for aten operators. We still need those for some operators because not all work
// with the templated unboxing logic yet.
// TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic
void setManuallyBoxedKernel_(KernelFunction::InternalBoxedKernelFunction* func) {
dispatchTable_.setManuallyBoxedKernel_(func);
}
void setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func);
// Asserts that the given FuncType is correct for calling this operator in an unboxed way.
template<class FuncType>
@ -111,20 +158,36 @@ public:
" but the operator was registered with ",
cpp_signature_->name(),
" (",
debug_.value(),
(schema_.has_value() ? schema_->debug : "unknown debug info"),
") This likely happened in a call to OperatorHandle::typed<Return (Args...)>(). Please make sure that the function signature matches the signature in the operator registration call."
);
}
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
const KernelFunction& lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
if (C10_UNLIKELY(!kernel.isValid())) {
reportError(k);
}
return kernel;
}
std::string listAllDispatchKeys() const;
private:
OperatorName name_;
c10::optional<FunctionSchema> schema_;
c10::optional<std::string> debug_;
// INVARIANT: schema_.has_value() == debug_.has_value()
c10::optional<AnnotatedSchema> schema_;
// The dispatchTable stores the current kernel for each dispatch key
DispatchTable dispatchTable_;
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
// This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
// unboxing wrapper for aten operators. We still need those for some operators because not all work
// with the templated unboxing logic yet.
// TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic
c10::optional<KernelFunction::InternalBoxedKernelFunction*> manuallyBoxedKernel_;
// kernels_ stores all registered kernels for the corresponding dispatch key
// and catchAllKernels_ stores the catch-all kernels.
@ -157,9 +220,10 @@ private:
// re-executed and then only allow one kernel here, i.e. error if a kernel
// is already registered, but that's a lot of effort to implement and
// currently not high-pri.
ska::flat_hash_map<c10::optional<DispatchKey>, std::list<KernelEntry>> kernels_;
ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
std::mutex kernelsMutex_; // protects kernels_
std::list<AnnotatedKernel> catchAllKernel_;
AnnotatedKernel missingKernel_;
// signature_hash_ is set to the hash of the function signature if any of
// the kernels was created in a way that allowed us to know the function
@ -168,10 +232,16 @@ private:
// to verify their arguments against the known function signature.
c10::optional<CppSignature> cpp_signature_;
const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
) const;
// This function re-establishes the invariant that dispatchTable
// contains the front element from the kernels list for a given dispatch key.
void updateDispatchTable_(c10::optional<DispatchKey> dispatch_key);
void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
// Like above, but for ALL entries in the dispatch table.
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
};
}
}
} // namespace impl
} // namespace c10

View file

@ -148,8 +148,7 @@ TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) {
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
}, "Could not run '_test::dummy' with arguments from the 'CUDA'"
" backend. '_test::dummy' is only available for these backends:"
" [CPU].");
" backend.");
}
TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel) {
@ -237,7 +236,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithSchema_t
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails) {
@ -288,7 +287,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) {
@ -455,7 +454,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlder
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) {
@ -474,7 +473,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewe
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
@ -493,7 +492,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewer
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
@ -512,7 +511,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlde
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey) {
@ -533,7 +532,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
expectThrows<c10::Error>([&] {
callOpUnboxedWithDispatchKey<void, Tensor>(*op, c10::DispatchKey::CUDA, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CUDA'"
" backend. '_test::dummy' is only available for these backends: [CPU].");
" backend.");
}
TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel) {
@ -559,7 +558,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCall
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::XLA));
}, "Could not run '_test::dummy' with arguments from the 'XLA'"
" backend. '_test::dummy' is only available for these backends: [");
" backend.");
// also assert that the error message contains the available tensor type ids, but don't assert their order
expectThrows<c10::Error>([&] {
@ -586,17 +585,17 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallOutOfSc
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CPU'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
}, "Could not run '_test::dummy' with arguments from the 'CUDA'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::XLA));
}, "Could not run '_test::dummy' with arguments from the 'XLA'"
" backend. '_test::dummy' is only available for these backends: [].");
" backend.");
}
bool called_stackbased_kernel = false;
@ -727,7 +726,7 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelForWrongBacke
ASSERT_TRUE(op.has_value());
expectThrows<c10::Error>([&] {
auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
}, "Could not run '_test::dummy' with arguments from the 'CPU' backend. '_test::dummy' is only available for these backends: [].");
}, "Could not run '_test::dummy' with arguments from the 'CPU' backend.");
}
bool called = false;

View file

@ -356,7 +356,7 @@ alias analysis kind: PURE_FUNCTION
except RuntimeError as e:
self.assertExpectedInline(
str(e),
'''Tried to register multiple backend fallbacks for the same dispatch key XLA (registered at /dev/null:0)''' # noqa
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa
)
else:
self.assertTrue(False)