mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add file-line info for jit.load and string frontend (#21217)
Summary: This makes file-line reporting also work for things loaded using `torch.jit.load()` as well as the string frontend (via `CompilationUnit`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/21217 Differential Revision: D15590838 Pulled By: jamesr66a fbshipit-source-id: 6b6a12574bf9eca0b83f24f0b50535fda5863243
This commit is contained in:
parent
b663eec119
commit
619261d7a7
14 changed files with 68 additions and 26 deletions
|
|
@ -41,8 +41,18 @@ void testClassImport() {
|
|||
CompilationUnit cu2;
|
||||
std::vector<at::Tensor> constantTable;
|
||||
// Import different versions of FooTest into two namespaces.
|
||||
import_libs(cu1, "__torch__", classSrcs1, constantTable, nullptr);
|
||||
import_libs(cu2, "__torch__", classSrcs2, constantTable, nullptr);
|
||||
import_libs(
|
||||
cu1,
|
||||
"__torch__",
|
||||
std::make_shared<Source>(classSrcs1),
|
||||
constantTable,
|
||||
nullptr);
|
||||
import_libs(
|
||||
cu2,
|
||||
"__torch__",
|
||||
std::make_shared<Source>(classSrcs2),
|
||||
constantTable,
|
||||
nullptr);
|
||||
|
||||
// We should get the correct version of `FooTest` for whichever namespace we
|
||||
// are referencing
|
||||
|
|
@ -70,13 +80,13 @@ void testScriptObject() {
|
|||
import_libs(
|
||||
m1.class_compilation_unit(),
|
||||
"__torch__",
|
||||
classSrcs1,
|
||||
std::make_shared<Source>(classSrcs1),
|
||||
constantTable,
|
||||
nullptr);
|
||||
import_libs(
|
||||
m2.class_compilation_unit(),
|
||||
"__torch__",
|
||||
classSrcs2,
|
||||
std::make_shared<Source>(classSrcs2),
|
||||
constantTable,
|
||||
nullptr);
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ const auto testSource = R"JIT(
|
|||
|
||||
void testClassParser() {
|
||||
auto cu = std::make_shared<Module>();
|
||||
Parser p(testSource);
|
||||
Parser p(std::make_shared<Source>(testSource));
|
||||
std::vector<Def> definitions;
|
||||
std::vector<Resolver> resolvers;
|
||||
|
||||
|
|
|
|||
|
|
@ -3708,6 +3708,31 @@ def foo(x):
|
|||
FileCheck().check('test_jit.py:{}:20'.format(lineno + 1))\
|
||||
.run(scripted.graph)
|
||||
|
||||
def test_file_line_save_load(self):
|
||||
class Scripted(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, xyz):
|
||||
return torch.neg(xyz)
|
||||
|
||||
scripted = Scripted()
|
||||
|
||||
# NB: not using getExportImportCopy because that takes a different
|
||||
# code path that calls CompilationUnit._import rather than
|
||||
# going through the full save/load pathway
|
||||
buffer = scripted.save_to_buffer()
|
||||
bytesio = io.BytesIO(buffer)
|
||||
scripted = torch.jit.load(bytesio)
|
||||
|
||||
FileCheck().check('code/archive.py:4:10').run(scripted.graph)
|
||||
|
||||
def test_file_line_string(self):
|
||||
scripted = torch.jit.CompilationUnit('''
|
||||
def foo(xyz):
|
||||
return torch.neg(xyz)
|
||||
''')
|
||||
|
||||
FileCheck().check('<string>:2:12').run(scripted.foo.graph)
|
||||
|
||||
def test_tensor_shape(self):
|
||||
x = torch.empty(34, 56, 78)
|
||||
|
||||
|
|
|
|||
|
|
@ -253,7 +253,8 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
|
|||
at::DataPtr data;
|
||||
size_t size;
|
||||
std::tie(data, size) = reader_.getRecord(path);
|
||||
std::string src(static_cast<const char*>(data.get()), size);
|
||||
auto src = std::make_shared<Source>(
|
||||
std::string(static_cast<const char*>(data.get()), size), path, 0);
|
||||
script::import_libs(
|
||||
main_module_->class_compilation_unit(),
|
||||
qualifier,
|
||||
|
|
@ -321,13 +322,17 @@ void ScriptModuleDeserializer::convertModule(
|
|||
std::tie(data, size) =
|
||||
reader_.getRecord(module_def.torchscript_arena().key());
|
||||
std::string data_str(static_cast<const char*>(data.get()), size);
|
||||
auto src = std::make_shared<Source>(
|
||||
std::string(static_cast<const char*>(data.get()), size),
|
||||
module_def.torchscript_arena().key(),
|
||||
1);
|
||||
|
||||
std::function<void(const std::string&)> import_callback =
|
||||
[this](const std::string& qualifier) { importCallback(qualifier); };
|
||||
script::import_methods(
|
||||
main_module_->class_compilation_unit(),
|
||||
module,
|
||||
data_str,
|
||||
src,
|
||||
tensor_table_,
|
||||
import_callback);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ struct SourceResolver : public Resolver {
|
|||
struct SourceImporter {
|
||||
SourceImporter(
|
||||
const CompilationUnit& lib_cu,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const std::function<void(const std::string&)>& import_callback)
|
||||
: p_(src),
|
||||
|
|
@ -269,7 +269,7 @@ struct SourceImporter {
|
|||
void import_functions(
|
||||
const CompilationUnit& lib_cu,
|
||||
CompilationUnit& cu,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const Self& self,
|
||||
const std::function<void(const std::string&)>& import_callback) {
|
||||
|
|
@ -280,7 +280,7 @@ void import_functions(
|
|||
void import_methods(
|
||||
const CompilationUnit& lib_cu,
|
||||
const std::shared_ptr<Module>& mod,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const std::function<void(const std::string&)>& import_callback) {
|
||||
auto self = [&](Value* v) {
|
||||
|
|
@ -299,7 +299,7 @@ void import_methods(
|
|||
void import_libs(
|
||||
CompilationUnit& lib_cu,
|
||||
const std::string& class_qualifier,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const std::function<void(const std::string&)>& import_callback) {
|
||||
SourceImporter importer(lib_cu, src, constant_table, import_callback);
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ TORCH_API void import_methods(
|
|||
// CompilationUnit in which to look up any classes used
|
||||
const CompilationUnit& lib_cu,
|
||||
const std::shared_ptr<script::Module>& mod,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
// Callback to import any dependencies of this source before compiling
|
||||
const std::function<void(const std::string&)>& import_callback);
|
||||
|
|
@ -30,7 +30,7 @@ TORCH_API void import_libs(
|
|||
// Qualifier for any classes that `src` defines. Looks like a module path,
|
||||
// like "foo.bar.baz"
|
||||
const std::string& class_qualifier,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
// Callback to import any dependencies of this source before compiling
|
||||
const std::function<void(const std::string&)>& import_callback);
|
||||
|
|
@ -44,7 +44,7 @@ TORCH_API void import_functions(
|
|||
const CompilationUnit& lib_cu,
|
||||
// CompilationoUnit to define the functions in.
|
||||
CompilationUnit& cu,
|
||||
const std::string& src,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const Self& self = nullptr,
|
||||
const std::function<void(const std::string&)>& import_callback = nullptr);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class IRParser {
|
|||
const std::string& str,
|
||||
torch::jit::Graph* graph,
|
||||
std::unordered_map<std::string, Value*>& vmap)
|
||||
: L(str),
|
||||
: L(std::make_shared<Source>(str)),
|
||||
g(graph),
|
||||
vmap(vmap),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ true) {}
|
||||
|
|
|
|||
|
|
@ -3141,7 +3141,7 @@ void CompilationUnit::define(
|
|||
const std::string& source,
|
||||
const ResolverPtr& resolver,
|
||||
const Self& self) {
|
||||
Parser p(source);
|
||||
Parser p(std::make_shared<Source>(source, "<string>", 1));
|
||||
std::vector<Def> definitions;
|
||||
std::vector<ResolverPtr> resolvers;
|
||||
while (p.lexer().cur().kind != TK_EOF) {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ namespace script {
|
|||
namespace {
|
||||
struct SchemaParser {
|
||||
SchemaParser(const std::string& str)
|
||||
: L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
: L(std::make_shared<Source>(str)),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
|
||||
either<OperatorName, FunctionSchema> parseDeclaration() {
|
||||
OperatorName name = parseName();
|
||||
|
|
|
|||
|
|
@ -613,7 +613,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
});
|
||||
|
||||
m.def("parse_type_comment", [](const std::string& comment) {
|
||||
Parser p(comment);
|
||||
Parser p(std::make_shared<Source>(comment));
|
||||
return Decl(p.parseTypeComment());
|
||||
});
|
||||
|
||||
|
|
@ -657,7 +657,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
import_functions(
|
||||
CompilationUnit::_get_python_cu_const(),
|
||||
cu,
|
||||
src,
|
||||
std::make_shared<Source>(src),
|
||||
constant_table,
|
||||
self,
|
||||
nullptr);
|
||||
|
|
|
|||
|
|
@ -367,8 +367,8 @@ struct Token {
|
|||
};
|
||||
|
||||
struct Lexer {
|
||||
explicit Lexer(const std::string& str)
|
||||
: source(std::make_shared<Source>(str)),
|
||||
explicit Lexer(const std::shared_ptr<Source>& source)
|
||||
: source(source),
|
||||
pos(0),
|
||||
nesting(0),
|
||||
indent_stack(),
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ Decl mergeTypesFromTypeComment(
|
|||
}
|
||||
|
||||
struct ParserImpl {
|
||||
explicit ParserImpl(const std::string& str)
|
||||
: L(str), shared(sharedParserData()) {}
|
||||
explicit ParserImpl(const std::shared_ptr<Source>& source)
|
||||
: L(source), shared(sharedParserData()) {}
|
||||
|
||||
Ident parseIdent() {
|
||||
auto t = L.expect(TK_IDENT);
|
||||
|
|
@ -618,7 +618,8 @@ struct ParserImpl {
|
|||
SharedParserData& shared;
|
||||
};
|
||||
|
||||
Parser::Parser(const std::string& src) : pImpl(new ParserImpl(src)) {}
|
||||
Parser::Parser(const std::shared_ptr<Source>& src)
|
||||
: pImpl(new ParserImpl(src)) {}
|
||||
|
||||
Parser::~Parser() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
|
|||
bool is_method);
|
||||
|
||||
struct TORCH_API Parser {
|
||||
explicit Parser(const std::string& str);
|
||||
explicit Parser(const std::shared_ptr<Source>& src);
|
||||
TreeRef parseFunction(bool is_method);
|
||||
TreeRef parseClass();
|
||||
Decl parseTypeComment();
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
|
|||
}
|
||||
|
||||
TypePtr ScriptTypeParser::parseType(const std::string& str) {
|
||||
Parser p(str);
|
||||
Parser p(std::make_shared<Source>(str));
|
||||
return parseTypeFromExpr(p.parseExp());
|
||||
}
|
||||
} // namespace script
|
||||
|
|
|
|||
Loading…
Reference in a new issue