From 23e526e6fff074ea901cd567830104654bdb2ced Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 26 Jul 2019 17:43:55 -0700 Subject: [PATCH] Fix SourceRange comparison Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23341 Test Plan: Imported from OSS Differential Revision: D16505398 Pulled By: jamesr66a fbshipit-source-id: 0bf6a1a054c7749c0a3334654d5746dd9f5dee96 --- test/test_jit.py | 41 +++++++++++++++++++ torch/csrc/jit/ir.cpp | 3 +- torch/csrc/jit/passes/python_print.cpp | 2 +- torch/csrc/jit/source_range.cpp | 12 +++--- torch/csrc/jit/source_range.h | 5 +-- torch/csrc/jit/source_range_serialization.cpp | 10 +++-- 6 files changed, 58 insertions(+), 15 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 6f644cb8bae..b1d1f575455 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3395,6 +3395,47 @@ def foo(xyz): self.assertEqual(records1, records2) self.assertEqual(records2, records3) + def test_serialized_source_ranges_no_dups(self): + class FooTest3(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, lim): + first = 1 + second = 1 + i = 1 + somenum = 5 + dontmutateme = 3 + third = 0 + while bool(i < lim): + third = first + second + first = second + second = third + j = 0 + while j < 10: + somenum = somenum * 2 + j = j + 1 + i = i + j + i = i + dontmutateme + + st = second + third + fs = first + second + return third, st, fs + + ft3 = FooTest3() + + def debug_records_from_mod(mod): + buffer = io.BytesIO() + torch.jit.save(ft3, buffer) + buffer.seek(0) + archive = zipfile.ZipFile(buffer) + debug_file = archive.open('archive/debug/archive.pkl') + return pickle.load(debug_file), buffer + + records, _ = debug_records_from_mod(ft3) + for i in range(len(records) - 1): + offset, source_range = records[i] + offset2, source_range2 = records[i + 1] + self.assertNotEqual(source_range, source_range2) + def test_tensor_shape(self): x = torch.empty(34, 56, 78) diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 438dcb47eec..0982b23ca18 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -204,8 +204,7 @@ SourceRange Node::sourceRange() const { if (source_range_) { return *source_range_; } - std::stringstream ss; - return SourceRange(ss.str()); + return SourceRange(); } static std::ostream& indent(std::ostream& out, size_t level) { diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 5ba57d1de99..9cbb22f785c 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -168,7 +168,7 @@ const static std::unordered_set reserved_names = { struct PythonPrintPass { using SourceRangeStack = std::vector; - SourceRangeStack source_range_stack_ = {SourceRange("")}; + SourceRangeStack source_range_stack_ = {SourceRange()}; struct WithSourceRange { explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) { diff --git a/torch/csrc/jit/source_range.cpp b/torch/csrc/jit/source_range.cpp index 8a5e573b30d..f15a62ecf63 100644 --- a/torch/csrc/jit/source_range.cpp +++ b/torch/csrc/jit/source_range.cpp @@ -14,6 +14,10 @@ c10::optional Source::findSourceRangeThatGenerated( // a range of a shared string 'file_' with C10_EXPORT void SourceRange::highlight(std::ostream& out) const { + // This is an empty SourceRange, used as a sentinel value. + if (!source_) { + return; + } const std::string& str = source_->text(); if (size() == str.size()) { // this is just the entire file, not a subset, so print it out. @@ -66,11 +70,9 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const { if (!str.empty() && str.back() != '\n') out << "\n"; // Retrieve original SourceRange, if present. - if (source_) { - if (auto orig_source_range = findSourceRangeThatGenerated()) { - out << "Compiled from code "; - orig_source_range->highlight(out); - } + if (auto orig_source_range = findSourceRangeThatGenerated()) { + out << "Compiled from code "; + orig_source_range->highlight(out); } } diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h index 4176e799f85..fc4965886c6 100644 --- a/torch/csrc/jit/source_range.h +++ b/torch/csrc/jit/source_range.h @@ -104,10 +104,7 @@ struct Source { struct CAFFE2_API SourceRange { SourceRange(std::shared_ptr source_, size_t start_, size_t end_) : source_(std::move(source_)), start_(start_), end_(end_) {} - explicit SourceRange(std::string string_range) - : source_(std::make_shared(std::move(string_range))), - start_(0), - end_(source_->text().size()) {} + SourceRange() : source_(nullptr), start_(0), end_(0) {} const std::string text() const { return source_->text().substr(start(), end() - start()); diff --git a/torch/csrc/jit/source_range_serialization.cpp b/torch/csrc/jit/source_range_serialization.cpp index 2b7e3f1e48e..0c3041f7ed6 100644 --- a/torch/csrc/jit/source_range_serialization.cpp +++ b/torch/csrc/jit/source_range_serialization.cpp @@ -71,8 +71,12 @@ c10::IValue SourceRangeSerializer::serialize_source( if (serialized_sources.count(s)) { return serialized_sources.at(s); } - std::vector elements{ - s->text(), s->filename(), (int64_t)s->starting_line_no()}; + std::vector elements; + if (s == nullptr) { + elements = {"", "", 0}; + } else { + elements = {s->text(), s->filename(), (int64_t)s->starting_line_no()}; + } auto serialized = c10::ivalue::Tuple::create(std::move(elements)); serialized_sources[s] = serialized; return serialized; @@ -126,7 +130,7 @@ c10::optional ConcreteSourceRangeUnpickler:: findSourceRangeThatGenerated(const SourceRange& range) { unpickle(); - auto query = TaggedRange(range.start(), SourceRange{""}); + auto query = TaggedRange(range.start(), SourceRange{}); auto entry = std::upper_bound( unpickled_records->begin(), unpickled_records->end(),