pytorch/torch/_sources.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

139 lines
4.3 KiB
Python
Raw Normal View History

# mypy: allow-untyped-defs
import ast
import functools
import inspect
from textwrap import dedent
from typing import Any, NamedTuple, Optional
from torch._C import ErrorReport
from torch._C._jit_tree_views import SourceRangeFactory
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> tuple[list[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
Returns: (sourcelines, file_lino, filename)
"""
filename = None # in case getsourcefile throws
try:
filename = inspect.getsourcefile(obj)
sourcelines, file_lineno = inspect.getsourcelines(obj)
except OSError as e:
msg = (
f"Can't get source for {obj}. TorchScript requires source access in "
"order to carry out compilation, make sure original .py files are "
"available."
)
if error_msg:
msg += "\n" + error_msg
raise OSError(msg) from e
return sourcelines, file_lineno, filename
def normalize_source_lines(sourcelines: list[str]) -> list[str]:
"""
This helper function accepts a list of source lines. It finds the
indentation level of the function definition (`def`), then it indents
all lines in the function body to a point at or greater than that
level. This allows for comments and continued string literals that
are at a lower indentation than the rest of the code.
Args:
sourcelines: function source code, separated into lines by
the '\n' character
Returns:
A list of source lines that have been correctly aligned
"""
def remove_prefix(text, prefix):
return text[text.startswith(prefix) and len(prefix) :]
# Find the line and line number containing the function definition
Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353) (#74353) (#76771) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes https://github.com/pytorch/pytorch/issues/72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Reviewed By: zhxchen17 Differential Revision: D35206262 Pulled By: qihqi Pull Request resolved: https://github.com/pytorch/pytorch/pull/76771 Approved by: https://github.com/seemethere
2022-06-07 21:44:55 +00:00
idx = None
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353) (#74353) (#76771) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes https://github.com/pytorch/pytorch/issues/72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Reviewed By: zhxchen17 Differential Revision: D35206262 Pulled By: qihqi Pull Request resolved: https://github.com/pytorch/pytorch/pull/76771 Approved by: https://github.com/seemethere
2022-06-07 21:44:55 +00:00
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
# `parse_def()`, but we might want to handle this case in the future.
if idx is None:
return sourcelines
# Get a string representing the amount of leading whitespace
Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353) (#74353) (#76771) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes https://github.com/pytorch/pytorch/issues/72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Reviewed By: zhxchen17 Differential Revision: D35206262 Pulled By: qihqi Pull Request resolved: https://github.com/pytorch/pytorch/pull/76771 Approved by: https://github.com/seemethere
2022-06-07 21:44:55 +00:00
fn_def = sourcelines[idx]
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`
aligned_prefix = [
whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
]
aligned_suffix = [
whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
]
# Put it together again
aligned_prefix.append(fn_def)
return aligned_prefix + aligned_suffix
# Thin wrapper around SourceRangeFactory to store extra metadata
# about the function-to-be-compiled.
class SourceContext(SourceRangeFactory):
def __init__(
self,
source,
filename,
file_lineno,
leading_whitespace_len,
uses_true_division=True,
funcname=None,
):
super().__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
self.funcname = funcname
@functools.cache
def make_source_context(*args):
return SourceContext(*args)
def fake_range():
return SourceContext("", None, 0, 0).make_raw_range(0, 1)
class ParsedDef(NamedTuple):
ast: ast.Module
ctx: SourceContext
source: str
filename: Optional[str]
file_lineno: int
def parse_def(fn):
sourcelines, file_lineno, filename = get_source_lines_and_file(
fn, ErrorReport.call_stack()
)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(
f"Expected a single top-level function: {filename}:{file_lineno}"
)
leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
dedent_src.split("\n", 1)[0]
)
ctx = make_source_context(
source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
)
return ParsedDef(py_ast, ctx, source, filename, file_lineno)