mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[RELAND] Add DispatchKey impl overload; remove use of torch::dispatch (#36222)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36222 Reland of #35706, with fixes to code analyzer. It is extremely common to define implementations of operators at a specific dispatch key, so we add an overload to impl specifically for this case. I then delete most uses of torch::dispatch dispatch_autograd call sites can't make use of this overload. So instead the new preferred way to specify something as autograd is to pass kAutograd as the dispatch key (short form, analogous to kCPU/kCUDA which we support today). I flip flopped about whether or not kAutograd should have the type DispatchKey or some other type (to help better encapsulate the DispatchKey enum); this is more direct and I can't think of any BC problems from this usage. Some other reorganization I did: - I renamed all of the worker functions in op_registration to have a leading underscore and made them private, just to make it more clear what the public versus private API were (the private API shouldn't be used by users because it doesn't come with && overloads) Note that this means I needed to adjust the regex in the code analyzer, because - In a few places where I was touching lines already, I replaced full DispatchKey typed out enums with shorter kFoo names, similar to kAutograd but I didn't publish these globally. - Code analyzer now prints a unified diff, and in the other order (because I tend to think of the diff as reporting how the /new/ result is different) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D20929256 Pulled By: ezyang fbshipit-source-id: c69b803d2b3a1a8aff70e14da33d3adec5239f13
This commit is contained in:
parent
477f1c047c
commit
ef07bb65e9
16 changed files with 169 additions and 124 deletions
|
|
@ -332,8 +332,7 @@ I think Option 2 is the right answer for all ops, not just convolutions. Option
|
|||
*****************************************************************************************************************/
|
||||
|
||||
auto register_fallthrough = c10::import()
|
||||
.fallback(c10::dispatch(c10::DispatchKey::AutocastTensorId,
|
||||
c10::CppFunction::makeFallthrough()));
|
||||
.fallback(c10::DispatchKey::AutocastTensorId, c10::CppFunction::makeFallthrough());
|
||||
|
||||
/********************************************************************************************************************
|
||||
Explicit registration for out-of-place ops
|
||||
|
|
@ -361,17 +360,17 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
|
|||
// Common cases where registration signature matches redispatch signature
|
||||
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
|
||||
#define KERNEL(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
|
||||
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
|
||||
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call))
|
||||
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
|
||||
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call)
|
||||
|
||||
#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
|
||||
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
|
||||
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call)))
|
||||
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
|
||||
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call))
|
||||
|
||||
// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
|
||||
#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
|
||||
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
|
||||
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call)))
|
||||
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
|
||||
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call))
|
||||
|
||||
/*****************************************
|
||||
Explicit registration for out-of-place ops
|
||||
|
|
@ -426,9 +425,8 @@ auto register_out_of_place = c10::import()
|
|||
KERNEL(ADD_NS(gelu), "aten::gelu", Tensor (const Tensor &), fp32)
|
||||
KERNEL_UNBOXED_ONLY(ADD_NS(layer_norm), "aten::layer_norm", Tensor (const Tensor &, IntArrayRef, const Tensor &, const Tensor &, double, bool), fp32)
|
||||
// The macro doesn't like this one so I had to write it out manually.
|
||||
.impl("aten::native_layer_norm",
|
||||
c10::dispatch(DispatchKey::AutocastTensorId,
|
||||
CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)))
|
||||
.impl("aten::native_layer_norm", DispatchKey::AutocastTensorId,
|
||||
CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call))
|
||||
KERNEL_UNBOXED_ONLY(ADD_NS(group_norm), "aten::group_norm", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &, double, bool), fp32)
|
||||
KERNEL_UNBOXED_ONLY(ADD_NS(frobenius_norm), "aten::frobenius_norm", Tensor (const Tensor &), fp32)
|
||||
KERNEL_UNBOXED_ONLY(ADD_NS(frobenius_norm), "aten::frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
|
||||
|
|
@ -496,9 +494,8 @@ auto register_out_of_place = c10::import()
|
|||
;
|
||||
|
||||
auto register_banned = torch::import()
|
||||
.impl("aten::binary_cross_entropy",
|
||||
torch::dispatch(DispatchKey::AutocastTensorId,
|
||||
CppFunction::makeUnboxedOnly(&at::autocast::binary_cross_entropy_banned)));
|
||||
.impl("aten::binary_cross_entropy", DispatchKey::AutocastTensorId,
|
||||
CppFunction::makeUnboxedOnly(&at::autocast::binary_cross_entropy_banned));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
namespace {
|
||||
|
||||
static auto registry = c10::import()
|
||||
.fallback(c10::dispatch(c10::DispatchKey::BackendSelect, c10::CppFunction::makeFallthrough()))
|
||||
.fallback(c10::DispatchKey::BackendSelect, c10::CppFunction::makeFallthrough())
|
||||
;
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,13 +128,13 @@ Module& Module::operator=(Module&&) = default;
|
|||
// TODO: Error if an operator is def'ed multiple times. Right now we just
|
||||
// merge everything
|
||||
|
||||
Module& Module::def(FunctionSchema&& schema) & {
|
||||
Module& Module::_def(FunctionSchema&& schema) & {
|
||||
if (ns_.has_value()) schema.setNamespaceIfNotSet(ns_->c_str());
|
||||
registrars_.emplace_back(Dispatcher::singleton().registerDef(std::move(schema)));
|
||||
return *this;
|
||||
}
|
||||
|
||||
Module& Module::def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
||||
Module& Module::_def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
||||
FunctionSchema schema = [&] {
|
||||
if (name_or_schema.is_right()) {
|
||||
return std::move(name_or_schema).right();
|
||||
|
|
@ -156,7 +156,7 @@ Module& Module::def(c10::either<OperatorName, FunctionSchema>&& name_or_schema,
|
|||
return *this;
|
||||
}
|
||||
|
||||
Module& Module::impl(const char* name_str, CppFunction&& f) & {
|
||||
Module& Module::_impl(const char* name_str, CppFunction&& f) & {
|
||||
auto name = torch::jit::parseName(name_str);
|
||||
if (ns_.has_value()) name.setNamespaceIfNotSet(ns_->c_str());
|
||||
registrars_.emplace_back(
|
||||
|
|
@ -171,7 +171,7 @@ Module& Module::impl(const char* name_str, CppFunction&& f) & {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Module& Module::fallback(CppFunction&& f) & {
|
||||
Module& Module::_fallback(CppFunction&& f) & {
|
||||
TORCH_CHECK(!ns_, "Cannot define fallbacks from namespaces, use c10::import().fallback() instead");
|
||||
TORCH_CHECK(f.dispatch_key_, "Fallback for catch all function not supported");
|
||||
registrars_.emplace_back(Dispatcher::singleton().registerFallback(*f.dispatch_key_, std::move(f.func_)));
|
||||
|
|
|
|||
|
|
@ -617,8 +617,8 @@ private:
|
|||
// // provide multiple; one per backend). We'll take care of calling
|
||||
// // the correct implementation depending on if we get a CPU
|
||||
// // tensor or a CUDA tensor
|
||||
// .impl("aten::mul", torch::dispatch(torch::kCPU, &mul_cpu_impl))
|
||||
// .impl("aten::mul", torch::dispatch(torch::kCUDA, &mul_cuda_impl))
|
||||
// .impl("aten::mul", torch::kCPU, &mul_cpu_impl)
|
||||
// .impl("aten::mul", torch::kCUDA, &mul_cuda_impl)
|
||||
//
|
||||
// Also, you can omit the top level namespace and specify it explicitly in
|
||||
// the sub-definitions, e.g., torch::import().impl("aten::mul", ...)
|
||||
|
|
@ -726,12 +726,17 @@ template <typename Func>
|
|||
inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
|
||||
auto deviceTypeToDispatchKey = [](DeviceType t){
|
||||
switch (t) {
|
||||
// This list is synchronized with the k-constants in c10/core/DeviceType.h
|
||||
case DeviceType::CPU:
|
||||
return c10::DispatchKey::CPUTensorId;
|
||||
case DeviceType::CUDA:
|
||||
return c10::DispatchKey::CUDATensorId;
|
||||
case DeviceType::XLA:
|
||||
return c10::DispatchKey::XLATensorId;
|
||||
case DeviceType::HIP:
|
||||
return c10::DispatchKey::HIPTensorId;
|
||||
case DeviceType::MSNPU:
|
||||
return c10::DispatchKey::MSNPUTensorId;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Device type ", t, " cannot be overloaded at dispatch time, "
|
||||
|
|
@ -741,12 +746,6 @@ inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
|
|||
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
|
||||
}
|
||||
|
||||
// Convenience for overriding autograd functionality
|
||||
template <typename Func>
|
||||
inline CppFunction dispatch_autograd(Func&& raw_f) {
|
||||
return dispatch(c10::DispatchKey::VariableTensorId, std::forward<Func>(raw_f));
|
||||
}
|
||||
|
||||
inline FunctionSchema schema(const char* str, AliasAnalysisKind k) {
|
||||
FunctionSchema s = torch::jit::parseSchema(str);
|
||||
s.setAliasAnalysis(k);
|
||||
|
|
@ -794,6 +793,14 @@ class CAFFE2_API Module final {
|
|||
friend Module _import_DOES_NOT_WORK_WITH_MOBILE_CUSTOM_BUILD(std::string ns);
|
||||
friend Module import();
|
||||
|
||||
private:
|
||||
// Non-user visible actual implementations of functions. These aren't
|
||||
// public because we only implement & qualifier and not && qualifier
|
||||
Module& _def(FunctionSchema&& schema) &;
|
||||
Module& _def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
|
||||
Module& _impl(const char* name, CppFunction&& f) &;
|
||||
Module& _fallback(CppFunction&& f) &;
|
||||
|
||||
public:
|
||||
Module(const Module&) = delete;
|
||||
Module& operator=(const Module&) = delete;
|
||||
|
|
@ -832,11 +839,10 @@ public:
|
|||
// Declare an operator with a schema, but don't provide any implementations
|
||||
// for it. You're expected to then provide implementations using the
|
||||
// impl() method.
|
||||
Module& def(FunctionSchema&& schema) &;
|
||||
template <typename Schema>
|
||||
Module& def(Schema&& raw_schema) & {
|
||||
FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return def(std::move(s));
|
||||
return _def(std::move(s));
|
||||
}
|
||||
template <typename Schema>
|
||||
Module&& def(Schema&& raw_schema) && {
|
||||
|
|
@ -848,12 +854,11 @@ public:
|
|||
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
|
||||
// except that if n is not a schema, then the schema is inferred from the
|
||||
// static type of f.
|
||||
Module& def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
|
||||
template <typename NameOrSchema, typename Func>
|
||||
Module& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
|
||||
return def(std::move(name_or_schema), std::move(f));
|
||||
return _def(std::move(name_or_schema), std::move(f));
|
||||
}
|
||||
template <typename NameOrSchema, typename Func>
|
||||
Module&& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) && {
|
||||
|
|
@ -865,27 +870,50 @@ public:
|
|||
// implementations for a single operator at different dispatch keys
|
||||
// (see torch::dispatch). Implementations must have a corresponding
|
||||
// declaration (from def), otherwise they are invalid.
|
||||
Module& impl(const char* name, CppFunction&& f) &;
|
||||
template <typename Func>
|
||||
Module& impl(const char* name, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
return impl(name, std::move(f));
|
||||
return _impl(name, std::move(f));
|
||||
}
|
||||
template <typename Func>
|
||||
Module&& impl(const char* name, Func&& raw_f) && {
|
||||
impl(name, std::forward<Func>(raw_f));
|
||||
return std::move(*this);
|
||||
}
|
||||
// Convenience overload for directly specifying the dispatch key. Dispatch
|
||||
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
|
||||
// the canonical list of accepted overloads.
|
||||
template <typename Dispatch, typename Func>
|
||||
Module& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
|
||||
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
||||
}
|
||||
template <typename Dispatch, typename Func>
|
||||
Module&& impl(const char* name, Dispatch&& key, Func&& raw_f) && {
|
||||
impl(name, std::forward<Dispatch>(key), std::forward<Func>(raw_f));
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
// Register a fallback implementation for all operators which will be used
|
||||
// if there is not a specific implementation for an operator available.
|
||||
// At the moment, you must specify a dispatch key (see torch::dispatch) for
|
||||
// your fallback.
|
||||
Module& fallback(CppFunction&& f) &;
|
||||
// Providing a DispatchKey is MANDATORY for fallback at the moment.
|
||||
//
|
||||
// Dispatch can validly be either DeviceType or DispatchKey; check
|
||||
// torch::dispatch for the canonical list of accepted overloads.
|
||||
template <typename Dispatch, typename Func>
|
||||
Module& fallback(Dispatch&& key, Func&& raw_f) & {
|
||||
return fallback(c10::dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
||||
}
|
||||
template <typename Dispatch, typename Func>
|
||||
Module&& fallback(Dispatch&& key, Func&& raw_f) && {
|
||||
fallback(std::forward<Dispatch>(key), std::forward<Func>(raw_f));
|
||||
return std::move(*this);
|
||||
}
|
||||
// NB: these overloads are here for completeness, but you'll probably want to
|
||||
// use the direct Dispatch overload
|
||||
template <typename Func>
|
||||
Module& fallback(Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
return fallback(std::move(f));
|
||||
CppFunction f((std::forward<Func>(raw_f)));
|
||||
return _fallback(std::move(f));
|
||||
}
|
||||
template <typename Func>
|
||||
Module&& fallback(Func&& raw_f) && {
|
||||
|
|
@ -921,7 +949,6 @@ namespace torch {
|
|||
|
||||
// New-style API
|
||||
using c10::dispatch;
|
||||
using c10::dispatch_autograd;
|
||||
using c10::schema;
|
||||
using c10::import;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1297,11 +1297,11 @@ TEST(NewOperatorRegistrationTest, testBasics) {
|
|||
.def("_test::dummy2(Tensor self) -> Tensor")
|
||||
.def("_test::dummy3(Tensor self, Tensor other) -> Tensor", [](const Tensor& self, const Tensor& other) { return self; })
|
||||
.def("_test::dummy4", [](const Tensor& self, const Tensor& other) { return other; })
|
||||
.impl("_test::dummy", c10::dispatch(c10::DeviceType::CPU, [](const Tensor& self) { return self; }))
|
||||
.impl("_test::dummy", c10::dispatch(c10::DeviceType::XLA, [](const Tensor& self) { return self; }))
|
||||
.impl("_test::dummy", c10::DeviceType::CPU, [](const Tensor& self) { return self; })
|
||||
.impl("_test::dummy", c10::DeviceType::XLA, [](const Tensor& self) { return self; })
|
||||
// Internal API
|
||||
.impl("_test::dummy2", c10::dispatch(c10::DispatchKey::CPUTensorId, [](const Tensor& self) { return self; }))
|
||||
.impl("_test::dummy2", c10::dispatch(c10::DispatchKey::XLATensorId, [](const Tensor& self) { return self; }));
|
||||
.impl("_test::dummy2", c10::DispatchKey::CPUTensorId, [](const Tensor& self) { return self; })
|
||||
.impl("_test::dummy2", c10::DispatchKey::XLATensorId, [](const Tensor& self) { return self; });
|
||||
|
||||
ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy", ""}).has_value());
|
||||
// Should have a schema even if there are no impls
|
||||
|
|
@ -1382,7 +1382,7 @@ TEST(NewOperatorRegistrationTest, dispatch) {
|
|||
auto registrar = c10::import()
|
||||
.def("test::fn_cpu", torch::dispatch(c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; }))
|
||||
.def("test::fn_cuda", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }))
|
||||
.def("test::fn_autograd", torch::dispatch_autograd([&](const Tensor& x) { autograd_called = true; return x; }));
|
||||
.def("test::fn_autograd", torch::dispatch(c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; }));
|
||||
|
||||
{
|
||||
auto op = Dispatcher::singleton().findSchema({"test::fn_cpu", ""});
|
||||
|
|
@ -1415,9 +1415,11 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {
|
|||
bool autograd_called = false;
|
||||
auto registrar = c10::import()
|
||||
.def("test::fn(Tensor self) -> Tensor")
|
||||
.impl("test::fn", torch::dispatch(c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; }))
|
||||
.impl("test::fn", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }))
|
||||
.impl("test::fn", torch::dispatch_autograd([&](const Tensor& x) { autograd_called = true; return x; }));
|
||||
// NB: Direct use of DispatchKey is discouraged; use the DeviceType
|
||||
// k-synonyms instead
|
||||
.impl("test::fn", c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; })
|
||||
.impl("test::fn", c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; })
|
||||
.impl("test::fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });
|
||||
|
||||
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
|
@ -1440,7 +1442,7 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {
|
|||
|
||||
TEST(NewOperatorRegistrationTest, fallback) {
|
||||
auto registrar = c10::import()
|
||||
.fallback(torch::dispatch(c10::kCPU, c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()));
|
||||
.fallback(c10::kCPU, c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||
|
||||
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
|
||||
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
|
||||
|
|
@ -1454,12 +1456,12 @@ TEST(NewOperatorRegistrationTest, BackendSelectRedispatchesToCPU) {
|
|||
bool backend_generic_called = false;
|
||||
auto registrar = c10::import()
|
||||
.def("test::fn(Tensor self) -> Tensor")
|
||||
.impl("test::fn", torch::dispatch(c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; }))
|
||||
.impl("test::fn", torch::dispatch(c10::DispatchKey::BackendSelect, [&](const Tensor& x) {
|
||||
.impl("test::fn", c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; })
|
||||
.impl("test::fn", c10::DispatchKey::BackendSelect, [&](const Tensor& x) {
|
||||
backend_generic_called = true;
|
||||
auto op = c10::Dispatcher::singleton().findSchema({"test::fn", ""});
|
||||
return c10::Dispatcher::singleton().callUnboxedRedispatch<Tensor, const Tensor&>(*op, c10::DispatchKey::BackendSelect, x);
|
||||
}))
|
||||
})
|
||||
;
|
||||
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
|
|
|||
|
|
@ -117,16 +117,16 @@ DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
|||
CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}))
|
||||
""")
|
||||
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
.impl("${operator_name_with_overload}", torch::dispatch(
|
||||
.impl("${operator_name_with_overload}",
|
||||
DispatchKey::${Backend}TensorId,
|
||||
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name})))
|
||||
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
|
||||
""")
|
||||
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
.impl("${operator_name_with_overload}", &TypeDefault::${type_wrapper_name})
|
||||
""")
|
||||
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
.impl("${operator_name_with_overload}",
|
||||
torch::dispatch(DispatchKey::${Backend}TensorId, &${Type}::${type_wrapper_name}))
|
||||
DispatchKey::${Backend}TensorId, &${Type}::${type_wrapper_name})
|
||||
""")
|
||||
|
||||
# add non-virtual declaration to TensorBody.h
|
||||
|
|
|
|||
|
|
@ -138,12 +138,11 @@ TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
|||
TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
|
||||
// By default fallthrough
|
||||
auto registry = c10::import()
|
||||
.fallback(
|
||||
c10::dispatch(DispatchKey::TESTING_ONLY_GenericModeTensorId,
|
||||
c10::CppFunction::makeFallthrough()))
|
||||
.fallback(DispatchKey::TESTING_ONLY_GenericModeTensorId,
|
||||
c10::CppFunction::makeFallthrough())
|
||||
.impl("aten::mul.Tensor",
|
||||
c10::dispatch(DispatchKey::TESTING_ONLY_GenericModeTensorId,
|
||||
c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()));
|
||||
DispatchKey::TESTING_ONLY_GenericModeTensorId,
|
||||
c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||
|
||||
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericModeTensorId);
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,10 @@ using namespace at;
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr auto kCustomRNG = DispatchKey::CustomRNGKeyId;
|
||||
|
||||
struct TestCPUGenerator : public c10::GeneratorImpl {
|
||||
TestCPUGenerator(uint64_t value) : GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) { }
|
||||
TestCPUGenerator(uint64_t value) : GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(kCustomRNG)}, value_(value) { }
|
||||
~TestCPUGenerator() = default;
|
||||
uint32_t random() { return value_; }
|
||||
uint64_t random64() { return value_; }
|
||||
|
|
@ -98,21 +100,20 @@ class RNGTest : public ::testing::Test {
|
|||
void SetUp() override {
|
||||
static auto registry = torch::import()
|
||||
// Random
|
||||
.impl("aten::random_.from", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(random_from_to)))
|
||||
.impl("aten::random_.to", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(random_to)))
|
||||
.impl("aten::random_", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(random_)))
|
||||
.impl("aten::random_.from", kCustomRNG, CppFunction::makeUnboxedOnly(random_from_to))
|
||||
.impl("aten::random_.to", kCustomRNG, CppFunction::makeUnboxedOnly(random_to))
|
||||
.impl("aten::random_", kCustomRNG, CppFunction::makeUnboxedOnly(random_))
|
||||
// Normal
|
||||
.impl("aten::normal_", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_)))
|
||||
.impl("aten::normal.Tensor_float_out", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_Tensor_float_out)))
|
||||
.impl("aten::normal.float_Tensor_out", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_float_Tensor_out)))
|
||||
.impl("aten::normal.Tensor_Tensor_out", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_Tensor_Tensor_out)))
|
||||
.impl("aten::normal.Tensor_float", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_Tensor_float)))
|
||||
.impl("aten::normal.float_Tensor", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_float_Tensor)))
|
||||
.impl("aten::normal.Tensor_Tensor", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(normal_Tensor_Tensor)))
|
||||
// Uniform
|
||||
.impl("aten::uniform_", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(uniform_)))
|
||||
.impl("aten::normal_", kCustomRNG, CppFunction::makeUnboxedOnly(normal_))
|
||||
.impl("aten::normal.Tensor_float_out", kCustomRNG, CppFunction::makeUnboxedOnly(normal_Tensor_float_out))
|
||||
.impl("aten::normal.float_Tensor_out", kCustomRNG, CppFunction::makeUnboxedOnly(normal_float_Tensor_out))
|
||||
.impl("aten::normal.Tensor_Tensor_out", kCustomRNG, CppFunction::makeUnboxedOnly(normal_Tensor_Tensor_out))
|
||||
.impl("aten::normal.Tensor_float", kCustomRNG, CppFunction::makeUnboxedOnly(normal_Tensor_float))
|
||||
.impl("aten::normal.float_Tensor", kCustomRNG, CppFunction::makeUnboxedOnly(normal_float_Tensor))
|
||||
.impl("aten::normal.Tensor_Tensor", kCustomRNG, CppFunction::makeUnboxedOnly(normal_Tensor_Tensor))
|
||||
.impl("aten::uniform_", kCustomRNG, CppFunction::makeUnboxedOnly(uniform_))
|
||||
// Cauchy
|
||||
.impl("aten::cauchy_", torch::dispatch(DispatchKey::CustomRNGKeyId, CppFunction::makeUnboxedOnly(custom_rng_cauchy_)))
|
||||
.impl("aten::cauchy_", kCustomRNG, CppFunction::makeUnboxedOnly(custom_rng_cauchy_))
|
||||
;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -204,8 +204,21 @@ static inline DispatchKey XLATensorId() {
|
|||
return DispatchKey::XLATensorId;
|
||||
}
|
||||
|
||||
// These are some convenience identifiers for dispatch keys which are
|
||||
// shorter to type than their long counterparts. Note that some of these
|
||||
// dispatch keys directly correspond to DeviceType; and most APIs that
|
||||
// accept DispatchKey also accept DeviceType; e.g.,
|
||||
// torch::dispatch(torch::kCPU, ...) is also valid.
|
||||
constexpr DispatchKey kAutograd = DispatchKey::VariableTensorId;
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
// Expose the constant, but not the TYPE (DispatchKey is an implementation
|
||||
// detail!)
|
||||
using c10::kAutograd;
|
||||
}
|
||||
|
||||
// NB: You really shouldn't use this instance; this enum is guaranteed
|
||||
// to be pretty small so a regular array should be acceptable.
|
||||
namespace std {
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|||
|
||||
void init_msnpu_extension() {
|
||||
static auto registry = torch::import()
|
||||
.impl("aten::empty.memory_format", torch::dispatch(DispatchKey::MSNPUTensorId, CppFunction::makeUnboxedOnly(empty_override)))
|
||||
.impl("aten::add.Tensor", torch::dispatch(DispatchKey::MSNPUTensorId, CppFunction::makeUnboxedOnly(add_override)))
|
||||
.impl("aten::convolution_overrideable", torch::dispatch(DispatchKey::MSNPUTensorId, CppFunction::makeUnboxedOnly(fake_convolution)))
|
||||
.impl("aten::convolution_backward_overrideable", torch::dispatch(DispatchKey::MSNPUTensorId, CppFunction::makeUnboxedOnly(fake_convolution_backward)))
|
||||
.impl("aten::empty.memory_format", kMSNPU, CppFunction::makeUnboxedOnly(empty_override))
|
||||
.impl("aten::add.Tensor", kMSNPU, CppFunction::makeUnboxedOnly(add_override))
|
||||
.impl("aten::convolution_overrideable", kMSNPU, CppFunction::makeUnboxedOnly(fake_convolution))
|
||||
.impl("aten::convolution_backward_overrideable", kMSNPU, CppFunction::makeUnboxedOnly(fake_convolution_backward))
|
||||
;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -62,23 +62,29 @@ Tensor FF_op(const Tensor& self) {
|
|||
|
||||
namespace {
|
||||
|
||||
// NB: Some of these registrations (AA, EE) are not what you
|
||||
// actually expect to see in practice, but we cover them here
|
||||
// as they are technically "valid" API calls and we want to
|
||||
// make sure the analyzer catches them. (The analyzer is very
|
||||
// generic, so actually there isn't any reason it shouldn't work,
|
||||
// but it's good to test them!)
|
||||
//
|
||||
// Additionally, the code in this file is not really runnable; for
|
||||
// example we are missing schemas for all of the impl registrations
|
||||
// here. The analyzer doesn't really care, as it only really
|
||||
// cares about the name
|
||||
auto registerer = torch::import()
|
||||
.def("aten::AA(Tensor self) -> Tensor",
|
||||
torch::dispatch(DispatchKey::CPUTensorId, &AA_op))
|
||||
.def("aten::AA(Tensor self) -> Tensor", torch::dispatch(kCPU, &AA_op))
|
||||
.def("aten::BB(Tensor self) -> Tensor", &BB_op)
|
||||
.impl("aten::CC(Tensor self) -> Tensor",
|
||||
torch::dispatch(DispatchKey::CPUTensorId, &CC_op))
|
||||
.impl("aten::DD(Tensor self) -> Tensor", &DD_op)
|
||||
.def("aten::EE(Tensor self) -> Tensor", torch::dispatch(
|
||||
DispatchKey::CPUTensorId,
|
||||
CppFunction::makeUnboxedOnly(EE_op)))
|
||||
.def("aten::FF(Tensor self) -> Tensor",
|
||||
CppFunction::makeUnboxedOnly(FF_op))
|
||||
.impl("aten::GG(Tensor self) -> Tensor", torch::dispatch(
|
||||
DispatchKey::CPUTensorId, [] (Tensor a) -> Tensor {
|
||||
.impl("aten::CC", kCPU, &CC_op)
|
||||
.impl("aten::DD", &DD_op)
|
||||
.def("aten::EE(Tensor self) -> Tensor", torch::dispatch(kCPU, CppFunction::makeUnboxedOnly(EE_op)))
|
||||
.def("aten::FF(Tensor self) -> Tensor", CppFunction::makeUnboxedOnly(FF_op))
|
||||
.impl("aten::GG",
|
||||
kCPU, [] (Tensor a) -> Tensor {
|
||||
return call_FF_op(a);
|
||||
}))
|
||||
.impl("aten::HH(Tensor self) -> Tensor",
|
||||
})
|
||||
.impl("aten::HH",
|
||||
[] (Tensor a) -> Tensor {
|
||||
return a;
|
||||
});
|
||||
|
|
|
|||
|
|
@ -190,8 +190,7 @@ class TestDispatch(TestCase):
|
|||
lambda m: m.def_("foo(Tensor x) -> Tensor"),
|
||||
# m.impl("test_def", [](const Tensor& x) { return x })
|
||||
lambda m: m.impl_t_t("foo"),
|
||||
# m.impl("test_def",
|
||||
# torch::dispatch_autograd([](const Tensor& x) { return x }))
|
||||
# m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
|
||||
lambda m: m.impl_t_t("foo", dispatch="autograd")
|
||||
])
|
||||
self.assertExpectedInline(r, '''\
|
||||
|
|
@ -217,7 +216,7 @@ catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||
r = self.commute("foo", [
|
||||
# m.def("foo", [](const Tensor & x) { return x })
|
||||
lambda m: m.def_name_t_t("foo"),
|
||||
# m.impl("foo", torch::dispatch_autograd([](const Tensor & x) { return x }))
|
||||
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
|
||||
lambda m: m.impl_t_t("foo", "autograd")
|
||||
])
|
||||
self.assertExpectedInline(r, '''\
|
||||
|
|
@ -243,8 +242,7 @@ alias analysis kind: FROM_SCHEMA
|
|||
r = self.commute("foo", [
|
||||
# m.impl("foo", [](const Tensor& x) { return x })
|
||||
lambda m: m.impl_t_t("foo"),
|
||||
# m.impl("foo",
|
||||
# torch::dispatch_autograd([](const Tensor& x) { return x }))
|
||||
# m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
|
||||
lambda m: m.impl_t_t("foo", "autograd")
|
||||
])
|
||||
self.assertExpectedInline(r, '''\
|
||||
|
|
|
|||
|
|
@ -183,17 +183,15 @@ ${return_type} ${type_wrapper_name}(${type_method_formals}) {
|
|||
""")
|
||||
|
||||
UNBOXEDONLY_WRAPPER_REGISTRATION = CodeTemplate("""\
|
||||
.impl("${operator_name_with_overload}",
|
||||
torch::dispatch_autograd(
|
||||
.impl("${operator_name_with_overload}", torch::kAutograd,
|
||||
CppFunction::makeUnboxedOnly(VariableType::${type_wrapper_name})
|
||||
))
|
||||
)
|
||||
""")
|
||||
|
||||
WRAPPER_REGISTRATION = CodeTemplate("""\
|
||||
.impl("${operator_name_with_overload}",
|
||||
torch::dispatch_autograd(
|
||||
.impl("${operator_name_with_overload}", torch::kAutograd,
|
||||
&VariableType::${type_wrapper_name}
|
||||
))
|
||||
)
|
||||
""")
|
||||
|
||||
UNPACK_TENSOR = CodeTemplate("""\
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ check_test_result() {
|
|||
echo "Test result is the same as expected."
|
||||
else
|
||||
echo "Test result is DIFFERENT from expected!"
|
||||
diff "${OUTPUT}" "${TEST_SRC_ROOT}/expected_deps.yaml"
|
||||
diff -u "${TEST_SRC_ROOT}/expected_deps.yaml" "${OUTPUT}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ echo "Analyze: ${INPUT}"
|
|||
|
||||
"${ANALYZER_BIN}" \
|
||||
-op_schema_pattern="^(_aten|_prim|aten|quantized|profiler|_test)::[^ ]+" \
|
||||
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(def|impl)" \
|
||||
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl)" \
|
||||
-op_invoke_pattern="c10::Dispatcher::findSchema|callOp" \
|
||||
-format="${FORMAT}" \
|
||||
${EXTRA_ANALYZER_FLAGS} \
|
||||
|
|
|
|||
|
|
@ -105,30 +105,34 @@ static auto registry =
|
|||
torch::import()
|
||||
.impl(
|
||||
"_aten::add.Scalar",
|
||||
torch::dispatch_autograd(torch::autograd::VariableType::add_Scalar))
|
||||
torch::kAutograd,
|
||||
torch::autograd::VariableType::add_Scalar)
|
||||
.impl(
|
||||
"_aten::mul.Tensor",
|
||||
torch::dispatch_autograd(torch::autograd::VariableType::mul_Tensor))
|
||||
torch::kAutograd,
|
||||
torch::autograd::VariableType::mul_Tensor)
|
||||
.impl(
|
||||
"_aten::conv2d",
|
||||
torch::dispatch_autograd(
|
||||
CppFunction::makeFromBoxedFunction<conv2d_kernel>()))
|
||||
.impl("_aten::dropout", torch::dispatch_autograd(VariableType::dropout))
|
||||
torch::kAutograd,
|
||||
CppFunction::makeFromBoxedFunction<conv2d_kernel>())
|
||||
.impl("_aten::dropout", torch::kAutograd, VariableType::dropout)
|
||||
.impl(
|
||||
"_aten::feature_dropout",
|
||||
torch::dispatch_autograd(VariableType::feature_dropout))
|
||||
torch::kAutograd,
|
||||
VariableType::feature_dropout)
|
||||
.impl(
|
||||
"_aten::log_softmax.int",
|
||||
torch::dispatch_autograd(
|
||||
CppFunction::makeFromBoxedFunction<log_softmax_kernel>()))
|
||||
torch::kAutograd,
|
||||
CppFunction::makeFromBoxedFunction<log_softmax_kernel>())
|
||||
.impl(
|
||||
"_aten::max_pool2d",
|
||||
torch::dispatch_autograd([](const Tensor& self,
|
||||
c10::List<int64_t> kernel_size,
|
||||
c10::List<int64_t> stride,
|
||||
c10::List<int64_t> padding,
|
||||
c10::List<int64_t> dilation,
|
||||
bool ceil_mode = false) {
|
||||
torch::kAutograd,
|
||||
[](const Tensor& self,
|
||||
c10::List<int64_t> kernel_size,
|
||||
c10::List<int64_t> stride,
|
||||
c10::List<int64_t> padding,
|
||||
c10::List<int64_t> dilation,
|
||||
bool ceil_mode = false) {
|
||||
return VariableType::max_pool2d(
|
||||
self,
|
||||
kernel_size.vec(),
|
||||
|
|
@ -136,12 +140,12 @@ static auto registry =
|
|||
padding.vec(),
|
||||
dilation.vec(),
|
||||
ceil_mode);
|
||||
}))
|
||||
.impl("_aten::relu", torch::dispatch_autograd(VariableType::relu))
|
||||
})
|
||||
.impl("_aten::relu", torch::kAutograd, VariableType::relu)
|
||||
.impl(
|
||||
"_aten::view",
|
||||
torch::dispatch_autograd(
|
||||
CppFunction::makeFromBoxedFunction<view_kernel>()))
|
||||
.impl("_aten::t", torch::dispatch_autograd(VariableType::t))
|
||||
.impl("_aten::addmm", torch::dispatch_autograd(VariableType::addmm));
|
||||
torch::kAutograd,
|
||||
CppFunction::makeFromBoxedFunction<view_kernel>())
|
||||
.impl("_aten::t", torch::kAutograd, VariableType::t)
|
||||
.impl("_aten::addmm", torch::kAutograd, VariableType::addmm);
|
||||
} // anonymous namespace
|
||||
|
|
|
|||
Loading…
Reference in a new issue