Enable backward/forward compatibility for TS runtime (#57498)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57498

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D28162448

Pulled By: tugsbayasgalan

fbshipit-source-id: 5c21ced42a22aca7cee089e876e9d98d32f68955
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2021-05-07 15:39:49 -07:00 committed by Facebook GitHub Bot
parent b38f153d91
commit b0c27b44cf
11 changed files with 116 additions and 48 deletions

View file

@ -1,5 +1,4 @@
def forward(self,
x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
return _0
return torch.add(torch.mul(x, 2), y)

View file

@ -1,10 +1,10 @@
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.add(y, 1, 1)
z = torch.add(x, 5, 1)
x = torch.add(y, 1)
z = torch.add(x, 5)
z0 = z
y0 = y
_0 = bool(torch.lt(y, 8))
while _0:
y1 = torch.add_(y0, 1, 1)
y1 = torch.add_(y0, 1)
_0, z0, y0 = bool(torch.lt(y1, 8)), x, y1
return (x, z0)

View file

@ -5,11 +5,11 @@ def while_if_test(a: Tensor,
b0 = b
_0 = bool(torch.lt(a, 10))
while _0:
a1 = torch.add(a0, 1, 1)
b1 = torch.add(b0, 1, 1)
a1 = torch.add(a0, 1)
b1 = torch.add(b0, 1)
if bool(torch.gt(a1, b1)):
c0 = 2
else:
c0 = 3
_0, a0, c, b0 = bool(torch.lt(a1, 10)), a1, c0, b1
return torch.add(torch.add(a0, 1, 1), c, 1)
return torch.add(torch.add(a0, 1), c)

View file

@ -5,6 +5,6 @@ def while_test(a: Tensor,
_0 = bool(torch.lt(i, 3))
while _0:
a1 = torch.mul_(a0, a0)
i1 = torch.add_(i0, 1, 1)
i1 = torch.add_(i0, 1)
_0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1
return a0

View file

@ -0,0 +1,47 @@
import os
import sys
from torch._C import parse_ir
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# Tests that Python slice class is supported in TorchScript
class TestIgnorableArgs(JitTestCase):
def test_slice_ignorable_args_for_slice(self):
graph_str = """graph():
%15 : int = prim::Constant[value=9223372036854775807]()
%13 : int = prim::Constant[value=0]()
%10 : bool = prim::Constant[value=0]()
%8 : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%4 : int = prim::Constant[value=9]()
%5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
%7 : int[][] = prim::ListConstruct(%5, %6)
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
%16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0)
%20 : Tensor = aten::slice(%16, %0, %13, %0, %0)
return (%20)"""
graph = parse_ir(graph_str)
function = self.createFunctionFromGraph(graph)
function_copy = self.getExportImportCopy(function)
src = str(function.code)
# For a signature:
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
# We ignore trailing arguments after start=2 for dim 0
# and after end=1 for dim 1
# because in %16, %15 and %0 are default values for the schema.
FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src)
self.assertEqual(function(), function_copy())

View file

@ -35,6 +35,7 @@ from jit.test_enum import TestEnum # noqa: F401
from jit.test_string_formatting import TestStringFormatting # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401
from jit.test_hooks import TestHooks # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
from jit.test_isinstance import TestIsinstance # noqa: F401

View file

@ -0,0 +1,43 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <cstddef>
namespace torch {
namespace jit {
inline size_t CalculateNecessaryArgs(
const std::vector<Argument>& schema_args,
at::ArrayRef<Value*> actual_inputs) {
if (schema_args.size() < actual_inputs.size()) {
return actual_inputs.size();
}
// keeps track of trailing unnecessary args
int schema_size = schema_args.size();
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
// this means it is not default argument, so it is necessary
if (!schema_args.at(schema_idx).default_value().has_value()) {
return schema_idx + 1;
} else {
auto schema_value =
schema_args.at(schema_idx).default_value().value().toIValue();
// non-const value will become nullptr here, so will be marked necessary
// non-const would include prim::ListConstruct, prim::DictConstruct as
// well.
auto actual_value = toIValue(actual_inputs[schema_idx]);
if (!actual_value.has_value()) {
return schema_idx + 1;
}
// if the IR has same value as default value of the schema,
// it is not neccessary argument.
if (schema_value != actual_value.value()) {
return schema_idx + 1;
}
}
}
return 0;
}
} // namespace jit
} // namespace torch

View file

@ -811,7 +811,7 @@ const std::vector<Instruction>& Code::instructions() const {
return pImpl->instructions();
}
const std::unordered_map<std::string, int>& Code::op_to_num_specified_args()
const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
const {
return pImpl->op_to_num_specified_args();
}

View file

@ -65,7 +65,8 @@ struct TORCH_API Code {
const std::vector<c10::IValue>& constant_table() const;
const std::vector<c10::TypePtr>& type_table() const;
const std::vector<Instruction>& instructions() const;
const std::unordered_map<std::string, int>& op_to_num_specified_args() const;
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
const;
const std::vector<Node*>& instructions_source() const;
void request_bailout(size_t index);
size_t register_size() const;

View file

@ -7,6 +7,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
@ -101,7 +102,7 @@ struct CodeImpl {
// aten::foo("somestr", arg1=1, arg2=False, arg3=0.0)
// op_to_num_specified_args_["aten::foo.str"] = 3
// This is because for all usages, at most 3 args are used.
std::unordered_map<std::string, int> op_to_num_specified_args_;
std::unordered_map<std::string, size_t> op_to_num_specified_args_;
// running count of uses as we emit. When we reach use_count_[v] =
// v.uses().size() we know it is the final use and we can move rather than
@ -183,7 +184,8 @@ struct CodeImpl {
return instructions_;
}
const std::unordered_map<std::string, int>& op_to_num_specified_args() const {
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
const {
return op_to_num_specified_args_;
}
@ -734,12 +736,12 @@ struct MobileCodeImpl : CodeImpl {
// skip if schema has vararg
if (!op_schema.is_vararg()) {
auto numInclude =
calculate_necessary_args(op_schema.arguments(), node->inputs());
CalculateNecessaryArgs(op_schema.arguments(), node->inputs());
auto unique_name = op_schema.overload_name() != ""
? op_schema.name() + "." + op_schema.overload_name()
: op_schema.name();
auto it = op_to_num_specified_args_.insert(
std::pair<std::string, int>(unique_name, 0));
std::pair<std::string, size_t>(unique_name, 0));
auto prev_value = it.first->second;
it.first->second = std::max(numInclude, prev_value);
}
@ -748,36 +750,6 @@ struct MobileCodeImpl : CodeImpl {
}
}
int calculate_necessary_args(
const std::vector<Argument>& schema_args,
at::ArrayRef<Value*> actual_inputs) {
AT_ASSERT(schema_args.size() == actual_inputs.size());
// keeps track of trailing unnecessary args
int schema_size = schema_args.size();
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
// this means it is not default argument, so it is necessary
if (!schema_args.at(schema_idx).default_value().has_value()) {
return schema_idx + 1;
} else {
auto schema_value =
schema_args.at(schema_idx).default_value().value().toIValue();
// non-const value will become nullptr here, so will be marked necessary
// non-const would include prim::ListConstruct, prim::DictConstruct as
// well.
auto actual_value = toIValue(actual_inputs[schema_idx]);
if (!actual_value.has_value()) {
return schema_idx + 1;
}
// if the IR has same value as default value of the schema,
// it is not necessary argument.
if (schema_value != actual_value.value()) {
return schema_idx + 1;
}
}
}
return 0;
}
void emitOperator(Node* node) override {
CodeImpl::emitOperator(node);
// const Operator& op = node->getOperator();

View file

@ -10,6 +10,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <algorithm>
@ -1156,10 +1157,14 @@ struct PythonPrintImpl {
printOpName(stmt, node->kind());
const FunctionSchema& schema = node->schema();
stmt << "(";
for (size_t i = 0; i < node->inputs().size(); ++i) {
if (i > 0) {
// calculate how many args are specified.
// see (https://github.com/pytorch/pytorch/pull/56079) for more
// details.
size_t necessary_args =
CalculateNecessaryArgs(schema.arguments(), node->inputs());
for (size_t i = 0; i < necessary_args; ++i) {
if (i > 0)
stmt << ", ";
}
auto v = useOf(node->inputs().at(i));
// print the kwarg name if it is a kwarg only argument.
if (i < schema.arguments().size()) {