diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp index 4b047aa2fcb..c8c3859bd4f 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp @@ -4,19 +4,11 @@ namespace c10 { -void DispatchKeyExtractor::setOperatorHasKernelForBackend(DispatchKey k, bool has_kernel) { - if (has_kernel) { - operatorHasKernelForBackend_ = operatorHasKernelForBackend_.add(k); - } else { - operatorHasKernelForBackend_ = operatorHasKernelForBackend_.remove(k); - } -} - -void DispatchKeyExtractor::setOperatorHasFallthroughForBackend(DispatchKey k, bool has_fallthrough) { +void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) { if (has_fallthrough) { - operatorHasFallthroughForBackend_ = operatorHasFallthroughForBackend_.add(k); + nonFallthroughKeys_ = nonFallthroughKeys_.remove(k); } else { - operatorHasFallthroughForBackend_ = operatorHasFallthroughForBackend_.remove(k); + nonFallthroughKeys_ = nonFallthroughKeys_.add(k); } } @@ -29,7 +21,7 @@ std::string DispatchKeyExtractor::dumpState() const { oss << "0"; } } - oss << " " << operatorHasKernelForBackend_ << "\n"; + oss << " " << nonFallthroughKeys_ << "\n"; return oss.str(); } diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index d7e9c673242..aff6176d481 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -28,9 +28,8 @@ static inline DispatchKey dispatchTypeId( // The key mask lets us eliminate (by zero entries) keys which should not // be considered for dispatch. There are two cases when we use this: // - // - If there is no operator registered for a backend whose fallback behavior - // is to fallthrough, we eliminate that backend from consideration (since - // we want to "fallthrough" to the next valid key.) + // - If an operator's dispatch table contains a fallthrough entry, we + // should bypass it entirely when finding the key // - If a user invokes with redispatch, the mask lets us // zero out the key the user asked us to stop. // @@ -119,7 +118,7 @@ public: dispatch_arg_indices_reverse_ = c10::utils::bitset(); } - DispatchKey getDispatchKeyBoxed(DispatchKeySet backendsWithoutFallthrough, const torch::jit::Stack* stack) const { + DispatchKey getDispatchKeyBoxed(const torch::jit::Stack* stack) const { DispatchKeySet ks; dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) { const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1); @@ -133,19 +132,16 @@ public: } } }); - return dispatchKeySetToDispatchKey_(backendsWithoutFallthrough, DispatchKeySet::FULL, ks); + return dispatchKeySetToDispatchKey_(DispatchKeySet::FULL, ks); } template - DispatchKey getDispatchKeyUnboxed(DispatchKeySet backendsWithoutFallthrough, DispatchKeySet eligibleKeys, const Args&... args) const { + DispatchKey getDispatchKeyUnboxed(DispatchKeySet eligibleKeys, const Args&... args) const { auto ks = detail::multi_dispatch_key_set(args...); - return dispatchKeySetToDispatchKey_(backendsWithoutFallthrough, eligibleKeys, ks); + return dispatchKeySetToDispatchKey_(eligibleKeys, ks); } - // Used by DispatchTable to maintain the fallthrough invariant, see - // docs on operatorHasKernelForBackend_ - void setOperatorHasKernelForBackend(DispatchKey k, bool has_kernel); - void setOperatorHasFallthroughForBackend(DispatchKey k, bool has_fallthrough); + void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); std::string dumpState() const; void checkInvariants(const FunctionSchema& schema) const; @@ -166,22 +162,13 @@ private: // NB: If there is no valid dispatch key, this will return Undefined DispatchKey dispatchKeySetToDispatchKey_( - DispatchKeySet backendsWithoutFallthrough, // This is often known statically to be all ones; IN OPTIMIZER WE TRUST DispatchKeySet eligibleKeys, DispatchKeySet ks ) const { return impl::dispatchTypeId(ks, - // We must NOT respect the passed in backendsWithoutFallthrough if an operator has - // specifically overridden the backend, since that means we've opted to - // not fallthrough and instead apply some specific behavior (which we - // must dispatch to). - // - // This scheme doesn't work if you want to also apply fallthrough on a - // per-op basis, but while we could directly fix this by maintaining a - // second DispatchKeySet, it doesn't seem that there is any actual use case, - // so we are deferring it for #32454. - ((backendsWithoutFallthrough | operatorHasKernelForBackend_) - operatorHasFallthroughForBackend_) + // Keys that are fallthrough should be skipped + nonFallthroughKeys_ // Regardless of fallthrough behavior, only accept keys which are eligible // for dispatch, as requested by the user & eligibleKeys); @@ -189,8 +176,7 @@ private: explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse) - , operatorHasKernelForBackend_() - , operatorHasFallthroughForBackend_() {} + , nonFallthroughKeys_(DispatchKeySet::FULL) {} // this is a bitset that has ones for each argument index which has to be // considered for dispatch. This avoids having to iterate over the stack @@ -202,10 +188,8 @@ private: // fallthrough c10::utils::bitset dispatch_arg_indices_reverse_; - // Set of backends for which the operator has explicitly registered a kernel. - DispatchKeySet operatorHasKernelForBackend_; - // Set of backends for which the operator has explicitly registered a fallthrough kernel. - DispatchKeySet operatorHasFallthroughForBackend_; + // Set of keys for which the operator does NOT have fallthrough kernel. + DispatchKeySet nonFallthroughKeys_; }; } diff --git a/aten/src/ATen/core/dispatch/DispatchTable.cpp b/aten/src/ATen/core/dispatch/DispatchTable.cpp deleted file mode 100644 index 5d8adc7b594..00000000000 --- a/aten/src/ATen/core/dispatch/DispatchTable.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include - -#include - -namespace c10 { - -namespace impl { - std::string KernelFunctionTable::dumpState() const { - std::ostringstream oss; - for (uint8_t i = 0; i < static_cast(DispatchKey::NumDispatchKeys); i++) { - if (kernels_[i].isValid()) oss << " " << kernels_[i].dumpState() << "\n"; - } - return oss.str(); - } -} - -std::string DispatchTable::dumpState() const { - std::ostringstream oss; - oss << "table:\n"; - oss << kernels_.dumpState(); - oss << " catchall: " << catchallKernel_.dumpState() << "\n"; - oss << " extractor: " << dispatchKeyExtractor_.dumpState() << "\n"; - oss << " name: " << operatorName_ << "\n"; - return oss.str(); -} - -} // namespace c10 diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h deleted file mode 100644 index 7ec7fe531fa..00000000000 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ /dev/null @@ -1,247 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace c10 { - -namespace impl { -/** - * A KernelFunctionTable is a map from DispatchKey to a KernelFunction. - * It can store zero or one KernelFunctions for each DispatchKey. - */ -class KernelFunctionTable final { -public: - explicit KernelFunctionTable() - : kernels_() - , kernelCount_(0) {} - - void setKernel(DispatchKey dispatchKey, KernelFunction kernel) { - TORCH_INTERNAL_ASSERT(dispatchKey != DispatchKey::Undefined); - auto& slot = kernels_[static_cast(dispatchKey)]; - if (!slot.isValid()) { - ++kernelCount_; - } - slot = std::move(kernel); - } - - void removeKernelIfExists(DispatchKey dispatchKey) { - auto& slot = kernels_[static_cast(dispatchKey)]; - if (slot.isValid()) { - --kernelCount_; - slot = {}; - } else { - } - } - - const KernelFunction& operator[](DispatchKey dispatchKey) const { - return kernels_[static_cast(dispatchKey)]; - } - - KernelFunction& operator[](DispatchKey dispatchKey) { - return kernels_[static_cast(dispatchKey)]; - } - - size_t size() const { - return kernelCount_; - } - - std::string dumpState() const; - -private: - std::array(DispatchKey::NumDispatchKeys)> kernels_; - size_t kernelCount_; -}; -} - -/** - * Per-operator dispatch table. - * - * Given an operator specified by a FunctionSchema, this class records a dispatch - * table for various kernels provided for this operator. For example, if we - * consider the operator add(Tensor, Tensor), the dispatch table for this - * operator may contain implementations for various dynamic tensor types, such - * as CPU, CUDA, etc. - */ -class DispatchTable final { - public: - explicit DispatchTable(const FunctionSchema& schema) - : kernels_() - , catchallKernel_() - , dispatchKeyExtractor_(DispatchKeyExtractor::make(schema)) - , operatorName_(schema.operator_name()) {} - - // a dispatch table may be default constructed with only an - // operator name. Such a dispatch table is not callable until - // the schema is provided - DispatchTable(OperatorName op_name) - : kernels_() - , catchallKernel_() - , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) - , operatorName_(std::move(op_name)) {} - - /** - * Register a kernel in the table at some dispatch key. - * @param dispatch_key Dispatch key to define when this kernel is selected. - * @param kernel Concrete kernel function implementation to register - */ - void setKernel(DispatchKey dispatchKey, KernelFunction kernel) { - if (manuallyBoxedKernel_.has_value()) { - kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_); - } - kernels_.setKernel(dispatchKey, std::move(kernel)); - dispatchKeyExtractor_.setOperatorHasKernelForBackend(dispatchKey, true); - if (kernel.isFallthrough()) { - dispatchKeyExtractor_.setOperatorHasFallthroughForBackend(dispatchKey, true); - } - } - - /** - * Deregister the kernel for some dispatch key. - * - * @param dispatch_key Dispatch key to unregister. - */ - void removeKernelIfExists(DispatchKey dispatchKey) { - kernels_.removeKernelIfExists(dispatchKey); - dispatchKeyExtractor_.setOperatorHasKernelForBackend(dispatchKey, false); - dispatchKeyExtractor_.setOperatorHasFallthroughForBackend(dispatchKey, false); // may be no op - } - - /** - * Register a catch-all kernel that is called for this operator - * independent of the inputs. An operator can have either - * a catch-all kernel or a set of kernels with concrete - * dispatch keys, not both. - */ - void setCatchallKernel(KernelFunction kernel) { - if (manuallyBoxedKernel_.has_value()) { - kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_); - } - catchallKernel_ = std::move(kernel); - } - - /** - * Remove the catch-all kernel. - */ - void removeCatchallKernel() { - catchallKernel_ = {}; - } - - bool isEmpty() const { - return !catchallKernel_.isValid() && kernels_.size() == 0; - } - - std::string listAllDispatchKeys() const { - std::ostringstream str; - str << "["; - - bool has_kernels = false; - for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { - if (!kernels_[static_cast(iter)].isValid()) { - continue; - } - if (has_kernels) { - str << ", "; - } - str << static_cast(iter); - has_kernels = true; - } - - if (catchallKernel_.isValid()) { - if (has_kernels) { - str << ", "; - } - str << "CATCH-ALL"; - } - str << "]"; - return str.str(); - } - - const KernelFunction* lookup(DispatchKey dispatchKey) const { - auto& slot = kernels_[dispatchKey]; - // TODO: this condition shouldn't be necessary - if (slot.isValid()) { - return &slot; - } else { - return nullptr; - } - } - - const KernelFunction* lookupCatchallKernel() const { - // TODO: this condition shouldn't be necessary - if (!catchallKernel_.isValid()) { - return nullptr; - } - - return &catchallKernel_; - } - - const DispatchKeyExtractor& dispatchKeyExtractor() const { - return dispatchKeyExtractor_; - } - - const OperatorName& operatorName() const { - return operatorName_; - } - - void registerSchema(const FunctionSchema& schema) { - dispatchKeyExtractor_.registerSchema(schema); - } - - void deregisterSchema() { - dispatchKeyExtractor_.deregisterSchema(); - } - - std::string dumpState() const; - - // This function is a temporary hack, see comment at manuallyBoxedKernel_ member - void setManuallyBoxedKernel_(KernelFunction::InternalBoxedKernelFunction* func) { - TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_.has_value(), "Cannot set multiple manually boxed kernels for the same operator ", operatorName_); - manuallyBoxedKernel_ = func; - - // make sure that all previously registered kernels get this manually boxed kernel - for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { - auto& kernel = kernels_[static_cast(iter)]; - if (kernel.isValid()) { - kernel.setManuallyBoxedKernel_(func); - } - } - if (catchallKernel_.isValid()) { - catchallKernel_.setManuallyBoxedKernel_(func); - } - } - - c10::optional manuallyBoxedKernel() const { - return manuallyBoxedKernel_; - } - -private: - - impl::KernelFunctionTable kernels_; - KernelFunction catchallKernel_; - DispatchKeyExtractor dispatchKeyExtractor_; - OperatorName operatorName_; - - // This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed - // unboxing wrapper for aten operators. We still need those for some operators because not all work - // with the templated unboxing logic yet. - // TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic - c10::optional manuallyBoxedKernel_; -}; - -} // namespace c10 diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 0f24f0eca62..6b4774d8f67 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -38,7 +38,6 @@ Dispatcher::Dispatcher() : operators_() , operatorLookupTable_() , backendFallbackKernels_() -, backendsWithoutFallthrough_(DispatchKeySet::FULL) , listeners_(std::make_unique()) , mutex_() {} @@ -206,7 +205,14 @@ RegistrationHandleRAII Dispatcher::registerImpl( auto op = findOrRegisterName_(op_name); - auto handle = op.operatorIterator_->op.registerKernel(dispatch_key, std::move(kernel), std::move(cpp_signature), std::move(inferred_function_schema), std::move(debug)); + auto handle = op.operatorIterator_->op.registerKernel( + *this, + dispatch_key, + std::move(kernel), + std::move(cpp_signature), + std::move(inferred_function_schema), + std::move(debug) + ); ++op.operatorIterator_->def_and_impl_count; @@ -215,10 +221,10 @@ RegistrationHandleRAII Dispatcher::registerImpl( }); } -void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional dispatch_key, std::list::iterator handle) { +void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional dispatch_key, std::list::iterator handle) { std::lock_guard lock(mutex_); - op.operatorIterator_->op.deregisterKernel_(dispatch_key, handle); + op.operatorIterator_->op.deregisterKernel_(*this, dispatch_key, handle); TORCH_INTERNAL_ASSERT(op.operator_name() == op_name); @@ -249,8 +255,6 @@ void Dispatcher::deregisterName_( // Test if the operator entry is completely dead, and if so remove it completely void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) { if (0 == op.operatorIterator_->def_and_impl_count) { - // TODO: rename this to "assert deregistration invariants" - op.operatorIterator_->op.prepareForDeregistration(); operators_.erase(op.operatorIterator_); operatorLookupTable_.write([&] (ska::flat_hash_map& operatorLookupTable) { operatorLookupTable.erase(op_name); @@ -261,14 +265,17 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) { std::lock_guard lock(mutex_); - // TODO: preserve debug for old fallback TORCH_CHECK( - !backendFallbackKernels_[dispatchKey].isValid(), - "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, " (", debug, ")" + !backendFallbackKernels_[static_cast(dispatchKey)].kernel.isValid(), + "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", + backendFallbackKernels_[static_cast(dispatchKey)].debug, ", new registration ", debug ); - backendFallbackKernels_.setKernel(dispatchKey, std::move(kernel)); - if (kernel.isFallthrough()) { - backendsWithoutFallthrough_ = backendsWithoutFallthrough_.remove(dispatchKey); + // NB: inferred function schema is always nullptr for fallbacks, as fallbacks + // cannot be unobxed + backendFallbackKernels_[static_cast(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); + + for (auto& op : operators_) { + op.op.updateFallback(*this, dispatchKey); } return RegistrationHandleRAII([this, dispatchKey] { @@ -279,8 +286,11 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) { std::lock_guard lock(mutex_); - backendFallbackKernels_.removeKernelIfExists(dispatchKey); - backendsWithoutFallthrough_ = backendsWithoutFallthrough_.add(dispatchKey); + backendFallbackKernels_[static_cast(dispatchKey)] = {}; + + for (auto& op : operators_) { + op.op.updateFallback(*this, dispatchKey); + } } @@ -300,38 +310,16 @@ RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr(DispatchKey::NumDispatchKeys); i++) { - auto k = static_cast(i); - if (!backendsWithoutFallthrough_.has(k)) { - const auto& kernel = backendFallbackKernels_[k]; - TORCH_INTERNAL_ASSERT(kernel.isFallthrough()); - } - } } void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func) { - op.operatorIterator_->op.setManuallyBoxedKernel_(func); + std::lock_guard lock(mutex_); + op.operatorIterator_->op.setManuallyBoxedKernel_(*this, func); + // NB: Do not need to set manually boxed kernel for backend fallbacks } } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 1264941625d..db73feecacc 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -41,6 +41,9 @@ class SchemaRegistrationHandleRAII; */ class CAFFE2_API Dispatcher final { private: + // For direct access to backend fallback information + friend class impl::OperatorEntry; + struct OperatorDef final { explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {} @@ -206,26 +209,20 @@ private: const OperatorHandle& op, const OperatorName& op_name, c10::optional dispatch_key, - std::list::iterator kernel_handle); + std::list::iterator kernel_handle); void deregisterName_(const OperatorHandle& op, const OperatorName& op_name); void deregisterFallback_(DispatchKey dispatchKey); void deregisterLibrary_(const std::string& ns); void cleanup(const OperatorHandle& op, const OperatorName& op_name); void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug); - [[noreturn]] static void reportError(const DispatchTable& dispatchTable, DispatchKey dispatchKey); - - const KernelFunction& dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatch_key) const; - std::list operators_; LeftRight> operatorLookupTable_; // Map from namespace to debug string (saying, e.g., where the library was defined) ska::flat_hash_map libraries_; - impl::KernelFunctionTable backendFallbackKernels_; - // Set of backends which have specified they do NOT want fallthrough behavior - // (we store the inverse because it avoids a negation when we use this for - // masking) - DispatchKeySet backendsWithoutFallthrough_; + + std::array(DispatchKey::NumDispatchKeys)> backendFallbackKernels_; + std::unique_ptr listeners_; std::mutex mutex_; }; @@ -268,6 +265,12 @@ public: template TypedOperatorHandle typed() const { + // NB: This assert is not 100% sound: you can retrieve a typed() operator + // handle prior to ANY C++ signature being registered on the operator + // and the check will say everything is OK (at which point you can then + // smuggle in a kernel that is typed incorrectly). For everything + // in core library this won't happen, because all the static registrations + // will be done by the time a typed() handle is acquired. operatorIterator_->op.assertSignatureIsCorrect(); return TypedOperatorHandle(operatorIterator_); } @@ -324,57 +327,38 @@ template inline void unused_arg_(const Args&...) {} template inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle& op, DispatchKey dispatchKey, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); + const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey); return kernel.template call(op, std::forward(args)...); } template inline Return Dispatcher::call(const TypedOperatorHandle& op, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed(backendsWithoutFallthrough_, DispatchKeySet::FULL, args...); + auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor() + .template getDispatchKeyUnboxed( + DispatchKeySet::FULL, + args... + ); return callWithDispatchKey(op, dispatchKey, args...); } template inline Return Dispatcher::redispatch(const TypedOperatorHandle& op, DispatchKey currentDispatchKey, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - auto dispatchKey = dispatchTable.dispatchKeyExtractor().template getDispatchKeyUnboxed( - backendsWithoutFallthrough_, - DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey), - args...); - const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); - return kernel.template call(op, std::forward(args)...); + auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor() + .template getDispatchKeyUnboxed( + DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey), + args... + ); + return callWithDispatchKey(op, dispatchKey, args...); } inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const { // note: this doesn't need the mutex because write operations on the list keep iterators intact. - const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - auto dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(backendsWithoutFallthrough_, stack); - const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); + const auto& entry = op.operatorIterator_->op; + auto dispatchKey = entry.dispatchKeyExtractor().getDispatchKeyBoxed(stack); + const auto& kernel = entry.lookup(dispatchKey); kernel.callBoxed(op, stack); } -inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatchKey) const { - const KernelFunction* backendKernel = dispatchTable.lookup(dispatchKey); - - if (nullptr != backendKernel) { - return *backendKernel; - } - - const auto& backendFallbackKernel = backendFallbackKernels_[dispatchKey]; - if (backendFallbackKernel.isValid()) { - return backendFallbackKernel; - } - - const KernelFunction* catchallKernel = dispatchTable.lookupCatchallKernel(); - if (C10_LIKELY(nullptr != catchallKernel)) { - return *catchallKernel; - } - - reportError(dispatchTable, dispatchKey); -} - } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 70f006fbc94..9ba9053659b 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -1,11 +1,11 @@ #include #include +#include namespace c10 { namespace impl { namespace { - std::string toString(c10::optional k) { if (k.has_value()) { return toString(*k); @@ -13,33 +13,21 @@ namespace { return "(catch all)"; } } - - std::string listAllDispatchKeys(const ska::flat_hash_map, std::list>& kernels) { - if (kernels.size() == 0) { - return ""; - } - std::ostringstream str; - str << toString(kernels.begin()->first); - for (auto iter = ++kernels.begin(); iter != kernels.end(); ++iter) { - str << ", " << toString(iter->first); - } - return str.str(); - } } OperatorEntry::OperatorEntry(OperatorName&& operator_name) : name_(std::move(operator_name)) , schema_() -, debug_() -, dispatchTable_(name_) -, kernels_() { -} - -void OperatorEntry::prepareForDeregistration() { - if (!dispatchTable_.isEmpty()) { - TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator is ", toString(name_), ". Registered kernels for dispatch keys: ", dispatchTable_.listAllDispatchKeys()); - } - TORCH_INTERNAL_ASSERT(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_), ". The operator is ", toString(name_)); +, dispatchTable_() +, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) +, manuallyBoxedKernel_() +, kernels_() +, catchAllKernel_() +, cpp_signature_() +{ + // Pick up any backend fallbacks that were registered prior to this + // OperatorEntry being created + updateDispatchTableFull_(c10::Dispatcher::singleton()); } namespace { @@ -52,7 +40,7 @@ namespace { *schema_difference); } } -} +} // anonymous namespace void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) { TORCH_INTERNAL_ASSERT(!schema_.has_value()); @@ -63,28 +51,37 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) } } } + for (auto j = catchAllKernel_.begin(); j != catchAllKernel_.end(); ++j) { + if (j->inferred_function_schema != nullptr) { + checkSchema(name_, schema, debug, *j->inferred_function_schema, j->debug); + } + } // NB: don't register schema until after we've checked everything! - schema_ = std::move(schema); - debug_ = std::move(debug); - dispatchTable_.registerSchema(*schema_); + dispatchKeyExtractor_.registerSchema(schema); + schema_ = AnnotatedSchema(std::move(schema), std::move(debug)); } void OperatorEntry::deregisterSchema() { TORCH_INTERNAL_ASSERT(schema_.has_value()); schema_ = c10::nullopt; - debug_ = c10::nullopt; - dispatchTable_.deregisterSchema(); + dispatchKeyExtractor_.deregisterSchema(); } -std::list::iterator OperatorEntry::registerKernel( +std::list::iterator OperatorEntry::registerKernel( + const c10::Dispatcher& dispatcher, c10::optional dispatch_key, KernelFunction kernel, c10::optional cpp_signature, std::unique_ptr inferred_function_schema, std::string debug ) { - std::unique_lock lock(kernelsMutex_); - + // NB: cpp_signature doesn't get cleared even after the kernel that populated + // it is deleted. This means you could poison the value of cpp_signature_ + // with a bad signature value, and then it would permanently stay there until + // you deregister the schema. This can't really be fixed, because we + // only do a typed() test once in the lifetime of a TypedOperatorHandle, + // which means if you could validly change the type of a cpp_signature, then + // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { if (cpp_signature_.has_value()) { TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_, @@ -98,116 +95,229 @@ std::list::iterator OperatorEntry::registerKernel( } if (schema_ && inferred_function_schema) { - checkSchema(name_, *schema_, *debug_, *inferred_function_schema, debug); + checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug); } // Add the kernel to the kernels list, // possibly creating the list if this is the first kernel. - auto& k = kernels_[dispatch_key]; + auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : catchAllKernel_; if (k.size() > 0) { TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); } + if (manuallyBoxedKernel_.has_value()) { + kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_); + } + k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug)); - std::list::iterator inserted = k.begin(); + std::list::iterator inserted = k.begin(); // update the dispatch table, i.e. re-establish the invariant // that the dispatch table points to the newest kernel - updateDispatchTable_(dispatch_key); + if (dispatch_key.has_value()) { + updateDispatchTable_(dispatcher, *dispatch_key); + } else { + updateDispatchTableFull_(dispatcher); + } return inserted; } void OperatorEntry::deregisterKernel_( + const c10::Dispatcher& dispatcher, c10::optional dispatch_key, - std::list::iterator kernel + std::list::iterator kernel ) { - std::unique_lock lock(kernelsMutex_); - - auto found = kernels_.find(dispatch_key); - TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_)); - auto& k = found->second; - k.erase(kernel); - if (k.empty()) { - // the invariant says we don't want empty lists but instead remove the list from the map - kernels_.erase(found); + if (dispatch_key.has_value()) { + auto found = kernels_.find(*dispatch_key); + TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_)); + auto& k = found->second; + k.erase(kernel); + if (k.empty()) { + // the invariant says we don't want empty lists but instead remove the list from the map + kernels_.erase(found); + } + updateDispatchTable_(dispatcher, *dispatch_key); + } else { + catchAllKernel_.erase(kernel); + updateDispatchTableFull_(dispatcher); } - - updateDispatchTable_(dispatch_key); } -void OperatorEntry::updateDispatchTable_(c10::optional dispatch_key) { - // precondition: kernelsMutex_ is locked +void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { + updateDispatchTable_(dispatcher, dispatch_key); +} - auto k = kernels_.find(dispatch_key); - if (dispatch_key.has_value()) { - if (k == kernels_.end()) { - dispatchTable_.removeKernelIfExists(*dispatch_key); - } else { - dispatchTable_.setKernel(*dispatch_key, k->second.front().kernel); - } +const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const { + return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel; +} + +std::pair OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const { + auto dispatch_ix = static_cast(dispatch_key); + + // 1. Operator registration + auto kern_it = kernels_.find(dispatch_key); + if (kern_it != kernels_.end()) { + TORCH_INTERNAL_ASSERT(!kern_it->second.empty()); + TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid()); + return {kern_it->second.front(), "kernel"}; + + // 2. Backend fallback + } else if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) { + return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"}; + + // 3. Catch all + } else if (!catchAllKernel_.empty()) { + TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid()); + return {catchAllKernel_.front(), "catch all"}; + + // 4. Default to error } else { - if (k == kernels_.end()) { - dispatchTable_.removeCatchallKernel(); - } else { - dispatchTable_.setCatchallKernel(k->second.front().kernel); + return {missingKernel_, "missing"}; + } +} + +void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { + auto dispatch_ix = static_cast(dispatch_key); + dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key); + dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough()); +} + +void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) { + for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { + updateDispatchTable_(dispatcher, static_cast(iter)); + } +} + +void OperatorEntry::setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func) { + TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_); + manuallyBoxedKernel_ = func; + + for (auto& kv : kernels_) { + for (auto& k : kv.second) { + k.kernel.setManuallyBoxedKernel_(func); } } + for (auto& k : catchAllKernel_) { + k.kernel.setManuallyBoxedKernel_(func); + } + + // Refresh entries in dispatchTable_ + updateDispatchTableFull_(dispatcher); } void OperatorEntry::checkInvariants() const { if (schema_) { - TORCH_INTERNAL_ASSERT(schema_->operator_name() == name_); - dispatchTable_.dispatchKeyExtractor().checkInvariants(*schema_); + TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState()); + dispatchKeyExtractor().checkInvariants(schema_->schema); } - TORCH_INTERNAL_ASSERT(schema_.has_value() == debug_.has_value()); - TORCH_INTERNAL_ASSERT(name_ == dispatchTable_.operatorName()); - TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end()); + TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end(), dumpState()); for (const auto& kv : kernels_) { - auto mb_dispatch_key = kv.first; - TORCH_INTERNAL_ASSERT(kv.second.size() > 0); - auto* kernel = mb_dispatch_key ? dispatchTable_.lookup(*mb_dispatch_key) : dispatchTable_.lookupCatchallKernel(); - auto manual_boxed_kernel = dispatchTable_.manuallyBoxedKernel(); - // NB: this is a copy - auto local_kernel = kv.second.front().kernel; - if (manual_boxed_kernel.has_value()) { - local_kernel.setManuallyBoxedKernel_(*manual_boxed_kernel); - } - TORCH_INTERNAL_ASSERT(local_kernel._equalsBoxedAndUnboxed(*kernel)); + TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState()); + } + for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { + auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast(iter)); + TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]), + "Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n" + "Computed table:\n~~~~~~~~~~~\n", dumpComputedTable()); } } +std::string OperatorEntry::listAllDispatchKeys() const { + std::ostringstream str; + str << "["; + + bool has_kernels = false; + for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { + if (!dispatchTable_[iter].isValid()) { + continue; + } + if (has_kernels) { + str << ", "; + } + str << static_cast(iter); + has_kernels = true; + } + str << "]"; + return str.str(); +} + +void OperatorEntry::reportError(DispatchKey dispatchKey) const { + // If there is an invariant problem, report it now. + checkInvariants(); + + if (dispatchKey == DispatchKey::Undefined) { + TORCH_CHECK(false, + "There were no tensor arguments to this function (e.g., you passed an " + "empty list of Tensors), but no fallback function is registered for schema ", name_, + ". This usually means that this function requires a non-empty list of Tensors. " + "Available functions are ", listAllDispatchKeys(), ".\n\n", dumpComputedTable()) + } + + TORCH_CHECK(false, "Could not run '", name_, "' with arguments", + " from the '", toString(dispatchKey), "' backend. '", + name_, "' is only available for these backends: ", + listAllDispatchKeys(), ".\n\n", dumpComputedTable()); +} + +// INSPECTING DISPATCHER STATE +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The dumper functions purposely do not check invariants, as you might be using +// them to debug situations where the invariants are violated. + +// Inspect what the computed dispatch table would be (e.g., what +// updateDispatchTableFull_ would update the dispatch table to be) +std::string OperatorEntry::dumpComputedTable() const { + std::ostringstream oss; + for (uint8_t i = 0; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + auto k = static_cast(i); + auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); + if (kernel_prov.first.kernel.isValid()) { + oss << toString(k) << ": " + << (kernel_prov.first.kernel.isFallthrough() ? "fallthrough " : "") + << kernel_prov.first.debug << " [" << kernel_prov.second << "]\n"; + } + } + return oss.str(); +} + +// Inspect the "canonical" information in OperatorEntry. This only prints out +// *non-derived* information; i.e., what the source of truth says about the +// operator. This dumping function is appropriate for expect tests. +// This WON'T report backend fallbacks. std::string OperatorEntry::dumpState() const { std::ostringstream oss; oss << "name: " << name_ << "\n"; if (schema_) { - oss << "schema: " << *schema_ << "\n"; - oss << "debug: " << *debug_ << "\n"; - oss << "alias analysis kind: " << toString(schema_->aliasAnalysis()) << (schema_->isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n"; + oss << "schema: " << schema_->schema << "\n"; + oss << "debug: " << schema_->debug << "\n"; + oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis()) + << (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n"; } else { oss << "schema: (none)\n"; } - // Iterate over DispatchKey, not the flat hash map, so we have a stable order - auto print_key = [&](c10::optional k) { - auto it = kernels_.find(k); - if (it != kernels_.end()) { - int64_t i = 0; - for (const auto& jt : it->second) { - oss << (k ? toString(k) : "catchall") - << (i > 0 ? " (inactive)" : "") - << ": " - << jt.debug << " :: " - << toString(*jt.inferred_function_schema) << " [ " << jt.kernel.dumpState() << "]\n"; - i++; - } + + auto print_kernel = [&](const char* k_desc, const std::list& jts) { + int64_t i = 0; + for (const auto& jt : jts) { + oss << k_desc + << (i > 0 ? " (inactive)" : "") + << ": " + << jt.debug << " :: " + << (jt.inferred_function_schema ? toString(*jt.inferred_function_schema) : "(none)") + << " [ " << jt.kernel.dumpState() << "]\n"; + i++; } }; + + // Iterate over DispatchKey, not the flat hash map, so we have a stable order for (uint8_t i = 0; i < static_cast(DispatchKey::NumDispatchKeys); i++) { - print_key(static_cast(i)); + auto k = static_cast(i); + auto it = kernels_.find(k); + if (it != kernels_.end()) { + print_kernel(toString(k), it->second); + } } - print_key(c10::nullopt); - // dispatch table is 100% specified by OperatorEntry; so if you want to check - // if it makes sense use checkInvariants - // oss << dispatchTable_.dumpState(); + print_kernel("catchall", catchAllKernel_); return oss.str(); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 057eddb93fb..cdc8c7f797a 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -1,36 +1,67 @@ #pragma once -#include +#include +#include +#include +#include +#include +#include +#include +#include + #include #include #include + #include +#include namespace c10 { -namespace impl { - class OperatorEntry; -} + +class Dispatcher; namespace impl { -// This is a private class used inside the Dispatcher to represent an operator -// and its dispatch table. This is not part of the public API. +// This data structure represents a kernel that was registered to us from a +// user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata +// about the kernel that isn't necessary for actual dispatching (this is why +// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for +// giving good error messages. +struct AnnotatedKernel final { + AnnotatedKernel(KernelFunction k, std::unique_ptr s, std::string d) + : kernel(std::move(k)) + , inferred_function_schema(std::move(s)) + , debug(std::move(d)) + {} + AnnotatedKernel() {} + KernelFunction kernel; + std::unique_ptr inferred_function_schema; + // A little debug string to help us identify the kernel in question. + // Most importantly it records the TORCH_LIBRARY block that did the + // registration. + std::string debug; +}; + +// This data structure represents operator schema, with metadata specifying +// where the registration of this schema occurred +struct AnnotatedSchema final { + AnnotatedSchema(FunctionSchema s, std::string d) + : schema(std::move(s)) + , debug(std::move(d)) + {} + FunctionSchema schema; + std::string debug; +}; + +// Internal data structure that records information about a specific operator. +// It's not part of the public API; typically, users will interact with +// OperatorHandle instead. +// +// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher +// lock (this is important because some methods in OperatorEntry access +// dispatcher state) class CAFFE2_API OperatorEntry final { public: - struct KernelEntry final { - KernelEntry(KernelFunction k, std::unique_ptr s, std::string d) - : kernel(std::move(k)) - , inferred_function_schema(std::move(s)) - , debug(std::move(d)) - {} - KernelFunction kernel; - std::unique_ptr inferred_function_schema; - // A little debug string to help us identify the kernel in question. - // Mostly used in testing but it might be possible to augment - // regular registrations with some more info here too - std::string debug; - }; - explicit OperatorEntry(OperatorName&& operator_name); OperatorEntry(const OperatorEntry&) = delete; @@ -40,30 +71,19 @@ public: const FunctionSchema& schema() const { TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet"); - return *schema_; + return schema_->schema; } const std::string& debug() const { - TORCH_INTERNAL_ASSERT(debug_.has_value()); - return *debug_; + TORCH_INTERNAL_ASSERT(schema_.has_value()); + return schema_->debug; } bool hasSchema() const { return schema_.has_value(); } - // An OperatorEntry may be initialized with only an OperatorName. - // If this is the case, we may post facto register a schema to it. - // - // Some rules: - // - The following programs are equivalent: - // OperatorEntry op(std::move(schema)) - // and - // OperatorEntry op(schema.operator_name()) - // op.registerSchema(std::move(schema)) - // - The following programs are equivalent: - // OperatorEntry op(schema.operator_name()) - // and - // OperatorEntry op(std::move(schema)) - // op.deregisterSchema() + // We may allocate an OperatorEntry for an operator even when we don't + // have a schema. When we receive the schema registration, we post + // facto register a schema. // // NB: registerSchema/deregisterSchema are not idempotent; if you // attempt to register a schema when one is already present or vice @@ -76,31 +96,58 @@ public: return name_; } - const DispatchTable& dispatch_table() const { - return dispatchTable_; - } - - void prepareForDeregistration(); + // Why are kernels and fallback asymmetric? It has to do with ownership. + // Kernels and the computed dispatch tables for them are canonically + // owned by OperatorEntry, but backend fallbacks are specified once + // and apply for all operators, so they should be owned by Dispatcher. + // However, the registration of a backend fallback affects the + // state of the computed dispatch table, so when a backend fallback + // is updated, we need to update the operator tables too. Thus, + // registerKernel is the mechanism by which we give kernels to + // operator entry to own (and update dispatch table), but we only + // need a non-owning mechanism to update fallback. + // Precondition: Dispatcher::mutex_ is held // Postcondition: caller is responsible for disposing of the kernel - std::list::iterator registerKernel(c10::optional dispatch_key, KernelFunction kernel, c10::optional cpp_signature, std::unique_ptr inferred_function_schema, std::string debug); - void deregisterKernel_(c10::optional dispatch_key, std::list::iterator kernel); + std::list::iterator registerKernel( + const Dispatcher& dispatcher, + c10::optional dispatch_key, + KernelFunction kernel, + c10::optional cpp_signature, + std::unique_ptr inferred_function_schema, + std::string debug + ); + // Precondition: Dispatcher::mutex_ is held + void deregisterKernel_( + const Dispatcher& dispatcher, + c10::optional dispatch_key, + std::list::iterator kernel + ); + + // Precondition: Dispatcher::mutex_ is held + void updateFallback( + const Dispatcher& dispatcher, + DispatchKey dispatch_key + ); + + // Precondition: Dispatcher::mutex_ is held void updateSchemaAliasAnalysis(AliasAnalysisKind a) { TORCH_INTERNAL_ASSERT(schema_.has_value()); - schema_->setAliasAnalysis(a); + schema_->schema.setAliasAnalysis(a); } + std::string dumpComputedTable() const; std::string dumpState() const; void checkInvariants() const; + const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; } + // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed // unboxing wrapper for aten operators. We still need those for some operators because not all work // with the templated unboxing logic yet. // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic - void setManuallyBoxedKernel_(KernelFunction::InternalBoxedKernelFunction* func) { - dispatchTable_.setManuallyBoxedKernel_(func); - } + void setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func); // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template @@ -111,20 +158,36 @@ public: " but the operator was registered with ", cpp_signature_->name(), " (", - debug_.value(), + (schema_.has_value() ? schema_->debug : "unknown debug info"), ") This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call." ); } + [[noreturn]] void reportError(DispatchKey dispatchKey) const; + + const KernelFunction& lookup(DispatchKey k) const { + const auto& kernel = dispatchTable_[static_cast(k)]; + if (C10_UNLIKELY(!kernel.isValid())) { + reportError(k); + } + return kernel; + } + + std::string listAllDispatchKeys() const; + private: OperatorName name_; - c10::optional schema_; - c10::optional debug_; - // INVARIANT: schema_.has_value() == debug_.has_value() + c10::optional schema_; - // The dispatchTable stores the current kernel for each dispatch key - DispatchTable dispatchTable_; + std::array(DispatchKey::NumDispatchKeys)> dispatchTable_; + DispatchKeyExtractor dispatchKeyExtractor_; + + // This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed + // unboxing wrapper for aten operators. We still need those for some operators because not all work + // with the templated unboxing logic yet. + // TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic + c10::optional manuallyBoxedKernel_; // kernels_ stores all registered kernels for the corresponding dispatch key // and catchAllKernels_ stores the catch-all kernels. @@ -157,9 +220,10 @@ private: // re-executed and then only allow one kernel here, i.e. error if a kernel // is already registered, but that's a lot of effort to implement and // currently not high-pri. - ska::flat_hash_map, std::list> kernels_; + ska::flat_hash_map> kernels_; - std::mutex kernelsMutex_; // protects kernels_ + std::list catchAllKernel_; + AnnotatedKernel missingKernel_; // signature_hash_ is set to the hash of the function signature if any of // the kernels was created in a way that allowed us to know the function @@ -168,10 +232,16 @@ private: // to verify their arguments against the known function signature. c10::optional cpp_signature_; + const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const; + std::pair computeDispatchTableEntryWithDebug( + const c10::Dispatcher& dispatcher, DispatchKey dispatch_key + ) const; // This function re-establishes the invariant that dispatchTable // contains the front element from the kernels list for a given dispatch key. - void updateDispatchTable_(c10::optional dispatch_key); + void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key); + // Like above, but for ALL entries in the dispatch table. + void updateDispatchTableFull_(const c10::Dispatcher& dispatcher); }; -} -} +} // namespace impl +} // namespace c10 diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 75db6feca10..fa9aaeab08c 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -148,8 +148,7 @@ TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) { expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CUDA)); }, "Could not run '_test::dummy' with arguments from the 'CUDA'" - " backend. '_test::dummy' is only available for these backends:" - " [CPU]."); + " backend."); } TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel) { @@ -237,7 +236,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithSchema_t expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithoutSchema_thenFails) { @@ -288,7 +287,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) { @@ -455,7 +454,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlder expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) { @@ -474,7 +473,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewe expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) { @@ -493,7 +492,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewer expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) { @@ -512,7 +511,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlde expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey) { @@ -533,7 +532,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox expectThrows([&] { callOpUnboxedWithDispatchKey(*op, c10::DispatchKey::CUDA, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CUDA'" - " backend. '_test::dummy' is only available for these backends: [CPU]."); + " backend."); } TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel) { @@ -559,7 +558,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCall expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::XLA)); }, "Could not run '_test::dummy' with arguments from the 'XLA'" - " backend. '_test::dummy' is only available for these backends: ["); + " backend."); // also assert that the error message contains the available tensor type ids, but don't assert their order expectThrows([&] { @@ -586,17 +585,17 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallOutOfSc expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CPU)); }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::CUDA)); }, "Could not run '_test::dummy' with arguments from the 'CUDA'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::XLA)); }, "Could not run '_test::dummy' with arguments from the 'XLA'" - " backend. '_test::dummy' is only available for these backends: []."); + " backend."); } bool called_stackbased_kernel = false; @@ -727,7 +726,7 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelForWrongBacke ASSERT_TRUE(op.has_value()); expectThrows([&] { auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello "); - }, "Could not run '_test::dummy' with arguments from the 'CPU' backend. '_test::dummy' is only available for these backends: []."); + }, "Could not run '_test::dummy' with arguments from the 'CPU' backend."); } bool called = false; diff --git a/test/test_dispatch.py b/test/test_dispatch.py index f10e3714c7e..1c721eecabc 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -356,7 +356,7 @@ alias analysis kind: PURE_FUNCTION except RuntimeError as e: self.assertExpectedInline( str(e), - '''Tried to register multiple backend fallbacks for the same dispatch key XLA (registered at /dev/null:0)''' # noqa + '''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa ) else: self.assertTrue(False)