mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix SourceRange comparison
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23341 Test Plan: Imported from OSS Differential Revision: D16505398 Pulled By: jamesr66a fbshipit-source-id: 0bf6a1a054c7749c0a3334654d5746dd9f5dee96
This commit is contained in:
parent
3497891c14
commit
23e526e6ff
6 changed files with 58 additions and 15 deletions
|
|
@ -3395,6 +3395,47 @@ def foo(xyz):
|
|||
self.assertEqual(records1, records2)
|
||||
self.assertEqual(records2, records3)
|
||||
|
||||
def test_serialized_source_ranges_no_dups(self):
|
||||
class FooTest3(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, lim):
|
||||
first = 1
|
||||
second = 1
|
||||
i = 1
|
||||
somenum = 5
|
||||
dontmutateme = 3
|
||||
third = 0
|
||||
while bool(i < lim):
|
||||
third = first + second
|
||||
first = second
|
||||
second = third
|
||||
j = 0
|
||||
while j < 10:
|
||||
somenum = somenum * 2
|
||||
j = j + 1
|
||||
i = i + j
|
||||
i = i + dontmutateme
|
||||
|
||||
st = second + third
|
||||
fs = first + second
|
||||
return third, st, fs
|
||||
|
||||
ft3 = FooTest3()
|
||||
|
||||
def debug_records_from_mod(mod):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(ft3, buffer)
|
||||
buffer.seek(0)
|
||||
archive = zipfile.ZipFile(buffer)
|
||||
debug_file = archive.open('archive/debug/archive.pkl')
|
||||
return pickle.load(debug_file), buffer
|
||||
|
||||
records, _ = debug_records_from_mod(ft3)
|
||||
for i in range(len(records) - 1):
|
||||
offset, source_range = records[i]
|
||||
offset2, source_range2 = records[i + 1]
|
||||
self.assertNotEqual(source_range, source_range2)
|
||||
|
||||
def test_tensor_shape(self):
|
||||
x = torch.empty(34, 56, 78)
|
||||
|
||||
|
|
|
|||
|
|
@ -204,8 +204,7 @@ SourceRange Node::sourceRange() const {
|
|||
if (source_range_) {
|
||||
return *source_range_;
|
||||
}
|
||||
std::stringstream ss;
|
||||
return SourceRange(ss.str());
|
||||
return SourceRange();
|
||||
}
|
||||
|
||||
static std::ostream& indent(std::ostream& out, size_t level) {
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ const static std::unordered_set<std::string> reserved_names = {
|
|||
|
||||
struct PythonPrintPass {
|
||||
using SourceRangeStack = std::vector<SourceRange>;
|
||||
SourceRangeStack source_range_stack_ = {SourceRange("")};
|
||||
SourceRangeStack source_range_stack_ = {SourceRange()};
|
||||
|
||||
struct WithSourceRange {
|
||||
explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@ c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
|
|||
|
||||
// a range of a shared string 'file_' with
|
||||
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
||||
// This is an empty SourceRange, used as a sentinel value.
|
||||
if (!source_) {
|
||||
return;
|
||||
}
|
||||
const std::string& str = source_->text();
|
||||
if (size() == str.size()) {
|
||||
// this is just the entire file, not a subset, so print it out.
|
||||
|
|
@ -66,11 +70,9 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
|||
if (!str.empty() && str.back() != '\n')
|
||||
out << "\n";
|
||||
// Retrieve original SourceRange, if present.
|
||||
if (source_) {
|
||||
if (auto orig_source_range = findSourceRangeThatGenerated()) {
|
||||
out << "Compiled from code ";
|
||||
orig_source_range->highlight(out);
|
||||
}
|
||||
if (auto orig_source_range = findSourceRangeThatGenerated()) {
|
||||
out << "Compiled from code ";
|
||||
orig_source_range->highlight(out);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,7 @@ struct Source {
|
|||
struct CAFFE2_API SourceRange {
|
||||
SourceRange(std::shared_ptr<Source> source_, size_t start_, size_t end_)
|
||||
: source_(std::move(source_)), start_(start_), end_(end_) {}
|
||||
explicit SourceRange(std::string string_range)
|
||||
: source_(std::make_shared<Source>(std::move(string_range))),
|
||||
start_(0),
|
||||
end_(source_->text().size()) {}
|
||||
SourceRange() : source_(nullptr), start_(0), end_(0) {}
|
||||
|
||||
const std::string text() const {
|
||||
return source_->text().substr(start(), end() - start());
|
||||
|
|
|
|||
|
|
@ -71,8 +71,12 @@ c10::IValue SourceRangeSerializer::serialize_source(
|
|||
if (serialized_sources.count(s)) {
|
||||
return serialized_sources.at(s);
|
||||
}
|
||||
std::vector<c10::IValue> elements{
|
||||
s->text(), s->filename(), (int64_t)s->starting_line_no()};
|
||||
std::vector<c10::IValue> elements;
|
||||
if (s == nullptr) {
|
||||
elements = {"", "", 0};
|
||||
} else {
|
||||
elements = {s->text(), s->filename(), (int64_t)s->starting_line_no()};
|
||||
}
|
||||
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
|
||||
serialized_sources[s] = serialized;
|
||||
return serialized;
|
||||
|
|
@ -126,7 +130,7 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
|||
findSourceRangeThatGenerated(const SourceRange& range) {
|
||||
unpickle();
|
||||
|
||||
auto query = TaggedRange(range.start(), SourceRange{""});
|
||||
auto query = TaggedRange(range.start(), SourceRange{});
|
||||
auto entry = std::upper_bound(
|
||||
unpickled_records->begin(),
|
||||
unpickled_records->end(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue