mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable backward/forward compatibility for TS runtime (#57498)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57498 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D28162448 Pulled By: tugsbayasgalan fbshipit-source-id: 5c21ced42a22aca7cee089e876e9d98d32f68955
This commit is contained in:
parent
b38f153d91
commit
b0c27b44cf
11 changed files with 116 additions and 48 deletions
|
|
@ -1,5 +1,4 @@
|
|||
def forward(self,
|
||||
x: Tensor,
|
||||
y: Tensor) -> Tensor:
|
||||
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
|
||||
return _0
|
||||
return torch.add(torch.mul(x, 2), y)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x = torch.add(y, 1, 1)
|
||||
z = torch.add(x, 5, 1)
|
||||
x = torch.add(y, 1)
|
||||
z = torch.add(x, 5)
|
||||
z0 = z
|
||||
y0 = y
|
||||
_0 = bool(torch.lt(y, 8))
|
||||
while _0:
|
||||
y1 = torch.add_(y0, 1, 1)
|
||||
y1 = torch.add_(y0, 1)
|
||||
_0, z0, y0 = bool(torch.lt(y1, 8)), x, y1
|
||||
return (x, z0)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ def while_if_test(a: Tensor,
|
|||
b0 = b
|
||||
_0 = bool(torch.lt(a, 10))
|
||||
while _0:
|
||||
a1 = torch.add(a0, 1, 1)
|
||||
b1 = torch.add(b0, 1, 1)
|
||||
a1 = torch.add(a0, 1)
|
||||
b1 = torch.add(b0, 1)
|
||||
if bool(torch.gt(a1, b1)):
|
||||
c0 = 2
|
||||
else:
|
||||
c0 = 3
|
||||
_0, a0, c, b0 = bool(torch.lt(a1, 10)), a1, c0, b1
|
||||
return torch.add(torch.add(a0, 1, 1), c, 1)
|
||||
return torch.add(torch.add(a0, 1), c)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,6 @@ def while_test(a: Tensor,
|
|||
_0 = bool(torch.lt(i, 3))
|
||||
while _0:
|
||||
a1 = torch.mul_(a0, a0)
|
||||
i1 = torch.add_(i0, 1, 1)
|
||||
i1 = torch.add_(i0, 1)
|
||||
_0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1
|
||||
return a0
|
||||
|
|
|
|||
47
test/jit/test_ignorable_args.py
Normal file
47
test/jit/test_ignorable_args.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from torch._C import parse_ir
|
||||
from torch.testing import FileCheck
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestIgnorableArgs(JitTestCase):
|
||||
def test_slice_ignorable_args_for_slice(self):
|
||||
graph_str = """graph():
|
||||
%15 : int = prim::Constant[value=9223372036854775807]()
|
||||
%13 : int = prim::Constant[value=0]()
|
||||
%10 : bool = prim::Constant[value=0]()
|
||||
%8 : NoneType = prim::Constant()
|
||||
%0 : int = prim::Constant[value=1]()
|
||||
%1 : int = prim::Constant[value=2]()
|
||||
%2 : int = prim::Constant[value=3]()
|
||||
%3 : int = prim::Constant[value=4]()
|
||||
%4 : int = prim::Constant[value=9]()
|
||||
%5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
|
||||
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
|
||||
%7 : int[][] = prim::ListConstruct(%5, %6)
|
||||
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
|
||||
%16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0)
|
||||
%20 : Tensor = aten::slice(%16, %0, %13, %0, %0)
|
||||
return (%20)"""
|
||||
graph = parse_ir(graph_str)
|
||||
function = self.createFunctionFromGraph(graph)
|
||||
function_copy = self.getExportImportCopy(function)
|
||||
src = str(function.code)
|
||||
# For a signature:
|
||||
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
|
||||
# We ignore trailing arguments after start=2 for dim 0
|
||||
# and after end=1 for dim 1
|
||||
# because in %16, %15 and %0 are default values for the schema.
|
||||
FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src)
|
||||
self.assertEqual(function(), function_copy())
|
||||
|
|
@ -35,6 +35,7 @@ from jit.test_enum import TestEnum # noqa: F401
|
|||
from jit.test_string_formatting import TestStringFormatting # noqa: F401
|
||||
from jit.test_profiler import TestProfiler # noqa: F401
|
||||
from jit.test_slice import TestSlice # noqa: F401
|
||||
from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401
|
||||
from jit.test_hooks import TestHooks # noqa: F401
|
||||
from jit.test_warn import TestWarn # noqa: F401
|
||||
from jit.test_isinstance import TestIsinstance # noqa: F401
|
||||
|
|
|
|||
43
torch/csrc/jit/runtime/calculate_necessary_args.h
Normal file
43
torch/csrc/jit/runtime/calculate_necessary_args.h
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <cstddef>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
inline size_t CalculateNecessaryArgs(
|
||||
const std::vector<Argument>& schema_args,
|
||||
at::ArrayRef<Value*> actual_inputs) {
|
||||
if (schema_args.size() < actual_inputs.size()) {
|
||||
return actual_inputs.size();
|
||||
}
|
||||
// keeps track of trailing unnecessary args
|
||||
int schema_size = schema_args.size();
|
||||
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
|
||||
// this means it is not default argument, so it is necessary
|
||||
if (!schema_args.at(schema_idx).default_value().has_value()) {
|
||||
return schema_idx + 1;
|
||||
} else {
|
||||
auto schema_value =
|
||||
schema_args.at(schema_idx).default_value().value().toIValue();
|
||||
// non-const value will become nullptr here, so will be marked necessary
|
||||
// non-const would include prim::ListConstruct, prim::DictConstruct as
|
||||
// well.
|
||||
auto actual_value = toIValue(actual_inputs[schema_idx]);
|
||||
if (!actual_value.has_value()) {
|
||||
return schema_idx + 1;
|
||||
}
|
||||
// if the IR has same value as default value of the schema,
|
||||
// it is not neccessary argument.
|
||||
if (schema_value != actual_value.value()) {
|
||||
return schema_idx + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -811,7 +811,7 @@ const std::vector<Instruction>& Code::instructions() const {
|
|||
return pImpl->instructions();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int>& Code::op_to_num_specified_args()
|
||||
const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
|
||||
const {
|
||||
return pImpl->op_to_num_specified_args();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ struct TORCH_API Code {
|
|||
const std::vector<c10::IValue>& constant_table() const;
|
||||
const std::vector<c10::TypePtr>& type_table() const;
|
||||
const std::vector<Instruction>& instructions() const;
|
||||
const std::unordered_map<std::string, int>& op_to_num_specified_args() const;
|
||||
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
|
||||
const;
|
||||
const std::vector<Node*>& instructions_source() const;
|
||||
void request_bailout(size_t index);
|
||||
size_t register_size() const;
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/bailout_graph.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
|
||||
|
|
@ -101,7 +102,7 @@ struct CodeImpl {
|
|||
// aten::foo("somestr", arg1=1, arg2=False, arg3=0.0)
|
||||
// op_to_num_specified_args_["aten::foo.str"] = 3
|
||||
// This is because for all usages, at most 3 args are used.
|
||||
std::unordered_map<std::string, int> op_to_num_specified_args_;
|
||||
std::unordered_map<std::string, size_t> op_to_num_specified_args_;
|
||||
|
||||
// running count of uses as we emit. When we reach use_count_[v] =
|
||||
// v.uses().size() we know it is the final use and we can move rather than
|
||||
|
|
@ -183,7 +184,8 @@ struct CodeImpl {
|
|||
return instructions_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int>& op_to_num_specified_args() const {
|
||||
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
|
||||
const {
|
||||
return op_to_num_specified_args_;
|
||||
}
|
||||
|
||||
|
|
@ -734,12 +736,12 @@ struct MobileCodeImpl : CodeImpl {
|
|||
// skip if schema has vararg
|
||||
if (!op_schema.is_vararg()) {
|
||||
auto numInclude =
|
||||
calculate_necessary_args(op_schema.arguments(), node->inputs());
|
||||
CalculateNecessaryArgs(op_schema.arguments(), node->inputs());
|
||||
auto unique_name = op_schema.overload_name() != ""
|
||||
? op_schema.name() + "." + op_schema.overload_name()
|
||||
: op_schema.name();
|
||||
auto it = op_to_num_specified_args_.insert(
|
||||
std::pair<std::string, int>(unique_name, 0));
|
||||
std::pair<std::string, size_t>(unique_name, 0));
|
||||
auto prev_value = it.first->second;
|
||||
it.first->second = std::max(numInclude, prev_value);
|
||||
}
|
||||
|
|
@ -748,36 +750,6 @@ struct MobileCodeImpl : CodeImpl {
|
|||
}
|
||||
}
|
||||
|
||||
int calculate_necessary_args(
|
||||
const std::vector<Argument>& schema_args,
|
||||
at::ArrayRef<Value*> actual_inputs) {
|
||||
AT_ASSERT(schema_args.size() == actual_inputs.size());
|
||||
// keeps track of trailing unnecessary args
|
||||
int schema_size = schema_args.size();
|
||||
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
|
||||
// this means it is not default argument, so it is necessary
|
||||
if (!schema_args.at(schema_idx).default_value().has_value()) {
|
||||
return schema_idx + 1;
|
||||
} else {
|
||||
auto schema_value =
|
||||
schema_args.at(schema_idx).default_value().value().toIValue();
|
||||
// non-const value will become nullptr here, so will be marked necessary
|
||||
// non-const would include prim::ListConstruct, prim::DictConstruct as
|
||||
// well.
|
||||
auto actual_value = toIValue(actual_inputs[schema_idx]);
|
||||
if (!actual_value.has_value()) {
|
||||
return schema_idx + 1;
|
||||
}
|
||||
// if the IR has same value as default value of the schema,
|
||||
// it is not necessary argument.
|
||||
if (schema_value != actual_value.value()) {
|
||||
return schema_idx + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void emitOperator(Node* node) override {
|
||||
CodeImpl::emitOperator(node);
|
||||
// const Operator& op = node->getOperator();
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
|
|
@ -1156,10 +1157,14 @@ struct PythonPrintImpl {
|
|||
printOpName(stmt, node->kind());
|
||||
const FunctionSchema& schema = node->schema();
|
||||
stmt << "(";
|
||||
for (size_t i = 0; i < node->inputs().size(); ++i) {
|
||||
if (i > 0) {
|
||||
// calculate how many args are specified.
|
||||
// see (https://github.com/pytorch/pytorch/pull/56079) for more
|
||||
// details.
|
||||
size_t necessary_args =
|
||||
CalculateNecessaryArgs(schema.arguments(), node->inputs());
|
||||
for (size_t i = 0; i < necessary_args; ++i) {
|
||||
if (i > 0)
|
||||
stmt << ", ";
|
||||
}
|
||||
auto v = useOf(node->inputs().at(i));
|
||||
// print the kwarg name if it is a kwarg only argument.
|
||||
if (i < schema.arguments().size()) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue