mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] Enable ruff's UP rules and autoformat torchgen/ (#105423)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105423 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
6ca3d7e1a2
commit
964d29f312
11 changed files with 30 additions and 30 deletions
|
|
@ -315,7 +315,7 @@ class PythonOutArgument(PythonArgument):
|
|||
outputs=outputs,
|
||||
)
|
||||
elif size > 1:
|
||||
if any((not a.type.is_tensor_like() for a in outputs)):
|
||||
if any(not a.type.is_tensor_like() for a in outputs):
|
||||
raise RuntimeError(f"Unsupported output type: {outputs}")
|
||||
return PythonOutArgument(
|
||||
name="out",
|
||||
|
|
@ -882,10 +882,10 @@ def signature_from_schema(
|
|||
|
||||
|
||||
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
if len(returns) <= 1 or all((r.name is None for r in returns)):
|
||||
if len(returns) <= 1 or all(r.name is None for r in returns):
|
||||
return []
|
||||
else:
|
||||
if any((r.name is None for r in returns)):
|
||||
if any(r.name is None for r in returns):
|
||||
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
||||
# resolved by the linker for some reason, which cause error in building:
|
||||
#
|
||||
|
|
@ -1163,7 +1163,7 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
|
|||
# mutable reference to temporary. Maybe we could assign it to a
|
||||
# variable itself.)
|
||||
returns_without_annotation = tuple(
|
||||
(Return(r.name, r.type, None) for r in f.func.returns)
|
||||
Return(r.name, r.type, None) for r in f.func.returns
|
||||
)
|
||||
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
|
||||
if return_str not in SUPPORTED_RETURN_TYPES:
|
||||
|
|
@ -1195,7 +1195,7 @@ def cpp_dispatch_exprs(
|
|||
exprs: Tuple[str, ...] = tuple()
|
||||
if not isinstance(python_signature, PythonSignatureDeprecated):
|
||||
# By default the exprs are consistent with the C++ signature.
|
||||
exprs = tuple((a.name for a in cpp_args))
|
||||
exprs = tuple(a.name for a in cpp_args)
|
||||
else:
|
||||
# For deprecated python signature we may need fill in some constants.
|
||||
exprs = tuple(
|
||||
|
|
@ -1426,7 +1426,7 @@ def dispatch_lambda_exprs(
|
|||
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
|
||||
)
|
||||
if not all(
|
||||
(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys())
|
||||
a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
|
||||
|
|
@ -1454,7 +1454,7 @@ torch::utils::maybe_initialize_cuda(options);
|
|||
raise RuntimeError(
|
||||
f"{f.func}: dtype in tensor_options_args without output arg"
|
||||
)
|
||||
if not all((a in tensor_options_args_names for a in ("layout", "device"))):
|
||||
if not all(a in tensor_options_args_names for a in ("layout", "device")):
|
||||
raise RuntimeError(
|
||||
f"{f.func}: incomplete tensor options for output check"
|
||||
)
|
||||
|
|
@ -1473,6 +1473,6 @@ check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dt
|
|||
)
|
||||
|
||||
return DispatchLambdaArgumentExprs(
|
||||
exprs=tuple((lambda_args_exprs[a.name] for a in lambda_args)),
|
||||
exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
|
||||
inits=inits,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class CodeTemplate:
|
|||
|
||||
@staticmethod
|
||||
def from_file(filename: str) -> "CodeTemplate":
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
return CodeTemplate(f.read(), filename)
|
||||
|
||||
def __init__(self, pattern: str, filename: str = "") -> None:
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def parse_et_yaml(
|
|||
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
||||
of fields to persist from native_functions.yaml to functions.yaml
|
||||
"""
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
|
||||
et_kernel = extract_kernel_fields(es)
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
|||
def parse_tags_yaml(path: str) -> Set[str]:
|
||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ def parse_native_yaml(
|
|||
|
||||
# if a loaded yaml is provided, use that instead of reading from path
|
||||
if loaded_yaml is None:
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
else:
|
||||
es = loaded_yaml
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def parse_backend_yaml(
|
|||
)
|
||||
}
|
||||
|
||||
with open(backend_yaml_path, "r") as f:
|
||||
with open(backend_yaml_path) as f:
|
||||
yaml_values = yaml.load(f, Loader=YamlLoader)
|
||||
assert isinstance(yaml_values, dict)
|
||||
|
||||
|
|
@ -253,9 +253,9 @@ def error_on_missing_kernels(
|
|||
full_codegen: Optional[List[OperatorName]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
with open(kernel_defn_file_path, "r") as f:
|
||||
with open(kernel_defn_file_path) as f:
|
||||
backend_defns = f.read()
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
raise AssertionError(
|
||||
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -575,7 +575,7 @@ def translate_native_yaml(
|
|||
None
|
||||
"""
|
||||
if use_aten_lib:
|
||||
with open(aten_yaml_path, "r") as aten_yaml:
|
||||
with open(aten_yaml_path) as aten_yaml:
|
||||
out_file.writelines(aten_yaml.readlines())
|
||||
return
|
||||
|
||||
|
|
@ -604,7 +604,7 @@ def translate_native_yaml(
|
|||
or os.stat(native_yaml_path).st_size == 0
|
||||
):
|
||||
return
|
||||
with open(native_yaml_path, "r") as native_yaml:
|
||||
with open(native_yaml_path) as native_yaml:
|
||||
native_es = yaml.load(native_yaml, Loader=LineLoader)
|
||||
if not native_es:
|
||||
return
|
||||
|
|
@ -641,7 +641,7 @@ def parse_yaml(
|
|||
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex],
|
||||
]:
|
||||
if path and os.path.exists(path) and os.stat(path).st_size > 0:
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
|
||||
# Check for kernel index structure
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def parse_native_functions_keys(
|
|||
)
|
||||
}
|
||||
|
||||
with open(backend_yaml_path, "r") as f:
|
||||
with open(backend_yaml_path) as f:
|
||||
yaml_values = yaml.load(f, Loader=YamlLoader)
|
||||
assert isinstance(yaml_values, dict)
|
||||
|
||||
|
|
@ -134,10 +134,10 @@ def validate_shape_inference_header(
|
|||
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
|
||||
) -> None:
|
||||
try:
|
||||
with open(shape_inference_hdr, "r") as f:
|
||||
with open(shape_inference_hdr) as f:
|
||||
shape_infr_decls = f.read()
|
||||
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
raise AssertionError(
|
||||
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class Location:
|
|||
line: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}:{}".format(self.file, self.line)
|
||||
return f"{self.file}:{self.line}"
|
||||
|
||||
|
||||
# Valid values of the 'variants' field in native_functions.yaml
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class SelectiveBuildOperator:
|
|||
if "debug_info" in op_info:
|
||||
di_list = op_info["debug_info"]
|
||||
assert isinstance(di_list, list)
|
||||
debug_info = tuple((str(x) for x in di_list))
|
||||
debug_info = tuple(str(x) for x in di_list)
|
||||
|
||||
return SelectiveBuildOperator(
|
||||
name=op_name,
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class SelectiveBuilder:
|
|||
di_list = data["debug_info"]
|
||||
assert isinstance(di_list, list)
|
||||
|
||||
debug_info = tuple((str(x) for x in di_list))
|
||||
debug_info = tuple(str(x) for x in di_list)
|
||||
|
||||
operators = {}
|
||||
operators_dict = data.get("operators", {})
|
||||
|
|
@ -141,7 +141,7 @@ class SelectiveBuilder:
|
|||
|
||||
@staticmethod
|
||||
def from_yaml_path(config_path: str) -> "SelectiveBuilder":
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
contents = yaml.safe_load(f)
|
||||
return SelectiveBuilder.from_yaml_dict(contents)
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|||
# for getting mypy to do exhaustiveness checking
|
||||
# TODO: put this somewhere else, maybe
|
||||
def assert_never(x: NoReturn) -> NoReturn:
|
||||
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
||||
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
|
@ -137,9 +137,9 @@ class FileManager:
|
|||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
||||
old_contents: Optional[str]
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
old_contents = f.read()
|
||||
except IOError:
|
||||
except OSError:
|
||||
old_contents = None
|
||||
if contents != old_contents:
|
||||
# Create output directory if it doesn't exist
|
||||
|
|
@ -157,7 +157,7 @@ class FileManager:
|
|||
# TODO: Update the comment reference to the correct location
|
||||
if "generated_comment" not in env:
|
||||
comment = "@" + "generated by torchgen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_path))
|
||||
comment += f" from {os.path.basename(template_path)}"
|
||||
env["generated_comment"] = comment
|
||||
template = _read_template(template_path)
|
||||
return template.substitute(env)
|
||||
|
|
@ -172,7 +172,7 @@ class FileManager:
|
|||
template_fn: str,
|
||||
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
||||
) -> None:
|
||||
filename = "{}/{}".format(self.install_dir, filename)
|
||||
filename = f"{self.install_dir}/{filename}"
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
self.filenames.add(filename)
|
||||
if not self.dry_run:
|
||||
|
|
|
|||
Loading…
Reference in a new issue