pytorch/torch/csrc/jit/runtime/symbolic_shape_registry.cpp

448 lines
15 KiB
C++

#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <unordered_map>
namespace torch::jit {
namespace {
std::mutex lock;
// split here to satisfy MSVC++
// https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026?view=msvc-170
const std::string _xnnpack_shape_compute_functions =
#ifdef USE_XNNPACK
R"(def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any):
assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext)
(weight, bias, stride, padding, dilation, groups) = unchecked_cast(
Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int],
ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext),
)
return conv2d(input, weight, bias, stride, padding, dilation, groups)
def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any):
assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext)
(weight, bias) = unchecked_cast(
Tuple[List[int], Optional[List[int]]],
ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext),
)
return linear(input, weight, bias)
)"
#else
""
#endif
;
// mapping function schema to shape compute graphs allows multiple functions to
// share the same shape compute graph, which is memory efficient and also will
// help speed up shape analysis by caching the result of running consecutive ops
// for a particular set of inputs with the same graph, e.g. running a series
// of pointwise ops
// we need a map from schema to shape compute graph, because the aten schema
// is not recoverable from the shape compute graph, since the shape compute
// graph replaces Tensor inputs with List[int] and there are operators like Conv
// which natively have List[int] inputs
// TODO: consider storing shape compute graph directly on operator,
// and merge into native_functions.yaml
// wrapped in function so that operators get registered before map is
// initialized
// Conditionally defined ops not yet supported in python serialized
// operators
static const OperatorMap<std::string>& conditionally_defined_ops() {
// clang-format off
static const OperatorMap<std::string> schema_to_function_graph{
#ifdef USE_XNNPACK
{"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y", "prepacked_conv2d_clamp_run"},
{"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y", "prepacked_linear_clamp_run"},
#endif
};
// clang-format on
return schema_to_function_graph;
}
std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
cached_schema_to_graph;
std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
cached_bounded_schema_to_graph;
// CompilationUnit that holds all these Functions and keeps them alive.
auto compilation_unit = std::make_shared<CompilationUnit>();
const std::optional<const FunctionSchema*> getInplaceVariant(
const FunctionSchema& base_schema) {
auto& inplace_variants =
getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_"));
for (const auto& variant : inplace_variants) {
// Need to check that all args are the same except for the first, which
// is almost the same except for the Alias info
const FunctionSchema* schema = &variant->schema();
if (!schema->isSubtypeOf(base_schema, false)) {
continue;
}
Argument self_arg = schema->arguments()[0];
if (!self_arg.alias_info()->isWrite()) {
continue;
}
Argument ret_arg = schema->returns()[0];
if (!ret_arg.alias_info()->isWrite()) {
continue;
}
return schema;
}
return std::nullopt;
}
TypePtr mapTensorToListOfInts(TypePtr type) {
if (type->cast<TensorType>()) {
return ListType::ofInts();
}
at::ArrayRef<TypePtr> contained = type->containedTypes();
if (contained.empty()) {
return type;
}
return type->withContained(
fmap(type->containedTypes(), mapTensorToListOfInts));
}
void checkForWhileLoop(
const FunctionSchema* schema,
std::shared_ptr<Graph> graph) {
DepthFirstGraphNodeIterator graph_it(graph);
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
if (node->kind() != prim::Loop) {
continue;
}
LoopView loop(node);
if (loop.loopType() != LoopView::For) {
TORCH_WARN(
"While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
*node,
" for schema ",
*schema);
}
}
}
void checkInputReturnedAsOutput(
const FunctionSchema* schema,
const std::shared_ptr<Graph>& graph) {
// Could use alias db here as well but would have to warn because it's
// imprecise
for (size_t i : c10::irange(graph->inputs().size())) {
Value* input = graph->inputs().at(i);
for (size_t j : c10::irange(graph->outputs().size())) {
Value* output = graph->outputs().at(j);
TORCH_CHECK(
input != output,
"For schema: ",
*schema,
" input index ",
i,
" is returned as output index ",
j,
". Shape functions must return new unaliased lists");
}
}
}
void checkInputAndOutputTypes(
const FunctionSchema* schema,
const std::shared_ptr<Graph>& graph) {
// allow extra unused arguments to map multiple functions to e.g. unary
TORCH_CHECK(
graph->inputs().size() <= schema->arguments().size(),
"Shape function must have fewer arguments than schema. Got ",
graph->inputs().size(),
" graph arguments and ",
schema->arguments().size(),
" schema arguments of schema: ",
*schema);
for (auto i : c10::irange(graph->inputs().size())) {
auto inp_type = schema->arguments().at(i).type();
auto mapped_type = mapTensorToListOfInts(inp_type);
auto graph_type = graph->inputs().at(i)->type();
TORCH_INTERNAL_ASSERT(
mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
"For schema type: ",
inp_type->str(),
" Expected supertype of ",
mapped_type->str(),
" but got graph_type ",
graph_type->str(),
" at index ",
i,
" of schema: ",
*schema);
}
TORCH_CHECK(
graph->outputs().size() == schema->returns().size(),
"Shape function equal number of outputs as schema. Got ",
graph->outputs().size(),
" graph outputs and ",
schema->returns().size(),
" schema returns of schema: ",
*schema);
for (auto i : c10::irange(schema->returns().size())) {
auto out_type = schema->returns().at(i).type();
auto mapped_type = mapTensorToListOfInts(out_type);
auto graph_type = graph->outputs().at(i)->type();
TORCH_INTERNAL_ASSERT(
mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
"For schema type: ",
out_type->str(),
" Expected supertype of ",
mapped_type->str(),
" but got graph_type ",
graph_type->str(),
" at output index ",
i,
" of schema: ",
*schema);
}
}
void transformShapeFunction(
const FunctionSchema* schema_string,
const std::shared_ptr<Graph>& graph) {
Inline(*graph);
// ATEN operators can return multiple unboxed values, this in contrast to
// functions defined in TorchScript or User-Registered Operators
// Which must use a Tuple
// Here, modify the shape graph of aten operators with multiple outputs
// so that they correspond to each other
if (schema_string->returns().size() > 1) {
TORCH_INTERNAL_ASSERT(
graph->outputs().size() == 1 &&
graph->outputs().at(0)->type()->cast<TupleType>());
auto tuple_node = graph->outputs().at(0)->node();
WithInsertPoint guard(graph->return_node());
auto tuple_unpack_values = createTupleUnpack(tuple_node->output());
graph->eraseOutput(0);
for (Value* v : tuple_unpack_values) {
graph->registerOutput(v);
}
GRAPH_DUMP("After Output Tuple Unpacking", graph);
}
}
std::shared_ptr<Graph> genShapeComputeFn(
const FunctionSchema* schema_string,
const std::string& shape_compute_function_name,
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
const CompilationUnit& module) {
std::shared_ptr<Graph> graph;
GRAPH_DEBUG(
"Registering schema: ",
*schema_string,
" with shape compute func: ",
shape_compute_function_name);
if (reused_functions.count(shape_compute_function_name)) {
GRAPH_DEBUG("Registering reused schema");
graph = reused_functions[shape_compute_function_name];
} else {
Function& shape_compute_function =
module.get_function(shape_compute_function_name);
graph = toGraphFunction(shape_compute_function).graph();
transformShapeFunction(schema_string, graph);
// NB: we lint the shape functions registered in source
// in a test file
// LintShapeComputeGraph(schema_string, graph);
reused_functions[shape_compute_function_name] = graph;
}
// allow extra unused arguments to map multiple functions to e.g. unary
TORCH_INTERNAL_ASSERT(
graph->inputs().size() <= schema_string->arguments().size());
return graph;
}
void registerSchema(
const FunctionSchema* schema_string,
const std::string& shape_compute_function_name,
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
const CompilationUnit& module) {
auto graph = genShapeComputeFn(
schema_string, shape_compute_function_name, reused_functions, module);
cached_schema_to_graph[schema_string] = graph;
}
void registerBoundedSchema(
const FunctionSchema* schema_string,
const std::string& lower_bound_function_name,
const std::string& upper_bound_function_name,
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
const CompilationUnit& module) {
auto lower_graph = genShapeComputeFn(
schema_string, lower_bound_function_name, reused_functions, module);
auto upper_graph = genShapeComputeFn(
schema_string, upper_bound_function_name, reused_functions, module);
cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
}
void loadModule(const CompilationUnit& module) {
std::unordered_map<std::string, std::shared_ptr<Graph>> reused_functions;
std::vector<std::pair<std::shared_ptr<Operator>, std::string>>
operator_pairs = conditionally_defined_ops().getAllKeysAndValues();
auto te_ops = get_tensorexpr_elementwise_set().getAllKeysAndValues();
operator_pairs.insert(operator_pairs.end(), te_ops.begin(), te_ops.end());
auto more_mappings = GetShapeFunctionMappings().getAllKeysAndValues();
operator_pairs.insert(
operator_pairs.end(), more_mappings.begin(), more_mappings.end());
for (const auto& pair : operator_pairs) {
const FunctionSchema* schema_string = &pair.first->schema();
const std::string& shape_compute_function_name = pair.second;
registerSchema(
schema_string, shape_compute_function_name, reused_functions, module);
// Register the inplace variant if any for functions with common shape forms
if (shape_compute_function_name == "unary") {
auto inplace_schema = getInplaceVariant(*schema_string);
if (inplace_schema.has_value()) {
registerSchema(
inplace_schema.value(), "unary", reused_functions, module);
}
}
if (shape_compute_function_name == "broadcast") {
auto inplace_schema = getInplaceVariant(*schema_string);
if (inplace_schema.has_value()) {
registerSchema(
inplace_schema.value(),
"broadcast_inplace",
reused_functions,
module);
}
}
}
// Now register the bounded schemas
for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
const FunctionSchema* schema_string = &pair.first->schema();
const std::string& lower_bound_function_name = pair.second.first;
const std::string& upper_bound_function_name = pair.second.second;
registerBoundedSchema(
schema_string,
lower_bound_function_name,
upper_bound_function_name,
reused_functions,
module);
}
}
void loadFunctions() {
try {
auto shape_compute_functions =
GetSerializedShapeFunctions() + _xnnpack_shape_compute_functions;
auto src = std::make_shared<Source>(shape_compute_functions);
std::stringstream ss;
std::vector<at::IValue> constantTable;
auto resolver = std::make_shared<SourceImporterImpl>(
compilation_unit,
&constantTable,
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
1);
compilation_unit->define(
std::nullopt, shape_compute_functions, resolver, nullptr);
loadModule(*compilation_unit);
} catch (...) {
// Reset the cache and compilation unit so that we don't get weird errors
// in later tests when one of the shape functions is invalid.
compilation_unit = std::make_shared<CompilationUnit>();
cached_schema_to_graph.clear();
throw;
}
}
} // anonymous namespace
std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
const FunctionSchema& schema) {
std::lock_guard<std::mutex> guard(lock);
if (cached_schema_to_graph.empty()) {
loadFunctions();
}
GRAPH_DEBUG("Trying to find schema: ", schema);
auto cache_it = cached_schema_to_graph.find(&schema);
if (cache_it != cached_schema_to_graph.end()) {
return cache_it->second;
}
GRAPH_DEBUG("Could not find schema: ", schema);
return std::nullopt;
}
TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema(
const FunctionSchema& schema) {
std::lock_guard<std::mutex> guard(lock);
if (cached_bounded_schema_to_graph.empty()) {
loadFunctions();
}
GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
auto cache_it = cached_bounded_schema_to_graph.find(&schema);
if (cache_it != cached_bounded_schema_to_graph.end()) {
return cache_it->second;
}
return std::nullopt;
}
void RegisterShapeComputeGraphForSchema(
const FunctionSchema& schema,
const std::shared_ptr<Graph>& g) {
std::lock_guard<std::mutex> guard(lock);
if (cached_schema_to_graph.empty()) {
loadFunctions();
}
transformShapeFunction(&schema, g);
LintShapeComputeGraph(&schema, g);
cached_schema_to_graph[&schema] = g;
}
std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas() {
std::lock_guard<std::mutex> guard(lock);
if (cached_schema_to_graph.empty()) {
loadFunctions();
}
std::vector<const FunctionSchema*> schemas;
schemas.reserve(cached_schema_to_graph.size());
for (const auto& pair : cached_schema_to_graph) {
schemas.push_back(pair.first);
}
return schemas;
}
void LintShapeComputeGraph(
const FunctionSchema* schema,
const std::shared_ptr<Graph>& graph) {
checkInputAndOutputTypes(schema, graph);
checkForWhileLoop(schema, graph);
checkInputReturnedAsOutput(schema, graph);
// TODO: other checks ? list ops which we don't symbolically optimize, etc ?
}
} // namespace torch::jit