diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h index 92d71e2628c..cf7f09bb0f9 100644 --- a/aten/src/ATen/core/op_registration/op_whitelist.h +++ b/aten/src/ATen/core/op_registration/op_whitelist.h @@ -60,9 +60,13 @@ constexpr bool op_whitelist_check(string_view op_name) { #else return op_whitelist_contains( C10_STRINGIZE(TORCH_OPERATOR_WHITELIST), - // Strip overload name (as whitelist doesn't contain overloads) - OperatorNameView::parse(op_name).name - ); + // This function is majorly used for mobile selective build with + // root operators, where the overload is included in the whitelist. + op_name); + // // Strip overload name (as whitelist doesn't contain overloads) + // // Another function based on this may be added when there's usage + // // on op names without overload. + // OperatorNameView::parse(op_name).name); #endif } @@ -76,6 +80,12 @@ constexpr bool schema_whitelist_check(string_view schema) { #endif } +// schema_whitelist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST. +// Add this API to pass arbitrary whitelist. +constexpr bool op_whitelist_contains_name_in_schema(string_view whitelist, string_view schema) { + return op_whitelist_contains(whitelist, schema.substr(0, schema.find("("))); +} + // Returns true iff the given dispatch key is on the whitelist // and should be registered. When we turn this on, the list of valid // mobile dispatch keys is hard coded (but you need to make sure diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 064f34c929e..529b36385bd 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -5,6 +5,7 @@ #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/runtime/custom_operator.h" +#include "torch/csrc/jit/runtime/register_ops_utils.h" #include "torch/jit.h" namespace torch { @@ -191,5 +192,68 @@ void testIValueKWargs() { ASSERT_EQ(result.toInt(), 19); } +void testTemplatedOperatorCreator() { + constexpr char op_list[] = "foofoo::bar.template;foo::another"; +#define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \ + torch::detail::SelectiveStr(n) + + { + // Try to register an op name that does not exist in op_list. + // Expected: the op name is not registered. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + ASSERT_EQ(ops.size(), 0); + } + + { + // The operator should be successfully registered since its name is in the + // whitelist. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foofoo::bar"); + + ASSERT_EQ(op->schema().arguments().size(), 2); + ASSERT_EQ(op->schema().arguments()[0].name(), "a"); + ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); + ASSERT_EQ(op->schema().arguments()[1].name(), "b"); + ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); + + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); + + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(&stack); + at::Tensor output; + pop(stack, output); + + ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 66cd0cebb85..32b1e8d1855 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -19,6 +19,7 @@ namespace jit { _(CreateAutodiffSubgraphs) \ _(CustomOperators) \ _(CustomOperatorAliasing) \ + _(TemplatedOperatorCreator) \ _(IValueKWargs) \ _(CustomFusion) \ _(SchemaMatching) \ diff --git a/torch/csrc/jit/runtime/custom_operator.h b/torch/csrc/jit/runtime/custom_operator.h index d1fe948e15e..45ad6676376 100644 --- a/torch/csrc/jit/runtime/custom_operator.h +++ b/torch/csrc/jit/runtime/custom_operator.h @@ -17,9 +17,13 @@ struct TORCH_API RegisterOperators { RegisterOperators() = default; /// Registers a vector of already created `Operator`s. - RegisterOperators(std::vector operators) { - for (Operator& o : operators) { - registerOperator(std::move(o)); + /// The operator element is now optional to filter null ops. It's backward + /// compatible and works for selective operator registration. + RegisterOperators(std::vector> operators) { + for (c10::optional& o : operators) { + if (o) { + registerOperator(std::move(o.value())); + } } } }; diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 6acc2aee7d7..07e464910e0 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -5,10 +5,12 @@ #include #include +#include #include #include #include #include +#include #include #include @@ -223,5 +225,28 @@ TORCH_API void ensure_c10_registerer_defined(); // Used to assert that unschematized operators have an analysis method written TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); +// A factory function to generate an optional operator. It has two +// instantiations depending on the template bool arg value. The arg can be a +// compile-time function for the selective op registration based on schema +// string. +template +c10::optional OperatorGenerator( + torch::detail::SelectiveStr schema_str, + Func&& op, + AliasAnalysisKind alias_analysis) { + return c10::optional(Operator( + std::string(schema_str), + std::forward(op), + alias_analysis)); +} + +template +c10::optional OperatorGenerator( + torch::detail::SelectiveStr schema_str, + Func&& op, + AliasAnalysisKind alias_analysis) { + return c10::nullopt; +} + } // namespace jit } // namespace torch