Fix SourceRange comparison

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23341

Test Plan: Imported from OSS

Differential Revision: D16505398

Pulled By: jamesr66a

fbshipit-source-id: 0bf6a1a054c7749c0a3334654d5746dd9f5dee96
This commit is contained in:
James Reed 2019-07-26 17:43:55 -07:00 committed by Facebook Github Bot
parent 3497891c14
commit 23e526e6ff
6 changed files with 58 additions and 15 deletions

View file

@ -3395,6 +3395,47 @@ def foo(xyz):
self.assertEqual(records1, records2)
self.assertEqual(records2, records3)
def test_serialized_source_ranges_no_dups(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, lim):
first = 1
second = 1
i = 1
somenum = 5
dontmutateme = 3
third = 0
while bool(i < lim):
third = first + second
first = second
second = third
j = 0
while j < 10:
somenum = somenum * 2
j = j + 1
i = i + j
i = i + dontmutateme
st = second + third
fs = first + second
return third, st, fs
ft3 = FooTest3()
def debug_records_from_mod(mod):
buffer = io.BytesIO()
torch.jit.save(ft3, buffer)
buffer.seek(0)
archive = zipfile.ZipFile(buffer)
debug_file = archive.open('archive/debug/archive.pkl')
return pickle.load(debug_file), buffer
records, _ = debug_records_from_mod(ft3)
for i in range(len(records) - 1):
offset, source_range = records[i]
offset2, source_range2 = records[i + 1]
self.assertNotEqual(source_range, source_range2)
def test_tensor_shape(self):
x = torch.empty(34, 56, 78)

View file

@ -204,8 +204,7 @@ SourceRange Node::sourceRange() const {
if (source_range_) {
return *source_range_;
}
std::stringstream ss;
return SourceRange(ss.str());
return SourceRange();
}
static std::ostream& indent(std::ostream& out, size_t level) {

View file

@ -168,7 +168,7 @@ const static std::unordered_set<std::string> reserved_names = {
struct PythonPrintPass {
using SourceRangeStack = std::vector<SourceRange>;
SourceRangeStack source_range_stack_ = {SourceRange("")};
SourceRangeStack source_range_stack_ = {SourceRange()};
struct WithSourceRange {
explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {

View file

@ -14,6 +14,10 @@ c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
// a range of a shared string 'file_' with
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
// This is an empty SourceRange, used as a sentinel value.
if (!source_) {
return;
}
const std::string& str = source_->text();
if (size() == str.size()) {
// this is just the entire file, not a subset, so print it out.
@ -66,11 +70,9 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
if (!str.empty() && str.back() != '\n')
out << "\n";
// Retrieve original SourceRange, if present.
if (source_) {
if (auto orig_source_range = findSourceRangeThatGenerated()) {
out << "Compiled from code ";
orig_source_range->highlight(out);
}
if (auto orig_source_range = findSourceRangeThatGenerated()) {
out << "Compiled from code ";
orig_source_range->highlight(out);
}
}

View file

@ -104,10 +104,7 @@ struct Source {
struct CAFFE2_API SourceRange {
SourceRange(std::shared_ptr<Source> source_, size_t start_, size_t end_)
: source_(std::move(source_)), start_(start_), end_(end_) {}
explicit SourceRange(std::string string_range)
: source_(std::make_shared<Source>(std::move(string_range))),
start_(0),
end_(source_->text().size()) {}
SourceRange() : source_(nullptr), start_(0), end_(0) {}
const std::string text() const {
return source_->text().substr(start(), end() - start());

View file

@ -71,8 +71,12 @@ c10::IValue SourceRangeSerializer::serialize_source(
if (serialized_sources.count(s)) {
return serialized_sources.at(s);
}
std::vector<c10::IValue> elements{
s->text(), s->filename(), (int64_t)s->starting_line_no()};
std::vector<c10::IValue> elements;
if (s == nullptr) {
elements = {"", "", 0};
} else {
elements = {s->text(), s->filename(), (int64_t)s->starting_line_no()};
}
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
serialized_sources[s] = serialized;
return serialized;
@ -126,7 +130,7 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
findSourceRangeThatGenerated(const SourceRange& range) {
unpickle();
auto query = TaggedRange(range.start(), SourceRange{""});
auto query = TaggedRange(range.start(), SourceRange{});
auto entry = std::upper_bound(
unpickled_records->begin(),
unpickled_records->end(),