mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Operator generator based on templated selective build. (#43456)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43456 Introduce the template OperatorGenerator, which returns an optional Operator. It's null if the templated bool value is null. RegisterOperators() is updated to take the optional Operator. A null will not be registered. With this update the selective operator registration can be done at compile time. Tests are added to show an operator can be registered if it's in a whitelist and it will not be registered if it's not in the whitelist. Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D23283563 Pulled By: iseeyuan fbshipit-source-id: 456e0c72b2f335256be800aeabb797bd83bcf0b3
This commit is contained in:
parent
c25d0015f0
commit
288a2effa0
5 changed files with 110 additions and 6 deletions
|
|
@ -60,9 +60,13 @@ constexpr bool op_whitelist_check(string_view op_name) {
|
|||
#else
|
||||
return op_whitelist_contains(
|
||||
C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
|
||||
// Strip overload name (as whitelist doesn't contain overloads)
|
||||
OperatorNameView::parse(op_name).name
|
||||
);
|
||||
// This function is majorly used for mobile selective build with
|
||||
// root operators, where the overload is included in the whitelist.
|
||||
op_name);
|
||||
// // Strip overload name (as whitelist doesn't contain overloads)
|
||||
// // Another function based on this may be added when there's usage
|
||||
// // on op names without overload.
|
||||
// OperatorNameView::parse(op_name).name);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -76,6 +80,12 @@ constexpr bool schema_whitelist_check(string_view schema) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// schema_whitelist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
|
||||
// Add this API to pass arbitrary whitelist.
|
||||
constexpr bool op_whitelist_contains_name_in_schema(string_view whitelist, string_view schema) {
|
||||
return op_whitelist_contains(whitelist, schema.substr(0, schema.find("(")));
|
||||
}
|
||||
|
||||
// Returns true iff the given dispatch key is on the whitelist
|
||||
// and should be registered. When we turn this on, the list of valid
|
||||
// mobile dispatch keys is hard coded (but you need to make sure
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "torch/csrc/jit/ir/irparser.h"
|
||||
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
||||
#include "torch/csrc/jit/runtime/custom_operator.h"
|
||||
#include "torch/csrc/jit/runtime/register_ops_utils.h"
|
||||
#include "torch/jit.h"
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -191,5 +192,68 @@ void testIValueKWargs() {
|
|||
ASSERT_EQ(result.toInt(), 19);
|
||||
}
|
||||
|
||||
void testTemplatedOperatorCreator() {
|
||||
constexpr char op_list[] = "foofoo::bar.template;foo::another";
|
||||
#define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \
|
||||
torch::detail::SelectiveStr<c10::impl::op_whitelist_contains_name_in_schema( \
|
||||
l, n)>(n)
|
||||
|
||||
{
|
||||
// Try to register an op name that does not exist in op_list.
|
||||
// Expected: the op name is not registered.
|
||||
torch::jit::RegisterOperators reg({OperatorGenerator(
|
||||
TORCH_SELECTIVE_NAME_IN_SCHEMA(
|
||||
op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"),
|
||||
[](Stack* stack) {
|
||||
double a;
|
||||
at::Tensor b;
|
||||
pop(stack, a, b);
|
||||
push(stack, a + b);
|
||||
},
|
||||
aliasAnalysisFromSchema())});
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
|
||||
ASSERT_EQ(ops.size(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
// The operator should be successfully registered since its name is in the
|
||||
// whitelist.
|
||||
torch::jit::RegisterOperators reg({OperatorGenerator(
|
||||
TORCH_SELECTIVE_NAME_IN_SCHEMA(
|
||||
op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"),
|
||||
[](Stack* stack) {
|
||||
double a;
|
||||
at::Tensor b;
|
||||
pop(stack, a, b);
|
||||
push(stack, a + b);
|
||||
},
|
||||
aliasAnalysisFromSchema())});
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
ASSERT_EQ(op->schema().name(), "foofoo::bar");
|
||||
|
||||
ASSERT_EQ(op->schema().arguments().size(), 2);
|
||||
ASSERT_EQ(op->schema().arguments()[0].name(), "a");
|
||||
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
|
||||
ASSERT_EQ(op->schema().arguments()[1].name(), "b");
|
||||
ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
|
||||
|
||||
ASSERT_EQ(op->schema().returns().size(), 1);
|
||||
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
|
||||
|
||||
Stack stack;
|
||||
push(stack, 2.0f, at::ones(5));
|
||||
op->getOperation()(&stack);
|
||||
at::Tensor output;
|
||||
pop(stack, output);
|
||||
|
||||
ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ namespace jit {
|
|||
_(CreateAutodiffSubgraphs) \
|
||||
_(CustomOperators) \
|
||||
_(CustomOperatorAliasing) \
|
||||
_(TemplatedOperatorCreator) \
|
||||
_(IValueKWargs) \
|
||||
_(CustomFusion) \
|
||||
_(SchemaMatching) \
|
||||
|
|
|
|||
|
|
@ -17,9 +17,13 @@ struct TORCH_API RegisterOperators {
|
|||
RegisterOperators() = default;
|
||||
|
||||
/// Registers a vector of already created `Operator`s.
|
||||
RegisterOperators(std::vector<Operator> operators) {
|
||||
for (Operator& o : operators) {
|
||||
registerOperator(std::move(o));
|
||||
/// The operator element is now optional to filter null ops. It's backward
|
||||
/// compatible and works for selective operator registration.
|
||||
RegisterOperators(std::vector<c10::optional<Operator>> operators) {
|
||||
for (c10::optional<Operator>& o : operators) {
|
||||
if (o) {
|
||||
registerOperator(std::move(o.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/dispatch/OperatorOptions.h>
|
||||
#include <ATen/core/op_registration/op_whitelist.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
#include <torch/csrc/jit/runtime/operator_options.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
|
@ -223,5 +225,28 @@ TORCH_API void ensure_c10_registerer_defined();
|
|||
// Used to assert that unschematized operators have an analysis method written
|
||||
TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
|
||||
|
||||
// A factory function to generate an optional operator. It has two
|
||||
// instantiations depending on the template bool arg value. The arg can be a
|
||||
// compile-time function for the selective op registration based on schema
|
||||
// string.
|
||||
template <typename Func>
|
||||
c10::optional<Operator> OperatorGenerator(
|
||||
torch::detail::SelectiveStr<true> schema_str,
|
||||
Func&& op,
|
||||
AliasAnalysisKind alias_analysis) {
|
||||
return c10::optional<Operator>(Operator(
|
||||
std::string(schema_str),
|
||||
std::forward<Func>(op),
|
||||
alias_analysis));
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
c10::optional<Operator> OperatorGenerator(
|
||||
torch::detail::SelectiveStr<false> schema_str,
|
||||
Func&& op,
|
||||
AliasAnalysisKind alias_analysis) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in a new issue