Add file-line info for jit.load and string frontend (#21217)

Summary:
This makes file-line reporting also work for things loaded using `torch.jit.load()` as well as the string frontend (via `CompilationUnit`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21217

Differential Revision: D15590838

Pulled By: jamesr66a

fbshipit-source-id: 6b6a12574bf9eca0b83f24f0b50535fda5863243
This commit is contained in:
James Reed 2019-05-31 23:37:38 -07:00 committed by Facebook Github Bot
parent b663eec119
commit 619261d7a7
14 changed files with 68 additions and 26 deletions

View file

@ -41,8 +41,18 @@ void testClassImport() {
CompilationUnit cu2;
std::vector<at::Tensor> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(cu1, "__torch__", classSrcs1, constantTable, nullptr);
import_libs(cu2, "__torch__", classSrcs2, constantTable, nullptr);
import_libs(
cu1,
"__torch__",
std::make_shared<Source>(classSrcs1),
constantTable,
nullptr);
import_libs(
cu2,
"__torch__",
std::make_shared<Source>(classSrcs2),
constantTable,
nullptr);
// We should get the correct version of `FooTest` for whichever namespace we
// are referencing
@ -70,13 +80,13 @@ void testScriptObject() {
import_libs(
m1.class_compilation_unit(),
"__torch__",
classSrcs1,
std::make_shared<Source>(classSrcs1),
constantTable,
nullptr);
import_libs(
m2.class_compilation_unit(),
"__torch__",
classSrcs2,
std::make_shared<Source>(classSrcs2),
constantTable,
nullptr);

View file

@ -17,7 +17,7 @@ const auto testSource = R"JIT(
void testClassParser() {
auto cu = std::make_shared<Module>();
Parser p(testSource);
Parser p(std::make_shared<Source>(testSource));
std::vector<Def> definitions;
std::vector<Resolver> resolvers;

View file

@ -3708,6 +3708,31 @@ def foo(x):
FileCheck().check('test_jit.py:{}:20'.format(lineno + 1))\
.run(scripted.graph)
def test_file_line_save_load(self):
class Scripted(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, xyz):
return torch.neg(xyz)
scripted = Scripted()
# NB: not using getExportImportCopy because that takes a different
# code path that calls CompilationUnit._import rather than
# going through the full save/load pathway
buffer = scripted.save_to_buffer()
bytesio = io.BytesIO(buffer)
scripted = torch.jit.load(bytesio)
FileCheck().check('code/archive.py:4:10').run(scripted.graph)
def test_file_line_string(self):
scripted = torch.jit.CompilationUnit('''
def foo(xyz):
return torch.neg(xyz)
''')
FileCheck().check('<string>:2:12').run(scripted.foo.graph)
def test_tensor_shape(self):
x = torch.empty(34, 56, 78)

View file

@ -253,7 +253,8 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
at::DataPtr data;
size_t size;
std::tie(data, size) = reader_.getRecord(path);
std::string src(static_cast<const char*>(data.get()), size);
auto src = std::make_shared<Source>(
std::string(static_cast<const char*>(data.get()), size), path, 0);
script::import_libs(
main_module_->class_compilation_unit(),
qualifier,
@ -321,13 +322,17 @@ void ScriptModuleDeserializer::convertModule(
std::tie(data, size) =
reader_.getRecord(module_def.torchscript_arena().key());
std::string data_str(static_cast<const char*>(data.get()), size);
auto src = std::make_shared<Source>(
std::string(static_cast<const char*>(data.get()), size),
module_def.torchscript_arena().key(),
1);
std::function<void(const std::string&)> import_callback =
[this](const std::string& qualifier) { importCallback(qualifier); };
script::import_methods(
main_module_->class_compilation_unit(),
module,
data_str,
src,
tensor_table_,
import_callback);
}

View file

@ -152,7 +152,7 @@ struct SourceResolver : public Resolver {
struct SourceImporter {
SourceImporter(
const CompilationUnit& lib_cu,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const std::function<void(const std::string&)>& import_callback)
: p_(src),
@ -269,7 +269,7 @@ struct SourceImporter {
void import_functions(
const CompilationUnit& lib_cu,
CompilationUnit& cu,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const Self& self,
const std::function<void(const std::string&)>& import_callback) {
@ -280,7 +280,7 @@ void import_functions(
void import_methods(
const CompilationUnit& lib_cu,
const std::shared_ptr<Module>& mod,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const std::function<void(const std::string&)>& import_callback) {
auto self = [&](Value* v) {
@ -299,7 +299,7 @@ void import_methods(
void import_libs(
CompilationUnit& lib_cu,
const std::string& class_qualifier,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const std::function<void(const std::string&)>& import_callback) {
SourceImporter importer(lib_cu, src, constant_table, import_callback);

View file

@ -18,7 +18,7 @@ TORCH_API void import_methods(
// CompilationUnit in which to look up any classes used
const CompilationUnit& lib_cu,
const std::shared_ptr<script::Module>& mod,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
// Callback to import any dependencies of this source before compiling
const std::function<void(const std::string&)>& import_callback);
@ -30,7 +30,7 @@ TORCH_API void import_libs(
// Qualifier for any classes that `src` defines. Looks like a module path,
// like "foo.bar.baz"
const std::string& class_qualifier,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
// Callback to import any dependencies of this source before compiling
const std::function<void(const std::string&)>& import_callback);
@ -44,7 +44,7 @@ TORCH_API void import_functions(
const CompilationUnit& lib_cu,
// CompilationoUnit to define the functions in.
CompilationUnit& cu,
const std::string& src,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const Self& self = nullptr,
const std::function<void(const std::string&)>& import_callback = nullptr);

View file

@ -23,7 +23,7 @@ class IRParser {
const std::string& str,
torch::jit::Graph* graph,
std::unordered_map<std::string, Value*>& vmap)
: L(str),
: L(std::make_shared<Source>(str)),
g(graph),
vmap(vmap),
type_parser(L, /*parse_complete_tensor_types*/ true) {}

View file

@ -3141,7 +3141,7 @@ void CompilationUnit::define(
const std::string& source,
const ResolverPtr& resolver,
const Self& self) {
Parser p(source);
Parser p(std::make_shared<Source>(source, "<string>", 1));
std::vector<Def> definitions;
std::vector<ResolverPtr> resolvers;
while (p.lexer().cur().kind != TK_EOF) {

View file

@ -25,7 +25,8 @@ namespace script {
namespace {
struct SchemaParser {
SchemaParser(const std::string& str)
: L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
: L(std::make_shared<Source>(str)),
type_parser(L, /*parse_complete_tensor_types*/ false) {}
either<OperatorName, FunctionSchema> parseDeclaration() {
OperatorName name = parseName();

View file

@ -613,7 +613,7 @@ void initJitScriptBindings(PyObject* module) {
});
m.def("parse_type_comment", [](const std::string& comment) {
Parser p(comment);
Parser p(std::make_shared<Source>(comment));
return Decl(p.parseTypeComment());
});
@ -657,7 +657,7 @@ void initJitScriptBindings(PyObject* module) {
import_functions(
CompilationUnit::_get_python_cu_const(),
cu,
src,
std::make_shared<Source>(src),
constant_table,
self,
nullptr);

View file

@ -367,8 +367,8 @@ struct Token {
};
struct Lexer {
explicit Lexer(const std::string& str)
: source(std::make_shared<Source>(str)),
explicit Lexer(const std::shared_ptr<Source>& source)
: source(source),
pos(0),
nesting(0),
indent_stack(),

View file

@ -45,8 +45,8 @@ Decl mergeTypesFromTypeComment(
}
struct ParserImpl {
explicit ParserImpl(const std::string& str)
: L(str), shared(sharedParserData()) {}
explicit ParserImpl(const std::shared_ptr<Source>& source)
: L(source), shared(sharedParserData()) {}
Ident parseIdent() {
auto t = L.expect(TK_IDENT);
@ -618,7 +618,8 @@ struct ParserImpl {
SharedParserData& shared;
};
Parser::Parser(const std::string& src) : pImpl(new ParserImpl(src)) {}
Parser::Parser(const std::shared_ptr<Source>& src)
: pImpl(new ParserImpl(src)) {}
Parser::~Parser() = default;

View file

@ -18,7 +18,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
bool is_method);
struct TORCH_API Parser {
explicit Parser(const std::string& str);
explicit Parser(const std::shared_ptr<Source>& src);
TreeRef parseFunction(bool is_method);
TreeRef parseClass();
Decl parseTypeComment();

View file

@ -195,7 +195,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
}
TypePtr ScriptTypeParser::parseType(const std::string& str) {
Parser p(str);
Parser p(std::make_shared<Source>(str));
return parseTypeFromExpr(p.parseExp());
}
} // namespace script