mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Follows #133399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134537 Approved by: https://github.com/Skylion007
448 lines
15 KiB
C++
448 lines
15 KiB
C++
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
|
#include <torch/csrc/jit/ir/ir_views.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
|
|
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
|
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
|
#include <unordered_map>
|
|
|
|
namespace torch::jit {
|
|
namespace {
|
|
std::mutex lock;
|
|
|
|
// split here to satisfy MSVC++
|
|
// https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026?view=msvc-170
|
|
const std::string _xnnpack_shape_compute_functions =
|
|
#ifdef USE_XNNPACK
|
|
R"(def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any):
|
|
assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext)
|
|
(weight, bias, stride, padding, dilation, groups) = unchecked_cast(
|
|
Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int],
|
|
ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext),
|
|
)
|
|
return conv2d(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any):
|
|
assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext)
|
|
(weight, bias) = unchecked_cast(
|
|
Tuple[List[int], Optional[List[int]]],
|
|
ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext),
|
|
)
|
|
return linear(input, weight, bias)
|
|
)"
|
|
#else
|
|
""
|
|
#endif
|
|
;
|
|
|
|
// mapping function schema to shape compute graphs allows multiple functions to
|
|
// share the same shape compute graph, which is memory efficient and also will
|
|
// help speed up shape analysis by caching the result of running consecutive ops
|
|
// for a particular set of inputs with the same graph, e.g. running a series
|
|
// of pointwise ops
|
|
// we need a map from schema to shape compute graph, because the aten schema
|
|
// is not recoverable from the shape compute graph, since the shape compute
|
|
// graph replaces Tensor inputs with List[int] and there are operators like Conv
|
|
// which natively have List[int] inputs
|
|
// TODO: consider storing shape compute graph directly on operator,
|
|
// and merge into native_functions.yaml
|
|
|
|
// wrapped in function so that operators get registered before map is
|
|
// initialized
|
|
// Conditionally defined ops not yet supported in python serialized
|
|
// operators
|
|
static const OperatorMap<std::string>& conditionally_defined_ops() {
|
|
// clang-format off
|
|
static const OperatorMap<std::string> schema_to_function_graph{
|
|
#ifdef USE_XNNPACK
|
|
{"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y", "prepacked_conv2d_clamp_run"},
|
|
{"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y", "prepacked_linear_clamp_run"},
|
|
#endif
|
|
};
|
|
// clang-format on
|
|
return schema_to_function_graph;
|
|
}
|
|
|
|
std::unordered_map<const FunctionSchema*, std::shared_ptr<Graph>>
|
|
cached_schema_to_graph;
|
|
|
|
std::unordered_map<const FunctionSchema*, BoundedShapeGraphs>
|
|
cached_bounded_schema_to_graph;
|
|
|
|
// CompilationUnit that holds all these Functions and keeps them alive.
|
|
auto compilation_unit = std::make_shared<CompilationUnit>();
|
|
|
|
const std::optional<const FunctionSchema*> getInplaceVariant(
|
|
const FunctionSchema& base_schema) {
|
|
auto& inplace_variants =
|
|
getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_"));
|
|
|
|
for (const auto& variant : inplace_variants) {
|
|
// Need to check that all args are the same except for the first, which
|
|
// is almost the same except for the Alias info
|
|
const FunctionSchema* schema = &variant->schema();
|
|
if (!schema->isSubtypeOf(base_schema, false)) {
|
|
continue;
|
|
}
|
|
|
|
Argument self_arg = schema->arguments()[0];
|
|
if (!self_arg.alias_info()->isWrite()) {
|
|
continue;
|
|
}
|
|
|
|
Argument ret_arg = schema->returns()[0];
|
|
if (!ret_arg.alias_info()->isWrite()) {
|
|
continue;
|
|
}
|
|
|
|
return schema;
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
TypePtr mapTensorToListOfInts(TypePtr type) {
|
|
if (type->cast<TensorType>()) {
|
|
return ListType::ofInts();
|
|
}
|
|
at::ArrayRef<TypePtr> contained = type->containedTypes();
|
|
if (contained.empty()) {
|
|
return type;
|
|
}
|
|
return type->withContained(
|
|
fmap(type->containedTypes(), mapTensorToListOfInts));
|
|
}
|
|
|
|
void checkForWhileLoop(
|
|
const FunctionSchema* schema,
|
|
std::shared_ptr<Graph> graph) {
|
|
DepthFirstGraphNodeIterator graph_it(graph);
|
|
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
|
|
if (node->kind() != prim::Loop) {
|
|
continue;
|
|
}
|
|
LoopView loop(node);
|
|
if (loop.loopType() != LoopView::For) {
|
|
TORCH_WARN(
|
|
"While loops are not yet implemented in unrolling which may make this shape function difficult to partially evaluate: ",
|
|
*node,
|
|
" for schema ",
|
|
*schema);
|
|
}
|
|
}
|
|
}
|
|
|
|
void checkInputReturnedAsOutput(
|
|
const FunctionSchema* schema,
|
|
const std::shared_ptr<Graph>& graph) {
|
|
// Could use alias db here as well but would have to warn because it's
|
|
// imprecise
|
|
for (size_t i : c10::irange(graph->inputs().size())) {
|
|
Value* input = graph->inputs().at(i);
|
|
for (size_t j : c10::irange(graph->outputs().size())) {
|
|
Value* output = graph->outputs().at(j);
|
|
TORCH_CHECK(
|
|
input != output,
|
|
"For schema: ",
|
|
*schema,
|
|
" input index ",
|
|
i,
|
|
" is returned as output index ",
|
|
j,
|
|
". Shape functions must return new unaliased lists");
|
|
}
|
|
}
|
|
}
|
|
|
|
void checkInputAndOutputTypes(
|
|
const FunctionSchema* schema,
|
|
const std::shared_ptr<Graph>& graph) {
|
|
// allow extra unused arguments to map multiple functions to e.g. unary
|
|
TORCH_CHECK(
|
|
graph->inputs().size() <= schema->arguments().size(),
|
|
"Shape function must have fewer arguments than schema. Got ",
|
|
graph->inputs().size(),
|
|
" graph arguments and ",
|
|
schema->arguments().size(),
|
|
" schema arguments of schema: ",
|
|
*schema);
|
|
|
|
for (auto i : c10::irange(graph->inputs().size())) {
|
|
auto inp_type = schema->arguments().at(i).type();
|
|
auto mapped_type = mapTensorToListOfInts(inp_type);
|
|
auto graph_type = graph->inputs().at(i)->type();
|
|
TORCH_INTERNAL_ASSERT(
|
|
mapped_type->isSubtypeOf(graph->inputs().at(i)->type()),
|
|
"For schema type: ",
|
|
inp_type->str(),
|
|
" Expected supertype of ",
|
|
mapped_type->str(),
|
|
" but got graph_type ",
|
|
graph_type->str(),
|
|
" at index ",
|
|
i,
|
|
" of schema: ",
|
|
*schema);
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
graph->outputs().size() == schema->returns().size(),
|
|
"Shape function equal number of outputs as schema. Got ",
|
|
graph->outputs().size(),
|
|
" graph outputs and ",
|
|
schema->returns().size(),
|
|
" schema returns of schema: ",
|
|
*schema);
|
|
|
|
for (auto i : c10::irange(schema->returns().size())) {
|
|
auto out_type = schema->returns().at(i).type();
|
|
auto mapped_type = mapTensorToListOfInts(out_type);
|
|
auto graph_type = graph->outputs().at(i)->type();
|
|
TORCH_INTERNAL_ASSERT(
|
|
mapped_type->isSubtypeOf(graph->outputs().at(i)->type()),
|
|
"For schema type: ",
|
|
out_type->str(),
|
|
" Expected supertype of ",
|
|
mapped_type->str(),
|
|
" but got graph_type ",
|
|
graph_type->str(),
|
|
" at output index ",
|
|
i,
|
|
" of schema: ",
|
|
*schema);
|
|
}
|
|
}
|
|
|
|
void transformShapeFunction(
|
|
const FunctionSchema* schema_string,
|
|
const std::shared_ptr<Graph>& graph) {
|
|
Inline(*graph);
|
|
|
|
// ATEN operators can return multiple unboxed values, this in contrast to
|
|
// functions defined in TorchScript or User-Registered Operators
|
|
// Which must use a Tuple
|
|
// Here, modify the shape graph of aten operators with multiple outputs
|
|
// so that they correspond to each other
|
|
if (schema_string->returns().size() > 1) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
graph->outputs().size() == 1 &&
|
|
graph->outputs().at(0)->type()->cast<TupleType>());
|
|
auto tuple_node = graph->outputs().at(0)->node();
|
|
WithInsertPoint guard(graph->return_node());
|
|
auto tuple_unpack_values = createTupleUnpack(tuple_node->output());
|
|
graph->eraseOutput(0);
|
|
for (Value* v : tuple_unpack_values) {
|
|
graph->registerOutput(v);
|
|
}
|
|
GRAPH_DUMP("After Output Tuple Unpacking", graph);
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<Graph> genShapeComputeFn(
|
|
const FunctionSchema* schema_string,
|
|
const std::string& shape_compute_function_name,
|
|
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
|
|
const CompilationUnit& module) {
|
|
std::shared_ptr<Graph> graph;
|
|
GRAPH_DEBUG(
|
|
"Registering schema: ",
|
|
*schema_string,
|
|
" with shape compute func: ",
|
|
shape_compute_function_name);
|
|
if (reused_functions.count(shape_compute_function_name)) {
|
|
GRAPH_DEBUG("Registering reused schema");
|
|
graph = reused_functions[shape_compute_function_name];
|
|
} else {
|
|
Function& shape_compute_function =
|
|
module.get_function(shape_compute_function_name);
|
|
graph = toGraphFunction(shape_compute_function).graph();
|
|
|
|
transformShapeFunction(schema_string, graph);
|
|
// NB: we lint the shape functions registered in source
|
|
// in a test file
|
|
// LintShapeComputeGraph(schema_string, graph);
|
|
|
|
reused_functions[shape_compute_function_name] = graph;
|
|
}
|
|
// allow extra unused arguments to map multiple functions to e.g. unary
|
|
TORCH_INTERNAL_ASSERT(
|
|
graph->inputs().size() <= schema_string->arguments().size());
|
|
return graph;
|
|
}
|
|
|
|
void registerSchema(
|
|
const FunctionSchema* schema_string,
|
|
const std::string& shape_compute_function_name,
|
|
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
|
|
const CompilationUnit& module) {
|
|
auto graph = genShapeComputeFn(
|
|
schema_string, shape_compute_function_name, reused_functions, module);
|
|
|
|
cached_schema_to_graph[schema_string] = graph;
|
|
}
|
|
|
|
void registerBoundedSchema(
|
|
const FunctionSchema* schema_string,
|
|
const std::string& lower_bound_function_name,
|
|
const std::string& upper_bound_function_name,
|
|
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
|
|
const CompilationUnit& module) {
|
|
auto lower_graph = genShapeComputeFn(
|
|
schema_string, lower_bound_function_name, reused_functions, module);
|
|
auto upper_graph = genShapeComputeFn(
|
|
schema_string, upper_bound_function_name, reused_functions, module);
|
|
cached_bounded_schema_to_graph[schema_string] = {lower_graph, upper_graph};
|
|
}
|
|
|
|
void loadModule(const CompilationUnit& module) {
|
|
std::unordered_map<std::string, std::shared_ptr<Graph>> reused_functions;
|
|
|
|
std::vector<std::pair<std::shared_ptr<Operator>, std::string>>
|
|
operator_pairs = conditionally_defined_ops().getAllKeysAndValues();
|
|
auto te_ops = get_tensorexpr_elementwise_set().getAllKeysAndValues();
|
|
operator_pairs.insert(operator_pairs.end(), te_ops.begin(), te_ops.end());
|
|
auto more_mappings = GetShapeFunctionMappings().getAllKeysAndValues();
|
|
operator_pairs.insert(
|
|
operator_pairs.end(), more_mappings.begin(), more_mappings.end());
|
|
|
|
for (const auto& pair : operator_pairs) {
|
|
const FunctionSchema* schema_string = &pair.first->schema();
|
|
const std::string& shape_compute_function_name = pair.second;
|
|
|
|
registerSchema(
|
|
schema_string, shape_compute_function_name, reused_functions, module);
|
|
|
|
// Register the inplace variant if any for functions with common shape forms
|
|
if (shape_compute_function_name == "unary") {
|
|
auto inplace_schema = getInplaceVariant(*schema_string);
|
|
if (inplace_schema.has_value()) {
|
|
registerSchema(
|
|
inplace_schema.value(), "unary", reused_functions, module);
|
|
}
|
|
}
|
|
if (shape_compute_function_name == "broadcast") {
|
|
auto inplace_schema = getInplaceVariant(*schema_string);
|
|
if (inplace_schema.has_value()) {
|
|
registerSchema(
|
|
inplace_schema.value(),
|
|
"broadcast_inplace",
|
|
reused_functions,
|
|
module);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Now register the bounded schemas
|
|
for (const auto& pair : GetBoundedShapeMappings().getAllKeysAndValues()) {
|
|
const FunctionSchema* schema_string = &pair.first->schema();
|
|
const std::string& lower_bound_function_name = pair.second.first;
|
|
const std::string& upper_bound_function_name = pair.second.second;
|
|
|
|
registerBoundedSchema(
|
|
schema_string,
|
|
lower_bound_function_name,
|
|
upper_bound_function_name,
|
|
reused_functions,
|
|
module);
|
|
}
|
|
}
|
|
|
|
void loadFunctions() {
|
|
try {
|
|
auto shape_compute_functions =
|
|
GetSerializedShapeFunctions() + _xnnpack_shape_compute_functions;
|
|
|
|
auto src = std::make_shared<Source>(shape_compute_functions);
|
|
std::stringstream ss;
|
|
std::vector<at::IValue> constantTable;
|
|
auto resolver = std::make_shared<SourceImporterImpl>(
|
|
compilation_unit,
|
|
&constantTable,
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
1);
|
|
compilation_unit->define(
|
|
std::nullopt, shape_compute_functions, resolver, nullptr);
|
|
loadModule(*compilation_unit);
|
|
} catch (...) {
|
|
// Reset the cache and compilation unit so that we don't get weird errors
|
|
// in later tests when one of the shape functions is invalid.
|
|
compilation_unit = std::make_shared<CompilationUnit>();
|
|
cached_schema_to_graph.clear();
|
|
throw;
|
|
}
|
|
}
|
|
} // anonymous namespace
|
|
|
|
std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
|
|
const FunctionSchema& schema) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (cached_schema_to_graph.empty()) {
|
|
loadFunctions();
|
|
}
|
|
|
|
GRAPH_DEBUG("Trying to find schema: ", schema);
|
|
auto cache_it = cached_schema_to_graph.find(&schema);
|
|
if (cache_it != cached_schema_to_graph.end()) {
|
|
return cache_it->second;
|
|
}
|
|
GRAPH_DEBUG("Could not find schema: ", schema);
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema(
|
|
const FunctionSchema& schema) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (cached_bounded_schema_to_graph.empty()) {
|
|
loadFunctions();
|
|
}
|
|
GRAPH_DEBUG("Trying to find schema in bounded graphs: ", schema);
|
|
auto cache_it = cached_bounded_schema_to_graph.find(&schema);
|
|
if (cache_it != cached_bounded_schema_to_graph.end()) {
|
|
return cache_it->second;
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
void RegisterShapeComputeGraphForSchema(
|
|
const FunctionSchema& schema,
|
|
const std::shared_ptr<Graph>& g) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (cached_schema_to_graph.empty()) {
|
|
loadFunctions();
|
|
}
|
|
transformShapeFunction(&schema, g);
|
|
LintShapeComputeGraph(&schema, g);
|
|
|
|
cached_schema_to_graph[&schema] = g;
|
|
}
|
|
|
|
std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas() {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (cached_schema_to_graph.empty()) {
|
|
loadFunctions();
|
|
}
|
|
|
|
std::vector<const FunctionSchema*> schemas;
|
|
schemas.reserve(cached_schema_to_graph.size());
|
|
for (const auto& pair : cached_schema_to_graph) {
|
|
schemas.push_back(pair.first);
|
|
}
|
|
return schemas;
|
|
}
|
|
|
|
void LintShapeComputeGraph(
|
|
const FunctionSchema* schema,
|
|
const std::shared_ptr<Graph>& graph) {
|
|
checkInputAndOutputTypes(schema, graph);
|
|
checkForWhileLoop(schema, graph);
|
|
checkInputReturnedAsOutput(schema, graph);
|
|
// TODO: other checks ? list ops which we don't symbolically optimize, etc ?
|
|
}
|
|
|
|
} // namespace torch::jit
|