From b0c27b44cf582d7f72a78d306897b77f40b904cd Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Fri, 7 May 2021 15:39:49 -0700 Subject: [PATCH] Enable backward/forward compatibility for TS runtime (#57498) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57498 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D28162448 Pulled By: tugsbayasgalan fbshipit-source-id: 5c21ced42a22aca7cee089e876e9d98d32f68955 --- test/expect/TestJit.test_import_method.expect | 3 +- ...t.test_pretty_printer-loop_use_test.expect | 6 +-- ...t.test_pretty_printer-while_if_test.expect | 6 +-- ...tJit.test_pretty_printer-while_test.expect | 2 +- test/jit/test_ignorable_args.py | 47 +++++++++++++++++++ test/test_jit.py | 1 + .../jit/runtime/calculate_necessary_args.h | 43 +++++++++++++++++ torch/csrc/jit/runtime/interpreter.cpp | 2 +- torch/csrc/jit/runtime/interpreter.h | 3 +- .../csrc/jit/runtime/interpreter/code_impl.h | 40 +++------------- torch/csrc/jit/serialization/python_print.cpp | 11 +++-- 11 files changed, 116 insertions(+), 48 deletions(-) create mode 100644 test/jit/test_ignorable_args.py create mode 100644 torch/csrc/jit/runtime/calculate_necessary_args.h diff --git a/test/expect/TestJit.test_import_method.expect b/test/expect/TestJit.test_import_method.expect index d4a7d36ee9b..eb19228454a 100644 --- a/test/expect/TestJit.test_import_method.expect +++ b/test/expect/TestJit.test_import_method.expect @@ -1,5 +1,4 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: - _0 = torch.add(torch.mul(x, 2), y, alpha=1) - return _0 + return torch.add(torch.mul(x, 2), y) diff --git a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect index 128f7cd1fad..80c5c031b98 100644 --- a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect +++ b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect @@ -1,10 +1,10 @@ def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]: - x = torch.add(y, 1, 1) - z = torch.add(x, 5, 1) + x = torch.add(y, 1) + z = torch.add(x, 5) z0 = z y0 = y _0 = bool(torch.lt(y, 8)) while _0: - y1 = torch.add_(y0, 1, 1) + y1 = torch.add_(y0, 1) _0, z0, y0 = bool(torch.lt(y1, 8)), x, y1 return (x, z0) diff --git a/test/expect/TestJit.test_pretty_printer-while_if_test.expect b/test/expect/TestJit.test_pretty_printer-while_if_test.expect index 81a23e4d8c2..282c2d90b00 100644 --- a/test/expect/TestJit.test_pretty_printer-while_if_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_if_test.expect @@ -5,11 +5,11 @@ def while_if_test(a: Tensor, b0 = b _0 = bool(torch.lt(a, 10)) while _0: - a1 = torch.add(a0, 1, 1) - b1 = torch.add(b0, 1, 1) + a1 = torch.add(a0, 1) + b1 = torch.add(b0, 1) if bool(torch.gt(a1, b1)): c0 = 2 else: c0 = 3 _0, a0, c, b0 = bool(torch.lt(a1, 10)), a1, c0, b1 - return torch.add(torch.add(a0, 1, 1), c, 1) + return torch.add(torch.add(a0, 1), c) diff --git a/test/expect/TestJit.test_pretty_printer-while_test.expect b/test/expect/TestJit.test_pretty_printer-while_test.expect index e5297a32332..4fb5b8d1d25 100644 --- a/test/expect/TestJit.test_pretty_printer-while_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_test.expect @@ -5,6 +5,6 @@ def while_test(a: Tensor, _0 = bool(torch.lt(i, 3)) while _0: a1 = torch.mul_(a0, a0) - i1 = torch.add_(i0, 1, 1) + i1 = torch.add_(i0, 1) _0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1 return a0 diff --git a/test/jit/test_ignorable_args.py b/test/jit/test_ignorable_args.py new file mode 100644 index 00000000000..9b14b8c42f6 --- /dev/null +++ b/test/jit/test_ignorable_args.py @@ -0,0 +1,47 @@ +import os +import sys + +from torch._C import parse_ir +from torch.testing import FileCheck + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +# Tests that Python slice class is supported in TorchScript +class TestIgnorableArgs(JitTestCase): + def test_slice_ignorable_args_for_slice(self): + graph_str = """graph(): + %15 : int = prim::Constant[value=9223372036854775807]() + %13 : int = prim::Constant[value=0]() + %10 : bool = prim::Constant[value=0]() + %8 : NoneType = prim::Constant() + %0 : int = prim::Constant[value=1]() + %1 : int = prim::Constant[value=2]() + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=4]() + %4 : int = prim::Constant[value=9]() + %5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) + %6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) + %7 : int[][] = prim::ListConstruct(%5, %6) + %val.1 : Tensor = aten::tensor(%7, %8, %8, %10) + %16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0) + %20 : Tensor = aten::slice(%16, %0, %13, %0, %0) + return (%20)""" + graph = parse_ir(graph_str) + function = self.createFunctionFromGraph(graph) + function_copy = self.getExportImportCopy(function) + src = str(function.code) + # For a signature: + # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor + # We ignore trailing arguments after start=2 for dim 0 + # and after end=1 for dim 1 + # because in %16, %15 and %0 are default values for the schema. + FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src) + self.assertEqual(function(), function_copy()) diff --git a/test/test_jit.py b/test/test_jit.py index 8b15ad69c6a..be60a2508a8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -35,6 +35,7 @@ from jit.test_enum import TestEnum # noqa: F401 from jit.test_string_formatting import TestStringFormatting # noqa: F401 from jit.test_profiler import TestProfiler # noqa: F401 from jit.test_slice import TestSlice # noqa: F401 +from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401 from jit.test_hooks import TestHooks # noqa: F401 from jit.test_warn import TestWarn # noqa: F401 from jit.test_isinstance import TestIsinstance # noqa: F401 diff --git a/torch/csrc/jit/runtime/calculate_necessary_args.h b/torch/csrc/jit/runtime/calculate_necessary_args.h new file mode 100644 index 00000000000..5f37660ee14 --- /dev/null +++ b/torch/csrc/jit/runtime/calculate_necessary_args.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { + +inline size_t CalculateNecessaryArgs( + const std::vector& schema_args, + at::ArrayRef actual_inputs) { + if (schema_args.size() < actual_inputs.size()) { + return actual_inputs.size(); + } + // keeps track of trailing unnecessary args + int schema_size = schema_args.size(); + for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) { + // this means it is not default argument, so it is necessary + if (!schema_args.at(schema_idx).default_value().has_value()) { + return schema_idx + 1; + } else { + auto schema_value = + schema_args.at(schema_idx).default_value().value().toIValue(); + // non-const value will become nullptr here, so will be marked necessary + // non-const would include prim::ListConstruct, prim::DictConstruct as + // well. + auto actual_value = toIValue(actual_inputs[schema_idx]); + if (!actual_value.has_value()) { + return schema_idx + 1; + } + // if the IR has same value as default value of the schema, + // it is not neccessary argument. + if (schema_value != actual_value.value()) { + return schema_idx + 1; + } + } + } + return 0; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index fffb13c6b74..79d2c562db3 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -811,7 +811,7 @@ const std::vector& Code::instructions() const { return pImpl->instructions(); } -const std::unordered_map& Code::op_to_num_specified_args() +const std::unordered_map& Code::op_to_num_specified_args() const { return pImpl->op_to_num_specified_args(); } diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 9349d9c3f49..c11fae93f8a 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -65,7 +65,8 @@ struct TORCH_API Code { const std::vector& constant_table() const; const std::vector& type_table() const; const std::vector& instructions() const; - const std::unordered_map& op_to_num_specified_args() const; + const std::unordered_map& op_to_num_specified_args() + const; const std::vector& instructions_source() const; void request_bailout(size_t index); size_t register_size() const; diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index ae198107dd5..1387196f2c8 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -101,7 +102,7 @@ struct CodeImpl { // aten::foo("somestr", arg1=1, arg2=False, arg3=0.0) // op_to_num_specified_args_["aten::foo.str"] = 3 // This is because for all usages, at most 3 args are used. - std::unordered_map op_to_num_specified_args_; + std::unordered_map op_to_num_specified_args_; // running count of uses as we emit. When we reach use_count_[v] = // v.uses().size() we know it is the final use and we can move rather than @@ -183,7 +184,8 @@ struct CodeImpl { return instructions_; } - const std::unordered_map& op_to_num_specified_args() const { + const std::unordered_map& op_to_num_specified_args() + const { return op_to_num_specified_args_; } @@ -734,12 +736,12 @@ struct MobileCodeImpl : CodeImpl { // skip if schema has vararg if (!op_schema.is_vararg()) { auto numInclude = - calculate_necessary_args(op_schema.arguments(), node->inputs()); + CalculateNecessaryArgs(op_schema.arguments(), node->inputs()); auto unique_name = op_schema.overload_name() != "" ? op_schema.name() + "." + op_schema.overload_name() : op_schema.name(); auto it = op_to_num_specified_args_.insert( - std::pair(unique_name, 0)); + std::pair(unique_name, 0)); auto prev_value = it.first->second; it.first->second = std::max(numInclude, prev_value); } @@ -748,36 +750,6 @@ struct MobileCodeImpl : CodeImpl { } } - int calculate_necessary_args( - const std::vector& schema_args, - at::ArrayRef actual_inputs) { - AT_ASSERT(schema_args.size() == actual_inputs.size()); - // keeps track of trailing unnecessary args - int schema_size = schema_args.size(); - for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) { - // this means it is not default argument, so it is necessary - if (!schema_args.at(schema_idx).default_value().has_value()) { - return schema_idx + 1; - } else { - auto schema_value = - schema_args.at(schema_idx).default_value().value().toIValue(); - // non-const value will become nullptr here, so will be marked necessary - // non-const would include prim::ListConstruct, prim::DictConstruct as - // well. - auto actual_value = toIValue(actual_inputs[schema_idx]); - if (!actual_value.has_value()) { - return schema_idx + 1; - } - // if the IR has same value as default value of the schema, - // it is not necessary argument. - if (schema_value != actual_value.value()) { - return schema_idx + 1; - } - } - } - return 0; - } - void emitOperator(Node* node) override { CodeImpl::emitOperator(node); // const Operator& op = node->getOperator(); diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 03aaf50e668..719e308d602 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -1156,10 +1157,14 @@ struct PythonPrintImpl { printOpName(stmt, node->kind()); const FunctionSchema& schema = node->schema(); stmt << "("; - for (size_t i = 0; i < node->inputs().size(); ++i) { - if (i > 0) { + // calculate how many args are specified. + // see (https://github.com/pytorch/pytorch/pull/56079) for more + // details. + size_t necessary_args = + CalculateNecessaryArgs(schema.arguments(), node->inputs()); + for (size_t i = 0; i < necessary_args; ++i) { + if (i > 0) stmt << ", "; - } auto v = useOf(node->inputs().at(i)); // print the kwarg name if it is a kwarg only argument. if (i < schema.arguments().size()) {