mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE][CI] bump ruff to 0.9.2: multiline assert statements
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements
> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
> f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```
ghstack-source-id: f4c1376e01
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
This commit is contained in:
parent
33d0dbcdc0
commit
84a01dd825
80 changed files with 622 additions and 610 deletions
6
.github/scripts/label_utils.py
vendored
6
.github/scripts/label_utils.py
vendored
|
|
@ -63,9 +63,9 @@ def gh_get_labels(org: str, repo: str) -> list[str]:
|
|||
update_labels(labels, info)
|
||||
|
||||
last_page = get_last_page_num_from_header(header)
|
||||
assert (
|
||||
last_page > 0
|
||||
), "Error reading header info to determine total number of pages of labels"
|
||||
assert last_page > 0, (
|
||||
"Error reading header info to determine total number of pages of labels"
|
||||
)
|
||||
for page_number in range(2, last_page + 1): # skip page 1
|
||||
_, info = request_for_labels(prefix + f"&page={page_number}")
|
||||
update_labels(labels, info)
|
||||
|
|
|
|||
|
|
@ -1476,7 +1476,7 @@ init_command = [
|
|||
'black==23.12.1',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==5.13.2',
|
||||
'ruff==0.8.4', # sync with RUFF
|
||||
'ruff==0.9.2', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
@ -1561,7 +1561,7 @@ init_command = [
|
|||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.8.4', # sync with PYFMT
|
||||
'ruff==0.9.2', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
|
|||
|
|
@ -1006,9 +1006,9 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
|
|||
row,
|
||||
)
|
||||
c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
|
||||
assert (
|
||||
output_filename.find(".csv") > 0
|
||||
), f"expected output_filename to be a .csv, but got {output_filename}"
|
||||
assert output_filename.find(".csv") > 0, (
|
||||
f"expected output_filename to be a .csv, but got {output_filename}"
|
||||
)
|
||||
write_outputs(
|
||||
output_filename[:-4] + "_compilation_metrics.csv",
|
||||
first_headers + c_headers,
|
||||
|
|
@ -1182,9 +1182,9 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||
row,
|
||||
)
|
||||
c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
|
||||
assert (
|
||||
output_filename.find(".csv") > 0
|
||||
), f"expected output_filename to be a .csv, but got {output_filename}"
|
||||
assert output_filename.find(".csv") > 0, (
|
||||
f"expected output_filename to be a .csv, but got {output_filename}"
|
||||
)
|
||||
write_outputs(
|
||||
output_filename[:-4] + "_compilation_metrics.csv",
|
||||
first_headers + c_headers,
|
||||
|
|
@ -1997,16 +1997,16 @@ class BenchmarkRunner:
|
|||
def deepcopy_and_maybe_parallelize(self, model):
|
||||
model = self.deepcopy_model(model)
|
||||
if self.args.ddp:
|
||||
assert (
|
||||
torch.distributed.is_available()
|
||||
), "Can't use DDP without a distributed enabled build"
|
||||
assert torch.distributed.is_available(), (
|
||||
"Can't use DDP without a distributed enabled build"
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
model = DDP(model, find_unused_parameters=True)
|
||||
elif self.args.fsdp:
|
||||
assert (
|
||||
torch.distributed.is_available()
|
||||
), "Can't use FSDP without a distributed enabled build"
|
||||
assert torch.distributed.is_available(), (
|
||||
"Can't use FSDP without a distributed enabled build"
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
MixedPrecision,
|
||||
|
|
@ -2375,9 +2375,9 @@ class BenchmarkRunner:
|
|||
self, name, model, example_inputs, optimize_ctx, experiment, tag=None
|
||||
):
|
||||
"Run performance test in non-alternately."
|
||||
assert (
|
||||
experiment.func is latency_experiment
|
||||
), "Must run with latency_experiment."
|
||||
assert experiment.func is latency_experiment, (
|
||||
"Must run with latency_experiment."
|
||||
)
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=10):
|
||||
peak_mem = 0
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ def bench(shape, layer_id, p, fusion_types=[""]):
|
|||
torch._dynamo.reset()
|
||||
torch._inductor.metrics.reset()
|
||||
triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
|
||||
assert (
|
||||
torch._inductor.metrics.generated_kernel_count == 1
|
||||
), "codegen #kernel != 1"
|
||||
assert torch._inductor.metrics.generated_kernel_count == 1, (
|
||||
"codegen #kernel != 1"
|
||||
)
|
||||
row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
|
||||
|
||||
p.add_row(row)
|
||||
|
|
|
|||
|
|
@ -265,9 +265,9 @@ class OperatorInputsLoader:
|
|||
def get_inputs_for_operator(
|
||||
self, operator, dtype=None, device="cuda"
|
||||
) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]:
|
||||
assert (
|
||||
str(operator) in self.operator_db
|
||||
), f"Could not find {operator}, must provide overload"
|
||||
assert str(operator) in self.operator_db, (
|
||||
f"Could not find {operator}, must provide overload"
|
||||
)
|
||||
|
||||
if "embedding" in str(operator):
|
||||
log.warning("Embedding inputs NYI, input data cannot be randomized")
|
||||
|
|
@ -302,9 +302,9 @@ class OperatorInputsLoader:
|
|||
yield op
|
||||
|
||||
def get_call_frequency(self, op):
|
||||
assert (
|
||||
str(op) in self.operator_db
|
||||
), f"Could not find {op}, must provide overload"
|
||||
assert str(op) in self.operator_db, (
|
||||
f"Could not find {op}, must provide overload"
|
||||
)
|
||||
|
||||
count = 0
|
||||
for counter in self.operator_db[str(op)].values():
|
||||
|
|
|
|||
|
|
@ -538,21 +538,21 @@ class MultiheadAttentionContainer(torch.nn.Module):
|
|||
query.size(-1),
|
||||
)
|
||||
q, k, v = self.in_proj_container(query, key, value)
|
||||
assert (
|
||||
q.size(-1) % self.nhead == 0
|
||||
), "query's embed_dim must be divisible by the number of heads"
|
||||
assert q.size(-1) % self.nhead == 0, (
|
||||
"query's embed_dim must be divisible by the number of heads"
|
||||
)
|
||||
head_dim = q.size(-1) // self.nhead
|
||||
q = q.reshape(tgt_len, bsz * self.nhead, head_dim)
|
||||
|
||||
assert (
|
||||
k.size(-1) % self.nhead == 0
|
||||
), "key's embed_dim must be divisible by the number of heads"
|
||||
assert k.size(-1) % self.nhead == 0, (
|
||||
"key's embed_dim must be divisible by the number of heads"
|
||||
)
|
||||
head_dim = k.size(-1) // self.nhead
|
||||
k = k.reshape(src_len, bsz * self.nhead, head_dim)
|
||||
|
||||
assert (
|
||||
v.size(-1) % self.nhead == 0
|
||||
), "value's embed_dim must be divisible by the number of heads"
|
||||
assert v.size(-1) % self.nhead == 0, (
|
||||
"value's embed_dim must be divisible by the number of heads"
|
||||
)
|
||||
head_dim = v.size(-1) // self.nhead
|
||||
v = v.reshape(src_len, bsz * self.nhead, head_dim)
|
||||
|
||||
|
|
@ -629,9 +629,9 @@ class ScaledDotProduct(torch.nn.Module):
|
|||
attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1])
|
||||
|
||||
tgt_len, head_dim = query.size(-3), query.size(-1)
|
||||
assert (
|
||||
query.size(-1) == key.size(-1) == value.size(-1)
|
||||
), "The feature dim of query, key, value must be equal."
|
||||
assert query.size(-1) == key.size(-1) == value.size(-1), (
|
||||
"The feature dim of query, key, value must be equal."
|
||||
)
|
||||
assert key.size() == value.size(), "Shape of key, value must match"
|
||||
src_len = key.size(-3)
|
||||
batch_heads = max(query.size(-2), key.size(-2))
|
||||
|
|
|
|||
|
|
@ -884,9 +884,9 @@ class HungarianMatcher(nn.Module):
|
|||
self.cost_class = cost_class
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ class ModelArgs:
|
|||
# take longer name (as it have more symbols matched)
|
||||
if len(config) > 1:
|
||||
config.sort(key=len, reverse=True)
|
||||
assert len(config[0]) != len(
|
||||
config[1]
|
||||
), name # make sure only one 'best' match
|
||||
assert len(config[0]) != len(config[1]), (
|
||||
name
|
||||
) # make sure only one 'best' match
|
||||
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
|
|||
|
||||
# And again, the type checker has no way of knowing that this line is valid.
|
||||
jit_model = module.jit_model # type: ignore[attr-defined]
|
||||
assert isinstance(
|
||||
jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)
|
||||
), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
|
||||
assert isinstance(jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)), (
|
||||
f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
|
||||
)
|
||||
jit_model.save(artifact_path) # type: ignore[call-arg]
|
||||
|
||||
# Cleanup now that we have the actual serialized model.
|
||||
|
|
|
|||
|
|
@ -276,9 +276,9 @@ class BenchmarkRunner:
|
|||
if c in open_to_close.keys():
|
||||
curr_brackets.append(c)
|
||||
elif c in open_to_close.values():
|
||||
assert (
|
||||
curr_brackets and open_to_close[curr_brackets[-1]] == c
|
||||
), "ERROR: not able to parse the string!"
|
||||
assert curr_brackets and open_to_close[curr_brackets[-1]] == c, (
|
||||
"ERROR: not able to parse the string!"
|
||||
)
|
||||
curr_brackets.pop()
|
||||
elif c == "," and (not curr_brackets):
|
||||
break_idxs.append(i)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from torch._inductor.runtime.benchmarking import benchmarker
|
|||
|
||||
|
||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
assert (
|
||||
sparsity <= 1.0 and sparsity >= 0.0
|
||||
), "sparsity should be a value between 0 and 1"
|
||||
assert sparsity <= 1.0 and sparsity >= 0.0, (
|
||||
"sparsity should be a value between 0 and 1"
|
||||
)
|
||||
assert M % blocksize[0] == 0
|
||||
assert N % blocksize[1] == 0
|
||||
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
|
||||
|
|
|
|||
|
|
@ -84,9 +84,9 @@ class CompositeMHA(torch.nn.Module):
|
|||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.embed_dim = embed_dim
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, (
|
||||
"embed_dim must be divisible by num_heads"
|
||||
)
|
||||
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
|
|
|
|||
|
|
@ -49,9 +49,9 @@ class ExperimentConfig:
|
|||
backends: list[str]
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
len(self.shape) == 6
|
||||
), "Shape must be of length 6" # [B, Hq, M, Hkv, N, D]
|
||||
assert len(self.shape) == 6, (
|
||||
"Shape must be of length 6"
|
||||
) # [B, Hq, M, Hkv, N, D]
|
||||
|
||||
def asdict(self):
|
||||
# Convert the dataclass instance to a dictionary
|
||||
|
|
|
|||
|
|
@ -625,9 +625,9 @@ def split(self, split_size_or_sections, dim=0):
|
|||
unbound.append(i)
|
||||
|
||||
if unbound:
|
||||
assert (
|
||||
total_bound_size <= size
|
||||
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
assert total_bound_size <= size, (
|
||||
f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
)
|
||||
remaining_size = size - total_bound_size
|
||||
chunk_size = -(-remaining_size // len(unbound))
|
||||
for u in unbound:
|
||||
|
|
@ -636,9 +636,9 @@ def split(self, split_size_or_sections, dim=0):
|
|||
sizes[u] = sz
|
||||
remaining_size -= sz
|
||||
else:
|
||||
assert (
|
||||
total_bound_size == size
|
||||
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
assert total_bound_size == size, (
|
||||
f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
||||
)
|
||||
return tuple(
|
||||
t.index(dim, d)
|
||||
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
|
||||
|
|
|
|||
|
|
@ -62,9 +62,9 @@ def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")):
|
|||
for config in configs:
|
||||
required_jobs.extend(list(CONFIGS[config]))
|
||||
for job in required_jobs:
|
||||
assert (
|
||||
job in workflow_jobs
|
||||
), f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update."
|
||||
assert job in workflow_jobs, (
|
||||
f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update."
|
||||
)
|
||||
|
||||
# This page lists all artifacts.
|
||||
listings = requests.get(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.exists(
|
||||
args.prefix
|
||||
), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
|
||||
assert os.path.exists(args.prefix), (
|
||||
f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
|
||||
)
|
||||
|
||||
commit = schema_check.update_schema()
|
||||
|
||||
|
|
@ -40,7 +40,9 @@ if __name__ == "__main__":
|
|||
f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}."
|
||||
)
|
||||
else:
|
||||
assert args.force_unsafe, "Existing schema yaml file not found, please use --force-unsafe to try again."
|
||||
assert args.force_unsafe, (
|
||||
"Existing schema yaml file not found, please use --force-unsafe to try again."
|
||||
)
|
||||
|
||||
next_version, reason = schema_check.check(commit, args.force_unsafe)
|
||||
|
||||
|
|
|
|||
|
|
@ -182,9 +182,9 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
baseline_param.grad, param.grad, atol=atol, rtol=rtol
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
test_backward is False
|
||||
), "Calculating backward with multiple outputs is not supported yet."
|
||||
assert test_backward is False, (
|
||||
"Calculating backward with multiple outputs is not supported yet."
|
||||
)
|
||||
for baseline_elem, result_elem in zip(baseline_result, result):
|
||||
torch.testing.assert_close(
|
||||
baseline_elem, result_elem, atol=atol, rtol=rtol
|
||||
|
|
|
|||
|
|
@ -30,15 +30,13 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
|
||||
def _prepare_llm_model_gptj_to_test() -> (
|
||||
tuple[
|
||||
torch.nn.Module,
|
||||
dict[str, Any],
|
||||
dict[str, dict[int, str]],
|
||||
list[str],
|
||||
list[str],
|
||||
]
|
||||
):
|
||||
def _prepare_llm_model_gptj_to_test() -> tuple[
|
||||
torch.nn.Module,
|
||||
dict[str, Any],
|
||||
dict[str, dict[int, str]],
|
||||
list[str],
|
||||
list[str],
|
||||
]:
|
||||
model = transformers.GPTJForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-gptj"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -93,9 +93,9 @@ def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes:
|
|||
for dim in inp.type.tensor_type.shape.dim
|
||||
if dim.dim_value == 0 and dim.dim_param != ""
|
||||
]
|
||||
assert dynamic_shapes == (
|
||||
len(dynamic_inputs) > 0
|
||||
), "Dynamic shape check failed for graph inputs"
|
||||
assert dynamic_shapes == (len(dynamic_inputs) > 0), (
|
||||
"Dynamic shape check failed for graph inputs"
|
||||
)
|
||||
|
||||
|
||||
def parameterize_class_name(cls: type, idx: int, input_dicts: Mapping[Any, Any]):
|
||||
|
|
@ -249,9 +249,9 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
|||
ref_input_args = input_args
|
||||
ref_input_kwargs = input_kwargs
|
||||
|
||||
assert isinstance(ref_model, torch.nn.Module) or callable(
|
||||
ref_model
|
||||
), "Model must be a torch.nn.Module or callable"
|
||||
assert isinstance(ref_model, torch.nn.Module) or callable(ref_model), (
|
||||
"Model must be a torch.nn.Module or callable"
|
||||
)
|
||||
if (
|
||||
self.model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
|
|
@ -640,9 +640,9 @@ def add_decorate_info(
|
|||
# Skip does not apply to this opset
|
||||
continue
|
||||
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
|
||||
assert (
|
||||
opinfo is not None
|
||||
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
||||
assert opinfo is not None, (
|
||||
f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
|
||||
)
|
||||
assert decorate_meta.model_type is None, (
|
||||
f"Tested op: {decorate_meta.op_name} in wrong position! "
|
||||
"If model_type needs to be specified, it should be "
|
||||
|
|
|
|||
|
|
@ -294,13 +294,13 @@ def xfail(error_message: str, reason: Optional[str] = None):
|
|||
except Exception as e:
|
||||
if isinstance(e, torch.onnx.OnnxExporterError):
|
||||
# diagnostic message is in the cause of the exception
|
||||
assert (
|
||||
error_message in str(e.__cause__)
|
||||
), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
|
||||
assert error_message in str(e.__cause__), (
|
||||
f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
|
||||
)
|
||||
else:
|
||||
assert error_message in str(
|
||||
e
|
||||
), f"Expected error message: {error_message} NOT in {str(e)}"
|
||||
assert error_message in str(e), (
|
||||
f"Expected error message: {error_message} NOT in {str(e)}"
|
||||
)
|
||||
pytest.xfail(reason if reason else f"Expected failure: {error_message}")
|
||||
else:
|
||||
pytest.fail("Unexpected success!")
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ class TestFxPasses(common_utils.TestCase):
|
|||
name_to_node = {node.name: node for node in gm.graph.nodes}
|
||||
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
|
||||
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
|
||||
assert len({node.name for node in nodes}) == len(
|
||||
nodes
|
||||
), f"Expected all names to be unique, got {nodes}"
|
||||
assert len({node.name for node in nodes}) == len(nodes), (
|
||||
f"Expected all names to be unique, got {nodes}"
|
||||
)
|
||||
|
||||
def test_set_node_name_succeeds_when_no_name_collisions(self):
|
||||
def func(x, y, z):
|
||||
|
|
@ -51,9 +51,9 @@ class TestFxPasses(common_utils.TestCase):
|
|||
name_to_node = {node.name: node for node in nodes}
|
||||
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
|
||||
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
|
||||
assert len({node.name for node in nodes}) == len(
|
||||
nodes
|
||||
), f"Expected all names to be unique, got {nodes}"
|
||||
assert len({node.name for node in nodes}) == len(nodes), (
|
||||
f"Expected all names to be unique, got {nodes}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -12630,12 +12630,12 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
actual_std = np.std(ort_out)
|
||||
actual_mean = np.mean(ort_out)
|
||||
|
||||
assert (
|
||||
abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
|
||||
), "the gap of mean between ort outputs and expected one is unacceptable."
|
||||
assert (
|
||||
abs(abs(actual_std) - expected_std) <= expected_std * 0.1
|
||||
), "the gap of variance between ort outputs and expected one is unacceptable."
|
||||
assert abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1, (
|
||||
"the gap of mean between ort outputs and expected one is unacceptable."
|
||||
)
|
||||
assert abs(abs(actual_std) - expected_std) <= expected_std * 0.1, (
|
||||
"the gap of variance between ort outputs and expected one is unacceptable."
|
||||
)
|
||||
|
||||
@skipScriptTest()
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
|
|
@ -12661,12 +12661,12 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
actual_std = np.std(ort_out)
|
||||
actual_mean = np.mean(ort_out)
|
||||
|
||||
assert (
|
||||
abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1
|
||||
), "the gap of mean between ort outputs and expected one is unacceptable."
|
||||
assert (
|
||||
abs(abs(actual_std) - expected_std) <= expected_std * 0.1
|
||||
), "the gap of variance between ort outputs and expected one is unacceptable."
|
||||
assert abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1, (
|
||||
"the gap of mean between ort outputs and expected one is unacceptable."
|
||||
)
|
||||
assert abs(abs(actual_std) - expected_std) <= expected_std * 0.1, (
|
||||
"the gap of variance between ort outputs and expected one is unacceptable."
|
||||
)
|
||||
|
||||
@skipScriptTest()
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
|
|
@ -12705,15 +12705,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
actual_max = np.max(ort_out)
|
||||
actual_mean = np.mean(ort_out)
|
||||
|
||||
assert (
|
||||
actual_min >= expected_min
|
||||
), "the minimum value of ort outputs is out of scope."
|
||||
assert (
|
||||
actual_max <= expected_max
|
||||
), "the maximum value of ort outputs is out of scope."
|
||||
assert (
|
||||
abs(actual_mean - expected_mean) <= expected_mean * 0.05
|
||||
), "the mean value of ort outputs is out of scope."
|
||||
assert actual_min >= expected_min, (
|
||||
"the minimum value of ort outputs is out of scope."
|
||||
)
|
||||
assert actual_max <= expected_max, (
|
||||
"the maximum value of ort outputs is out of scope."
|
||||
)
|
||||
assert abs(actual_mean - expected_mean) <= expected_mean * 0.05, (
|
||||
"the mean value of ort outputs is out of scope."
|
||||
)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_sequence_to_int(self):
|
||||
|
|
|
|||
|
|
@ -1441,9 +1441,9 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
|
|||
|
||||
|
||||
def run_ci_sanity_check(test: ShardedTest, test_directory, options):
|
||||
assert (
|
||||
test.name == "test_ci_sanity_check_fail"
|
||||
), f"This handler only works for test_ci_sanity_check_fail, got {test.name}"
|
||||
assert test.name == "test_ci_sanity_check_fail", (
|
||||
f"This handler only works for test_ci_sanity_check_fail, got {test.name}"
|
||||
)
|
||||
ret_code = run_test(test, test_directory, options, print_log=False)
|
||||
# This test should fail
|
||||
if ret_code != 1:
|
||||
|
|
@ -1951,9 +1951,9 @@ def get_sharding_opts(options) -> tuple[int, int]:
|
|||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
assert which_shard <= num_shards, (
|
||||
"Selected shard must be less than or equal to total number of shards"
|
||||
)
|
||||
|
||||
return (which_shard, num_shards)
|
||||
|
||||
|
|
@ -1996,9 +1996,9 @@ def run_test_module(
|
|||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
assert isinstance(return_code, int) and not isinstance(
|
||||
return_code, bool
|
||||
), f"While running {str(test)} got non integer return code {return_code}"
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
if return_code == 0:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -665,12 +665,12 @@ class ErrorTrackingProcess(mp.Process):
|
|||
raise
|
||||
|
||||
def print_traces_of_all_threads(self):
|
||||
assert (
|
||||
self.is_alive()
|
||||
), "can only use print_traces_of_all_threads if the process is alive"
|
||||
assert (
|
||||
not self.disable_stderr
|
||||
), "do not disable stderr if you use print_traces_of_all_threads"
|
||||
assert self.is_alive(), (
|
||||
"can only use print_traces_of_all_threads if the process is alive"
|
||||
)
|
||||
assert not self.disable_stderr, (
|
||||
"do not disable stderr if you use print_traces_of_all_threads"
|
||||
)
|
||||
# On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
|
||||
# `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
|
||||
# the process. So let's poll the exception first
|
||||
|
|
@ -1030,19 +1030,19 @@ class TestWorkerInfoDataset(SynchronizedDataset):
|
|||
# See _test_get_worker_info below for usage.
|
||||
def _test_worker_info_init_fn(worker_id):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
assert (
|
||||
worker_id == worker_info.id
|
||||
), "worker_init_fn and worker_info should have consistent id"
|
||||
assert (
|
||||
worker_id < worker_info.num_workers
|
||||
), "worker_init_fn and worker_info should have valid id"
|
||||
assert (
|
||||
worker_info.seed == torch.initial_seed()
|
||||
), "worker_init_fn and worker_info should have consistent seed"
|
||||
assert worker_id == worker_info.id, (
|
||||
"worker_init_fn and worker_info should have consistent id"
|
||||
)
|
||||
assert worker_id < worker_info.num_workers, (
|
||||
"worker_init_fn and worker_info should have valid id"
|
||||
)
|
||||
assert worker_info.seed == torch.initial_seed(), (
|
||||
"worker_init_fn and worker_info should have consistent seed"
|
||||
)
|
||||
dataset = worker_info.dataset
|
||||
assert isinstance(
|
||||
dataset, TestWorkerInfoDataset
|
||||
), "worker_info should have correct dataset copy"
|
||||
assert isinstance(dataset, TestWorkerInfoDataset), (
|
||||
"worker_info should have correct dataset copy"
|
||||
)
|
||||
assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
|
||||
# test that WorkerInfo attributes are read-only
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -87,9 +87,9 @@ class ForeachFuncWrapper:
|
|||
actual = self.func(*inputs, **kwargs)
|
||||
keys = tuple([e.key for e in p.key_averages()])
|
||||
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
|
||||
assert (
|
||||
mta_called == (expect_fastpath and (not zero_size))
|
||||
), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
|
||||
assert mta_called == (expect_fastpath and (not zero_size)), (
|
||||
f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}"
|
||||
)
|
||||
else:
|
||||
actual = self.func(*inputs, **kwargs)
|
||||
if self.is_inplace:
|
||||
|
|
|
|||
|
|
@ -169,9 +169,9 @@ def random_nt(
|
|||
assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
|
||||
assert min_dim >= 0, "random_nt: min_dim must be non-negative"
|
||||
if require_non_empty:
|
||||
assert not (
|
||||
min_dim == 0 and max_dim == 1
|
||||
), "random_nt: zero cannot be the only possible value if require_non_empty is True"
|
||||
assert not (min_dim == 0 and max_dim == 1), (
|
||||
"random_nt: zero cannot be the only possible value if require_non_empty is True"
|
||||
)
|
||||
|
||||
if require_non_empty:
|
||||
# Select a random idx that will be required to be non-empty
|
||||
|
|
|
|||
|
|
@ -52,17 +52,17 @@ FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
|
|||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size(
|
||||
[2]
|
||||
), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
assert tensor.size() == torch.Size([2]), (
|
||||
f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
)
|
||||
x, y = tensor
|
||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size(
|
||||
[2]
|
||||
), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
assert tensor.size() == torch.Size([2]), (
|
||||
f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
)
|
||||
x, y = tensor
|
||||
return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
||||
|
||||
|
|
|
|||
|
|
@ -616,9 +616,9 @@ def load_deprecated_signatures(
|
|||
}
|
||||
schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
|
||||
for name in call_args:
|
||||
assert (
|
||||
name in schema_args_by_name or name in known_constants
|
||||
), f"deprecation definiton: Unrecognized value {name}"
|
||||
assert name in schema_args_by_name or name in known_constants, (
|
||||
f"deprecation definiton: Unrecognized value {name}"
|
||||
)
|
||||
|
||||
# Map deprecated signature arguments to their aten signature and test
|
||||
# if the types and alias annotation match.
|
||||
|
|
@ -683,7 +683,9 @@ def load_deprecated_signatures(
|
|||
function=pair.function,
|
||||
)
|
||||
)
|
||||
assert any_schema_found, f"No native function with name {aten_name} matched signature:\n {str(schema)}"
|
||||
assert any_schema_found, (
|
||||
f"No native function with name {aten_name} matched signature:\n {str(schema)}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -1075,7 +1075,9 @@ def emit_body(
|
|||
assert (
|
||||
base_name_and_overload_name
|
||||
in _foreach_ops_without_differentiability_info
|
||||
), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
|
||||
), (
|
||||
f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
len(f.func.arguments.flat_non_out)
|
||||
|
|
@ -1634,9 +1636,9 @@ def emit_body(
|
|||
noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
|
||||
if noref_cpp_type == BaseCType(tensorT):
|
||||
if aliased_arg_name is not None:
|
||||
assert (
|
||||
i == 0
|
||||
), "Expect non-CompositeImplicitAutograd view function {base} to return single output"
|
||||
assert i == 0, (
|
||||
"Expect non-CompositeImplicitAutograd view function {base} to return single output"
|
||||
)
|
||||
stmts_after_call += [
|
||||
ENFORCE_SAME_TENSOR_STORAGE.substitute(
|
||||
tensor_name=aliased_arg_name, out_tensor_name=ret_name
|
||||
|
|
@ -1873,9 +1875,9 @@ def emit_body(
|
|||
for derivative in fw_derivatives:
|
||||
res = derivative.var_names
|
||||
if f.func.name.name.inplace:
|
||||
assert (
|
||||
len(res) == 1
|
||||
), "Expected number of outputs to be 1 if function is inplace"
|
||||
assert len(res) == 1, (
|
||||
"Expected number of outputs to be 1 if function is inplace"
|
||||
)
|
||||
# TODO update this when inplace namings are unified
|
||||
res = ("self",)
|
||||
|
||||
|
|
@ -1973,9 +1975,9 @@ def emit_body(
|
|||
isinstance(derivative.var_types[0], ListType)
|
||||
and derivative.var_types[0].is_tensor_like()
|
||||
):
|
||||
assert (
|
||||
len(derivative.var_types) == 1
|
||||
), "Expected number of outputs to be 1 if function returns ListType"
|
||||
assert len(derivative.var_types) == 1, (
|
||||
"Expected number of outputs to be 1 if function returns ListType"
|
||||
)
|
||||
if not is_foreach:
|
||||
opt_res_grad_type = OptionalCType(
|
||||
VectorCType(BaseCType(tensorT))
|
||||
|
|
|
|||
|
|
@ -313,9 +313,9 @@ def postprocess_forward_derivatives(
|
|||
formula = defn.formula
|
||||
required_inputs_tangent = find_required_inputs(formula, "_t")
|
||||
if formula == "auto_element_wise":
|
||||
assert (
|
||||
f.func.kind() != SchemaKind.inplace
|
||||
), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
|
||||
assert f.func.kind() != SchemaKind.inplace, (
|
||||
f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
|
||||
)
|
||||
if (
|
||||
(not len(args_with_derivatives) == 1)
|
||||
or len(forward_derivatives) > 1
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]:
|
|||
"""
|
||||
ret: list[LlvmCoverageSegment] = []
|
||||
for raw_segment in raw_segments:
|
||||
assert (
|
||||
len(raw_segment) == 5 or len(raw_segment) == 6
|
||||
), "list is not compatible with llvmcom export:"
|
||||
assert len(raw_segment) == 5 or len(raw_segment) == 6, (
|
||||
"list is not compatible with llvmcom export:"
|
||||
)
|
||||
" Expected to have 5 or 6 elements"
|
||||
if len(raw_segment) == 5:
|
||||
ret.append(
|
||||
|
|
|
|||
|
|
@ -116,12 +116,12 @@ def build_groups_memberships(
|
|||
_memberships[pg_guid] = set(ranks)
|
||||
else:
|
||||
# validation across ranks
|
||||
assert (
|
||||
_groups[pg_guid].desc == desc
|
||||
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
|
||||
assert (
|
||||
_memberships[pg_guid] == set(ranks)
|
||||
), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
|
||||
assert _groups[pg_guid].desc == desc, (
|
||||
f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}"
|
||||
)
|
||||
assert _memberships[pg_guid] == set(ranks), (
|
||||
f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}"
|
||||
)
|
||||
return groups, _groups, memberships, _memberships, _pg_guids
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -73,13 +73,13 @@ class JobConfig:
|
|||
) -> argparse.Namespace:
|
||||
args = self.parser.parse_args(args)
|
||||
if args.selected_ranks is not None:
|
||||
assert (
|
||||
args.just_print_entries
|
||||
), "Not support selecting ranks without printing entries"
|
||||
assert args.just_print_entries, (
|
||||
"Not support selecting ranks without printing entries"
|
||||
)
|
||||
if args.pg_filters is not None:
|
||||
assert (
|
||||
args.just_print_entries
|
||||
), "Not support selecting pg filters without printing entries"
|
||||
assert args.just_print_entries, (
|
||||
"Not support selecting pg filters without printing entries"
|
||||
)
|
||||
if args.verbose:
|
||||
logger.set_log_level(logging.DEBUG)
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]:
|
|||
if not version:
|
||||
version = str(details[f]["version"])
|
||||
tb = time.time()
|
||||
assert (
|
||||
len(details) > 0
|
||||
), f"no files loaded from {args.trace_dir} with prefix {prefix}"
|
||||
assert len(details) > 0, (
|
||||
f"no files loaded from {args.trace_dir} with prefix {prefix}"
|
||||
)
|
||||
logger.debug("loaded %s files in %ss", filecount, tb - t0)
|
||||
return details, version
|
||||
|
|
|
|||
|
|
@ -378,9 +378,9 @@ class Op:
|
|||
meta = parts[1] if len(parts) == 2 else None
|
||||
self.state = event["state"]
|
||||
self.pg_name, self.pg_desc = event["process_group"]
|
||||
assert type in COLLECTIVES | P2P | {
|
||||
"coalesced"
|
||||
}, f"{type} is not a supported operation"
|
||||
assert type in COLLECTIVES | P2P | {"coalesced"}, (
|
||||
f"{type} is not a supported operation"
|
||||
)
|
||||
self.type = type
|
||||
if type == "send":
|
||||
assert isinstance(meta, str)
|
||||
|
|
|
|||
|
|
@ -263,16 +263,16 @@ def check_no_missing_dump_files(
|
|||
for membership in memberships:
|
||||
all_ranks.add(int(membership.global_rank))
|
||||
dumps_ranks = {int(key) for key in entries.keys()}
|
||||
assert (
|
||||
dumps_ranks == all_ranks
|
||||
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
||||
assert dumps_ranks == all_ranks, (
|
||||
f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
||||
)
|
||||
|
||||
|
||||
def check_version(version_by_ranks: dict[str, str], version: str) -> None:
|
||||
for rank, v in version_by_ranks.items():
|
||||
assert (
|
||||
v == version
|
||||
), f"Rank {rank} has different version {v} from the given version {version}"
|
||||
assert v == version, (
|
||||
f"Rank {rank} has different version {v} from the given version {version}"
|
||||
)
|
||||
|
||||
|
||||
def get_version_detail(version: str) -> tuple[int, int]:
|
||||
|
|
|
|||
|
|
@ -102,13 +102,13 @@ def _format_rule_for_python_class(rule: _RuleType) -> str:
|
|||
if field_name is not None
|
||||
]
|
||||
for field_name in field_names:
|
||||
assert isinstance(
|
||||
field_name, str
|
||||
), f"Unexpected field type {type(field_name)} from {field_name}. "
|
||||
assert isinstance(field_name, str), (
|
||||
f"Unexpected field type {type(field_name)} from {field_name}. "
|
||||
)
|
||||
"Field name must be string.\nFull message template: {message_template}"
|
||||
assert (
|
||||
not field_name.isnumeric()
|
||||
), f"Unexpected numeric field name {field_name}. "
|
||||
assert not field_name.isnumeric(), (
|
||||
f"Unexpected numeric field name {field_name}. "
|
||||
)
|
||||
"Only keyword name formatting is supported.\nFull message template: {message_template}"
|
||||
message_arguments = ", ".join(field_names)
|
||||
message_arguments_assigned = ", ".join(
|
||||
|
|
|
|||
|
|
@ -212,12 +212,12 @@ def main() -> None:
|
|||
lazy_install_dir = os.path.join(install_dir, "lazy/generated")
|
||||
os.makedirs(lazy_install_dir, exist_ok=True)
|
||||
|
||||
assert os.path.isfile(
|
||||
ts_backend_yaml
|
||||
), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
|
||||
assert os.path.isfile(
|
||||
ts_native_functions
|
||||
), f"Unable to access {ts_native_functions}"
|
||||
assert os.path.isfile(ts_backend_yaml), (
|
||||
f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
|
||||
)
|
||||
assert os.path.isfile(ts_native_functions), (
|
||||
f"Unable to access {ts_native_functions}"
|
||||
)
|
||||
from torchgen.dest.lazy_ir import GenTSLazyIR
|
||||
from torchgen.gen_lazy_tensor import run_gen_lazy_tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
|||
]
|
||||
|
||||
filtered_configs = list(filter(filter_func, config))
|
||||
assert (
|
||||
len(filtered_configs) == 2
|
||||
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
||||
assert len(filtered_configs) == 2, (
|
||||
f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
||||
)
|
||||
|
||||
def test_verification_success(self) -> None:
|
||||
filter_func = make_filter_from_options(
|
||||
|
|
|
|||
|
|
@ -48,20 +48,20 @@ class TestPrioritizations:
|
|||
if test.test_file not in files:
|
||||
files[test.test_file] = copy(test)
|
||||
else:
|
||||
assert (
|
||||
files[test.test_file] & test
|
||||
).is_empty(), (
|
||||
assert (files[test.test_file] & test).is_empty(), (
|
||||
f"Test run `{test}` overlaps with `{files[test.test_file]}`"
|
||||
)
|
||||
files[test.test_file] |= test
|
||||
|
||||
for test in files.values():
|
||||
assert test.is_full_file(), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" # noqa: B950
|
||||
assert test.is_full_file(), (
|
||||
f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
|
||||
) # noqa: B950
|
||||
|
||||
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
|
||||
assert (
|
||||
self._original_tests == set(files.keys())
|
||||
), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
|
||||
assert self._original_tests == set(files.keys()), (
|
||||
"The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
|
||||
)
|
||||
|
||||
def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]:
|
||||
# Sort by score, then alphabetically by test name
|
||||
|
|
@ -228,9 +228,9 @@ class AggregatedHeuristics:
|
|||
def validate(self) -> None:
|
||||
for heuristic, heuristic_results in self._heuristic_results.items():
|
||||
heuristic_results.validate()
|
||||
assert (
|
||||
heuristic_results._original_tests == self._all_tests
|
||||
), f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics"
|
||||
assert heuristic_results._original_tests == self._all_tests, (
|
||||
f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics"
|
||||
)
|
||||
|
||||
def add_heuristic_results(
|
||||
self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ class TestRun:
|
|||
exs = set(excluded or [])
|
||||
|
||||
if "::" in name:
|
||||
assert (
|
||||
not included and not excluded
|
||||
), "Can't specify included or excluded tests when specifying a test class in the file name"
|
||||
assert not included and not excluded, (
|
||||
"Can't specify included or excluded tests when specifying a test class in the file name"
|
||||
)
|
||||
self.test_file, test_class = name.split("::")
|
||||
ins.add(test_class)
|
||||
else:
|
||||
|
|
@ -148,9 +148,9 @@ class TestRun:
|
|||
return copy(self)
|
||||
|
||||
# If not, ensure we have the same file
|
||||
assert (
|
||||
self.test_file == other.test_file
|
||||
), f"Can't exclude {other} from {self} because they're not the same test file"
|
||||
assert self.test_file == other.test_file, (
|
||||
f"Can't exclude {other} from {self} because they're not the same test file"
|
||||
)
|
||||
|
||||
# 4 possible cases:
|
||||
|
||||
|
|
|
|||
|
|
@ -124,9 +124,9 @@ def get_duration(
|
|||
|
||||
if included:
|
||||
return included_classes_duration
|
||||
assert (
|
||||
excluded
|
||||
), f"TestRun {test} is not full file but doesn't have included or excluded classes"
|
||||
assert excluded, (
|
||||
f"TestRun {test} is not full file but doesn't have included or excluded classes"
|
||||
)
|
||||
if file_duration is None:
|
||||
return None
|
||||
return file_duration - excluded_classes_duration
|
||||
|
|
@ -140,9 +140,9 @@ def shard(
|
|||
) -> None:
|
||||
# Modifies sharded_jobs in place
|
||||
if len(sharded_jobs) == 0:
|
||||
assert (
|
||||
len(pytest_sharded_tests) == 0
|
||||
), "No shards provided but there are tests to shard"
|
||||
assert len(pytest_sharded_tests) == 0, (
|
||||
"No shards provided but there are tests to shard"
|
||||
)
|
||||
return
|
||||
|
||||
round_robin_index = 0
|
||||
|
|
|
|||
|
|
@ -255,9 +255,9 @@ def createResolutionCallbackFromEnv(lookup_base):
|
|||
def parseExpr(expr, module):
|
||||
try:
|
||||
value, len_parsed = parseNestedExpr(expr, module)
|
||||
assert len_parsed == len(
|
||||
expr
|
||||
), "whole expression was not parsed, falling back to c++ parser"
|
||||
assert len_parsed == len(expr), (
|
||||
"whole expression was not parsed, falling back to c++ parser"
|
||||
)
|
||||
return value
|
||||
except Exception:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -795,9 +795,9 @@ class LOBPCG:
|
|||
# strict ordering of eigenpairs
|
||||
break
|
||||
count += 1
|
||||
assert (
|
||||
count >= prev_count
|
||||
), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
|
||||
assert count >= prev_count, (
|
||||
f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
|
||||
)
|
||||
self.ivars["converged_count"] = count
|
||||
self.tvars["rerr"] = rerr
|
||||
return count
|
||||
|
|
|
|||
|
|
@ -541,9 +541,9 @@ def meta_sparse_structured_linear(
|
|||
transposed_strides = (1, input.size(0))
|
||||
|
||||
if out_dtype is not None:
|
||||
assert (
|
||||
input.dtype == torch.int8 and out_dtype == torch.int32
|
||||
), "out_dtype is only supported for i8i8->i32 linear operator"
|
||||
assert input.dtype == torch.int8 and out_dtype == torch.int32, (
|
||||
"out_dtype is only supported for i8i8->i32 linear operator"
|
||||
)
|
||||
output = input.new_empty(
|
||||
output_sizes,
|
||||
dtype=input.dtype if out_dtype is None else out_dtype,
|
||||
|
|
@ -566,9 +566,9 @@ def meta_sparse_structured_mm(
|
|||
output_sizes = [mat1.size(0), mat2.size(1)]
|
||||
|
||||
if out_dtype is not None:
|
||||
assert (
|
||||
mat2.dtype == torch.int8 and out_dtype == torch.int32
|
||||
), "out_dtype is only supported for i8i8->i32 linear operator"
|
||||
assert mat2.dtype == torch.int8 and out_dtype == torch.int32, (
|
||||
"out_dtype is only supported for i8i8->i32 linear operator"
|
||||
)
|
||||
output = mat2.new_empty(
|
||||
output_sizes,
|
||||
dtype=mat2.dtype if out_dtype is None else out_dtype,
|
||||
|
|
@ -588,22 +588,22 @@ def meta_sparse_structured_addmm(
|
|||
beta=1,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
assert (
|
||||
len(input.shape) == 1
|
||||
), "only input broadcasted to columns of mat1 * mat2 product is supported"
|
||||
assert len(input.shape) == 1, (
|
||||
"only input broadcasted to columns of mat1 * mat2 product is supported"
|
||||
)
|
||||
assert len(mat1.shape) == 2
|
||||
assert len(mat1_meta.shape) == 2
|
||||
assert len(mat2.shape) == 2
|
||||
assert input.size(0) == mat1.size(
|
||||
0
|
||||
), "only input broadcasted to columns of mat1 * mat2 product is supported"
|
||||
assert input.size(0) == mat1.size(0), (
|
||||
"only input broadcasted to columns of mat1 * mat2 product is supported"
|
||||
)
|
||||
assert mat1.size(1) == mat2.size(0) / 2
|
||||
output_sizes = [mat1.size(0), mat2.size(1)]
|
||||
|
||||
if out_dtype is not None:
|
||||
assert (
|
||||
mat2.dtype == torch.int8 and out_dtype == torch.int32
|
||||
), "out_dtype is only supported for i8i8->i32 linear operator"
|
||||
assert mat2.dtype == torch.int8 and out_dtype == torch.int32, (
|
||||
"out_dtype is only supported for i8i8->i32 linear operator"
|
||||
)
|
||||
output = mat2.new_empty(
|
||||
output_sizes,
|
||||
dtype=mat2.dtype if out_dtype is None else out_dtype,
|
||||
|
|
@ -638,9 +638,9 @@ def meta__cslt_sparse_mm(
|
|||
compression_factor = 10 if is_8bit_input_type else 9
|
||||
|
||||
if is_8bit_input_type:
|
||||
assert (
|
||||
not dense_B.is_contiguous()
|
||||
), "dense input must be transposed for 8bit dtypes"
|
||||
assert not dense_B.is_contiguous(), (
|
||||
"dense input must be transposed for 8bit dtypes"
|
||||
)
|
||||
|
||||
k = dense_B.size(0)
|
||||
n = dense_B.size(1)
|
||||
|
|
@ -649,16 +649,14 @@ def meta__cslt_sparse_mm(
|
|||
assert m == bias.size(0)
|
||||
|
||||
if out_dtype is not None:
|
||||
assert (
|
||||
is_8bit_input_type
|
||||
and out_dtype
|
||||
in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
}
|
||||
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
assert is_8bit_input_type and out_dtype in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
}, (
|
||||
"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
)
|
||||
output_shape = (n, m) if transpose_result else (m, n)
|
||||
return dense_B.new_empty(output_shape, dtype=out_dtype)
|
||||
|
||||
|
|
@ -861,12 +859,12 @@ def functional_assert_async_meta(val, assert_msg, dep_token):
|
|||
|
||||
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
||||
def squareCheckInputs(self: Tensor, f_name: str):
|
||||
assert (
|
||||
self.dim() >= 2
|
||||
), f"{f_name}: The input tensor must have at least 2 dimensions."
|
||||
assert (
|
||||
self.size(-1) == self.size(-2)
|
||||
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
|
||||
assert self.dim() >= 2, (
|
||||
f"{f_name}: The input tensor must have at least 2 dimensions."
|
||||
)
|
||||
assert self.size(-1) == self.size(-2), (
|
||||
f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
|
||||
)
|
||||
|
||||
|
||||
# Validates input shapes and devices
|
||||
|
|
@ -6500,9 +6498,9 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True):
|
|||
def meta__segment_reduce_backward(
|
||||
grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None
|
||||
):
|
||||
assert (
|
||||
lengths is not None or offsets is not None
|
||||
), "segment_reduce(): Either lengths or offsets must be defined"
|
||||
assert lengths is not None or offsets is not None, (
|
||||
"segment_reduce(): Either lengths or offsets must be defined"
|
||||
)
|
||||
data_contig = data.contiguous()
|
||||
grad_contig = grad.contiguous()
|
||||
return torch.empty_like(
|
||||
|
|
@ -6579,7 +6577,9 @@ def linear_backward(input_, grad_output_, weight_, output_mask):
|
|||
def meta_pixel_shuffle(self, upscale_factor):
|
||||
assert (
|
||||
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
|
||||
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
|
||||
), (
|
||||
f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
|
||||
)
|
||||
|
||||
def is_channels_last(ten):
|
||||
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
|
||||
|
|
@ -6753,15 +6753,14 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None):
|
|||
|
||||
@register_meta(torch.ops.aten.transpose_)
|
||||
def transpose_(self, dim0, dim1):
|
||||
assert (
|
||||
self.layout
|
||||
not in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}
|
||||
), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
|
||||
assert self.layout not in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}, (
|
||||
f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
|
||||
)
|
||||
|
||||
ndims = self.ndim
|
||||
|
||||
|
|
@ -6788,13 +6787,14 @@ def t_(self):
|
|||
if self.is_sparse:
|
||||
sparse_dim = self.sparse_dim()
|
||||
dense_dim = self.dense_dim()
|
||||
assert (
|
||||
sparse_dim <= 2 and dense_dim == 0
|
||||
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950
|
||||
assert sparse_dim <= 2 and dense_dim == 0, (
|
||||
f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, "
|
||||
f"but got {sparse_dim} sparse and {dense_dim} dense dimensions"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.dim() <= 2
|
||||
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
|
||||
assert self.dim() <= 2, (
|
||||
f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
|
||||
)
|
||||
|
||||
return transpose_(self, 0, 0 if ndims < 2 else 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -133,9 +133,9 @@ class OperatorBase:
|
|||
return fn
|
||||
|
||||
assert isinstance(k, DispatchKey)
|
||||
assert (
|
||||
k != DispatchKey.Python
|
||||
), "Please register a mode for the DispatchKey.Python key instead."
|
||||
assert k != DispatchKey.Python, (
|
||||
"Please register a mode for the DispatchKey.Python key instead."
|
||||
)
|
||||
|
||||
if k in self.py_kernels:
|
||||
raise RuntimeError(
|
||||
|
|
@ -422,12 +422,12 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
DispatchKey.Python
|
||||
):
|
||||
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
|
||||
assert (
|
||||
type(curr_mode) in self.python_key_table
|
||||
), f"Current active mode {curr_mode} not registered"
|
||||
assert curr_mode is not None, (
|
||||
"Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
|
||||
)
|
||||
assert type(curr_mode) in self.python_key_table, (
|
||||
f"Current active mode {curr_mode} not registered"
|
||||
)
|
||||
handler = self.python_key_table[type(curr_mode)]
|
||||
with _pop_mode_temporarily(functionality_key) as mode:
|
||||
return handler(mode, *args, **kwargs)
|
||||
|
|
@ -828,9 +828,9 @@ class OpOverload(OperatorBase):
|
|||
# TODO: We also need to handle tensor subclasses here
|
||||
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
||||
curr_mode = type(_get_current_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
|
||||
assert curr_mode is not None, (
|
||||
"Illegal invocation of dispatch on DispatchKey.Python without a mode."
|
||||
)
|
||||
|
||||
if curr_mode not in self.python_key_table:
|
||||
if isinstance(self, TorchBindOpOverload):
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ def _to(self, device, non_blocking=False):
|
|||
return untyped_storage
|
||||
|
||||
device_module = getattr(torch, device.type, None)
|
||||
assert (
|
||||
device_module is not None
|
||||
), f"{device.type.upper()} device module is not loaded"
|
||||
assert device_module is not None, (
|
||||
f"{device.type.upper()} device module is not loaded"
|
||||
)
|
||||
with device_module.device(device):
|
||||
if self.is_sparse and hasattr(device_module, "sparse"):
|
||||
new_type = getattr(device_module.sparse, self.__class__.__name__)
|
||||
|
|
@ -95,9 +95,9 @@ def _to(self, device, non_blocking=False):
|
|||
)
|
||||
return new_type(indices, values, self.size())
|
||||
else:
|
||||
assert (
|
||||
not self.is_sparse
|
||||
), f"sparse storage is not supported for {device.type.upper()} tensors"
|
||||
assert not self.is_sparse, (
|
||||
f"sparse storage is not supported for {device.type.upper()} tensors"
|
||||
)
|
||||
untyped_storage = torch.UntypedStorage(self.size(), device=device)
|
||||
untyped_storage.copy_(self, non_blocking)
|
||||
return untyped_storage
|
||||
|
|
|
|||
|
|
@ -165,18 +165,18 @@ _context = engine.background_context
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_export_diagnostic_context() -> (
|
||||
Generator[infra.DiagnosticContext, None, None]
|
||||
):
|
||||
def create_export_diagnostic_context() -> Generator[
|
||||
infra.DiagnosticContext, None, None
|
||||
]:
|
||||
"""Create a diagnostic context for export.
|
||||
|
||||
This is a workaround for code robustness since diagnostic context is accessed by
|
||||
export internals via global variable. See `ExportDiagnosticEngine` for more details.
|
||||
"""
|
||||
global _context
|
||||
assert (
|
||||
_context == engine.background_context
|
||||
), "Export context is already set. Nested export is not supported."
|
||||
assert _context == engine.background_context, (
|
||||
"Export context is already set. Nested export is not supported."
|
||||
)
|
||||
_context = engine.create_diagnostic_context(
|
||||
"torch.onnx.export",
|
||||
torch.__version__,
|
||||
|
|
|
|||
|
|
@ -98,9 +98,9 @@ def _construct_named_inputs_and_attrs(
|
|||
else:
|
||||
# Handle attributes
|
||||
attribute: ValidAttributeType | ir.Attr
|
||||
assert isinstance(
|
||||
param, _schemas.AttributeParameter
|
||||
), f"Expected AttributeParameter, got {type(param)}"
|
||||
assert isinstance(param, _schemas.AttributeParameter), (
|
||||
f"Expected AttributeParameter, got {type(param)}"
|
||||
)
|
||||
if reversed_args_stack:
|
||||
# First exhaust the positional arguments
|
||||
attribute = reversed_args_stack.pop() # type: ignore[assignment]
|
||||
|
|
@ -166,9 +166,9 @@ def _resolve_parameter_dtypes(
|
|||
type_binding = {}
|
||||
for name, arg in named_inputs.items():
|
||||
param = signature.params_map[name]
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
assert isinstance(param, _schemas.Parameter), (
|
||||
f"Expected Parameter, got {type(param)}"
|
||||
)
|
||||
if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)):
|
||||
# Skip the Python constants because we do not know what dtype they should take yet
|
||||
continue
|
||||
|
|
@ -318,9 +318,9 @@ def _process_python_constants(
|
|||
# - Otherwise, set named_inputs[param.name] = Constant(value)
|
||||
for name, arg in named_inputs.items():
|
||||
param = signature.params_map[name]
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
assert isinstance(param, _schemas.Parameter), (
|
||||
f"Expected Parameter, got {type(param)}"
|
||||
)
|
||||
|
||||
if isinstance(arg, ir.Value):
|
||||
# TODO(justinchuby): Cast the ir.Value here if needed
|
||||
|
|
@ -384,9 +384,9 @@ def _process_python_sequences(
|
|||
"""
|
||||
for name, arg in named_inputs.items():
|
||||
param = signature.params_map[name]
|
||||
assert isinstance(
|
||||
param, _schemas.Parameter
|
||||
), f"Expected Parameter, got {type(param)}"
|
||||
assert isinstance(param, _schemas.Parameter), (
|
||||
f"Expected Parameter, got {type(param)}"
|
||||
)
|
||||
|
||||
if not isinstance(arg, (tuple, list)):
|
||||
continue
|
||||
|
|
@ -447,9 +447,9 @@ def _process_python_sequences(
|
|||
)
|
||||
else:
|
||||
# Turn the Python constant into 1D tensor for the constant
|
||||
assert isinstance(
|
||||
val, (bool, int, float)
|
||||
), f"Expected int or float, got {type(val)}"
|
||||
assert isinstance(val, (bool, int, float)), (
|
||||
f"Expected int or float, got {type(val)}"
|
||||
)
|
||||
new_args.append(
|
||||
_get_or_create_constant(constant_farm, [val], dtype, opset) # type: ignore[arg-type]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -324,9 +324,9 @@ def _handle_getitem_node(
|
|||
assert len(node.all_input_nodes) == 1
|
||||
source = node.all_input_nodes[0]
|
||||
source_outputs = node_name_to_values[source.name]
|
||||
assert isinstance(
|
||||
source_outputs, Sequence
|
||||
), f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}"
|
||||
assert isinstance(source_outputs, Sequence), (
|
||||
f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}"
|
||||
)
|
||||
index = typing.cast(int, node.args[1])
|
||||
value = source_outputs[index]
|
||||
# Save the getitem value to the values mapping to in case
|
||||
|
|
@ -649,9 +649,9 @@ def _handle_output_node(
|
|||
# for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs
|
||||
for output in node.args[0]: # type: ignore[index,union-attr]
|
||||
output_value_name = output.name # type: ignore[union-attr]
|
||||
assert isinstance(
|
||||
output_value_name, str
|
||||
), f"Bug: Expected {output_value_name!r} to be a string"
|
||||
assert isinstance(output_value_name, str), (
|
||||
f"Bug: Expected {output_value_name!r} to be a string"
|
||||
)
|
||||
values = node_name_to_values[output_value_name]
|
||||
if isinstance(values, Sequence):
|
||||
graph_like.outputs.extend(values)
|
||||
|
|
@ -754,9 +754,9 @@ def _get_inputs_and_attributes(
|
|||
return inputs, {}, [], [node.name] # type: ignore[return-value]
|
||||
|
||||
# The target should be an ATen operator now
|
||||
assert hasattr(
|
||||
node.target, "_schema"
|
||||
), f"The target should be an ATen operator now, but node target {node.target} has no schema"
|
||||
assert hasattr(node.target, "_schema"), (
|
||||
f"The target should be an ATen operator now, but node target {node.target} has no schema"
|
||||
)
|
||||
node_schema: torch.FunctionSchema = node.target._schema
|
||||
|
||||
# This function assumes the order of arguments in FX op is the
|
||||
|
|
@ -1050,9 +1050,9 @@ def _exported_program_to_onnx_program(
|
|||
persistent = spec.persistent
|
||||
value = values[value_name]
|
||||
|
||||
assert not isinstance(
|
||||
value, Sequence
|
||||
), f"Input '{value_name}' should not be a sequence. This is unexpected."
|
||||
assert not isinstance(value, Sequence), (
|
||||
f"Input '{value_name}' should not be a sequence. This is unexpected."
|
||||
)
|
||||
|
||||
value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = (
|
||||
input_kind.name
|
||||
|
|
|
|||
|
|
@ -307,9 +307,9 @@ def _get_allowed_types_from_type_annotation(
|
|||
allowed_types = set()
|
||||
subtypes = typing.get_args(type_)
|
||||
for subtype in subtypes:
|
||||
assert (
|
||||
subtype is not type(None)
|
||||
), "Union should not contain None type because it is handled by _is_optional."
|
||||
assert subtype is not type(None), (
|
||||
"Union should not contain None type because it is handled by _is_optional."
|
||||
)
|
||||
allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
|
||||
return allowed_types
|
||||
|
||||
|
|
|
|||
|
|
@ -602,9 +602,9 @@ class FxOnnxInterpreter:
|
|||
raise RuntimeError(
|
||||
f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}"
|
||||
)
|
||||
assert (
|
||||
output is not None
|
||||
), f"Node creates None with target={node.target} and name={node.name}"
|
||||
assert output is not None, (
|
||||
f"Node creates None with target={node.target} and name={node.name}"
|
||||
)
|
||||
|
||||
assert isinstance(output, onnxscript_graph_building.TorchScriptTensor)
|
||||
assert isinstance(output, onnxscript.tensor.Tensor)
|
||||
|
|
@ -631,9 +631,9 @@ class FxOnnxInterpreter:
|
|||
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
|
||||
index = node.args[1]
|
||||
value = onnx_tensor_tuple[index] # type: ignore[index]
|
||||
assert (
|
||||
value is not None
|
||||
), f"Node creates None with target={node.target} and name={node.name}"
|
||||
assert value is not None, (
|
||||
f"Node creates None with target={node.target} and name={node.name}"
|
||||
)
|
||||
assert isinstance(
|
||||
value, (onnxscript_graph_building.TorchScriptTensor, tuple)
|
||||
), type(value)
|
||||
|
|
@ -664,9 +664,9 @@ class FxOnnxInterpreter:
|
|||
onnxscript_graph_building.TorchScriptTensor
|
||||
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
|
||||
) = symbolic_fn(*onnx_args, **onnx_kwargs)
|
||||
assert (
|
||||
output is not None
|
||||
), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
|
||||
assert output is not None, (
|
||||
f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
|
||||
)
|
||||
# Assign type and shape from fx graph.
|
||||
_fill_tensor_shape_type(output, node.name, node.meta["val"])
|
||||
# One fx node could produce multiple outputs (e.g., tuple of tensors); in
|
||||
|
|
@ -694,9 +694,9 @@ class FxOnnxInterpreter:
|
|||
# tensor, etc), we flatten the collection and register each element as output.
|
||||
flat_args, _ = _pytree.tree_flatten(node.args[0])
|
||||
for arg in flat_args:
|
||||
assert isinstance(
|
||||
arg, torch.fx.Node
|
||||
), f"arg must be a torch.fx.Node, not {type(arg)}"
|
||||
assert isinstance(arg, torch.fx.Node), (
|
||||
f"arg must be a torch.fx.Node, not {type(arg)}"
|
||||
)
|
||||
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name]
|
||||
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
|
||||
|
||||
|
|
@ -735,15 +735,15 @@ class FxOnnxInterpreter:
|
|||
root_fx_graph_module: The root FX module.
|
||||
onnxfunction_dispatcher: The dispatcher.
|
||||
"""
|
||||
assert isinstance(
|
||||
node.target, str
|
||||
), f"node.target must be a str, not {type(node.target)} for node {node}."
|
||||
assert isinstance(node.target, str), (
|
||||
f"node.target must be a str, not {type(node.target)} for node {node}."
|
||||
)
|
||||
|
||||
sub_module = root_fx_graph_module.get_submodule(node.target)
|
||||
|
||||
assert isinstance(
|
||||
sub_module, torch.fx.GraphModule
|
||||
), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}."
|
||||
assert isinstance(sub_module, torch.fx.GraphModule), (
|
||||
f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}."
|
||||
)
|
||||
|
||||
sub_onnxscript_graph = self.run(
|
||||
sub_module, onnxfunction_dispatcher, parent_onnxscript_graph
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ class RestoreParameterAndBufferNames(_pass.Transform):
|
|||
) -> None:
|
||||
"""Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target."""
|
||||
assert len(nodes) > 0, "`nodes` cannot be empty"
|
||||
assert (
|
||||
len({node.target for node in nodes}) == 1
|
||||
), "`nodes` must all have same `target`"
|
||||
assert len({node.target for node in nodes}) == 1, (
|
||||
"`nodes` must all have same `target`"
|
||||
)
|
||||
old_name = nodes[0].target
|
||||
assert isinstance(old_name, str), f"Expected str, got type({old_name})"
|
||||
# Parameter/buffer name cannot contain "."
|
||||
|
|
@ -74,9 +74,9 @@ class RestoreParameterAndBufferNames(_pass.Transform):
|
|||
to the same objects, allowing us to use it as key to retrieve the original name.
|
||||
"""
|
||||
assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args"
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "RestoreParameterAndBufferNames does not take any kwargs"
|
||||
assert len(kwargs) == 0, (
|
||||
"RestoreParameterAndBufferNames does not take any kwargs"
|
||||
)
|
||||
# state_to_readable_name[parameter/buffer] returns the original readable name of
|
||||
# the parameter/buffer. E.g., "self.linear.weight".
|
||||
state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {}
|
||||
|
|
@ -95,9 +95,9 @@ class RestoreParameterAndBufferNames(_pass.Transform):
|
|||
|
||||
for node in self.module.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
assert isinstance(
|
||||
node.target, str
|
||||
), f"Expected str, got type({node.target})"
|
||||
assert isinstance(node.target, str), (
|
||||
f"Expected str, got type({node.target})"
|
||||
)
|
||||
if node.target.find(".") != -1:
|
||||
raise RuntimeError(
|
||||
f"Unexpected target {node.target} in get_attr, found '.' in target. "
|
||||
|
|
|
|||
|
|
@ -280,9 +280,9 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
assert len(args) >= 1, (
|
||||
f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
)
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: torch.dtype | None = kwargs.get("dtype", None)
|
||||
|
|
@ -319,9 +319,9 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
assert len(args) >= 1, (
|
||||
f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
)
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
computation_dtype = torch.bool
|
||||
|
|
@ -344,9 +344,9 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
|
|||
def preview_type_promotion(
|
||||
self, args: tuple, kwargs: dict
|
||||
) -> TypePromotionSnapshot:
|
||||
assert (
|
||||
len(args) >= 1
|
||||
), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
assert len(args) >= 1, (
|
||||
f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
|
||||
)
|
||||
arg = args[0]
|
||||
assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
|
||||
dtype: torch.dtype | None = kwargs.get("dtype", None)
|
||||
|
|
@ -1319,17 +1319,17 @@ def find_compatible_op_overload(
|
|||
op_trace_dispatch_mode = _OpTraceDispatchMode()
|
||||
with op_trace_dispatch_mode:
|
||||
op(*args, **kwargs)
|
||||
assert (
|
||||
len(op_trace_dispatch_mode.traced_ops) >= 1
|
||||
), "Expected at least 1 traced op, got 0"
|
||||
assert len(op_trace_dispatch_mode.traced_ops) >= 1, (
|
||||
"Expected at least 1 traced op, got 0"
|
||||
)
|
||||
|
||||
new_op_overload = op_trace_dispatch_mode.traced_ops[0]
|
||||
assert isinstance(
|
||||
new_op_overload, torch._ops.OpOverload
|
||||
), f"Expected OpOverload, got {type(new_op_overload)}"
|
||||
assert (
|
||||
new_op_overload.overloadpacket == op
|
||||
), f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}"
|
||||
assert isinstance(new_op_overload, torch._ops.OpOverload), (
|
||||
f"Expected OpOverload, got {type(new_op_overload)}"
|
||||
)
|
||||
assert new_op_overload.overloadpacket == op, (
|
||||
f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}"
|
||||
)
|
||||
|
||||
return new_op_overload
|
||||
|
||||
|
|
@ -1398,9 +1398,9 @@ class _TypePromotionInterpreter(torch.fx.Interpreter):
|
|||
assert node_val is not None, f"Node {node} node.meta['val'] is not set."
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
target = node.target
|
||||
assert isinstance(
|
||||
target, torch._ops.OpOverload
|
||||
), f"Expected OpOverload, got {type(target)}"
|
||||
assert isinstance(target, torch._ops.OpOverload), (
|
||||
f"Expected OpOverload, got {type(target)}"
|
||||
)
|
||||
node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs)
|
||||
|
||||
new_node_val = self._run_node_and_set_meta(node)
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ class ReplaceGetAttrWithPlaceholder(_pass.Transform):
|
|||
@property
|
||||
def replaced_attrs(self) -> tuple[torch.Tensor, ...]:
|
||||
"""The list of replaced weight tensors."""
|
||||
assert (
|
||||
self._replaced_attrs is not None
|
||||
), "Must run ReplaceGetAttrWithPlaceholder first"
|
||||
assert self._replaced_attrs is not None, (
|
||||
"Must run ReplaceGetAttrWithPlaceholder first"
|
||||
)
|
||||
return self._replaced_attrs
|
||||
|
||||
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
|
||||
|
|
|
|||
|
|
@ -639,9 +639,9 @@ class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
|
|||
flattened_outputs: The flattened model outputs.
|
||||
"""
|
||||
|
||||
assert isinstance(
|
||||
model, torch_export.ExportedProgram
|
||||
), "'model' must be torch_export.ExportedProgram"
|
||||
assert isinstance(model, torch_export.ExportedProgram), (
|
||||
"'model' must be torch_export.ExportedProgram"
|
||||
)
|
||||
ordered_buffers = tuple(
|
||||
model.state_dict[name]
|
||||
if name in model.state_dict
|
||||
|
|
|
|||
|
|
@ -411,9 +411,9 @@ def quantized_args(
|
|||
output = fn(g, *non_quantized_args, **kwargs)
|
||||
|
||||
assert _scale is not None, "Bug: Scale must be set for quantized operator"
|
||||
assert (
|
||||
_zero_point is not None
|
||||
), "Bug: Zero point must be set for quantized operator"
|
||||
assert _zero_point is not None, (
|
||||
"Bug: Zero point must be set for quantized operator"
|
||||
)
|
||||
|
||||
if quantize_output:
|
||||
return quantize_helper(g, output, _scale, _zero_point)
|
||||
|
|
|
|||
|
|
@ -145,10 +145,12 @@ def scaled_dot_product_attention(
|
|||
scale: torch._C.Value | None = None,
|
||||
enable_gqa: bool = False,
|
||||
):
|
||||
assert (not is_causal) or (
|
||||
is_causal and symbolic_helper._is_none(attn_mask)
|
||||
), "is_causal and attn_mask cannot be set at the same time"
|
||||
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
||||
"is_causal and attn_mask cannot be set at the same time"
|
||||
)
|
||||
assert not enable_gqa, (
|
||||
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
)
|
||||
|
||||
if symbolic_helper._is_none(scale):
|
||||
scale = _attention_scale(g, query)
|
||||
|
|
|
|||
|
|
@ -1199,9 +1199,9 @@ def prelu(g: jit_utils.GraphContext, self, weight):
|
|||
weight_rank = 0
|
||||
|
||||
if self_rank is not None and weight_rank is not None:
|
||||
assert (
|
||||
self_rank >= weight_rank
|
||||
), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}"
|
||||
assert self_rank >= weight_rank, (
|
||||
f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}"
|
||||
)
|
||||
return g.op("PRelu", self, weight)
|
||||
|
||||
|
||||
|
|
@ -4056,9 +4056,9 @@ def repeat_interleave(
|
|||
"Unsupported for cases with dynamic repeats",
|
||||
self,
|
||||
)
|
||||
assert (
|
||||
repeats_sizes[0] == input_sizes[dim]
|
||||
), "repeats must have the same size as input along dim"
|
||||
assert repeats_sizes[0] == input_sizes[dim], (
|
||||
"repeats must have the same size as input along dim"
|
||||
)
|
||||
reps = repeats_sizes[0]
|
||||
else:
|
||||
raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self)
|
||||
|
|
|
|||
|
|
@ -229,9 +229,9 @@ def _compare_onnx_pytorch_outputs_in_np(
|
|||
pt_outs: _OutputsType,
|
||||
options: VerificationOptions,
|
||||
):
|
||||
assert (
|
||||
len(onnx_outs) == len(pt_outs)
|
||||
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
assert len(onnx_outs) == len(pt_outs), (
|
||||
f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
|
||||
)
|
||||
acceptable_error_percentage = options.acceptable_error_percentage
|
||||
if acceptable_error_percentage and (
|
||||
acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
|
||||
|
|
@ -1561,9 +1561,9 @@ class GraphInfo:
|
|||
pt_outs = self.pt_outs
|
||||
graph_outputs = list(self.graph.outputs())
|
||||
assert pt_outs is not None
|
||||
assert len(graph_outputs) == len(
|
||||
pt_outs
|
||||
), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
|
||||
assert len(graph_outputs) == len(pt_outs), (
|
||||
f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
|
||||
)
|
||||
return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
|
||||
|
||||
def _args_and_params_for_partition_graph(
|
||||
|
|
@ -1577,9 +1577,9 @@ class GraphInfo:
|
|||
args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs)
|
||||
args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs)
|
||||
params = {k: full_params[k] for k in input_names if k in full_params}
|
||||
assert len(args) + len(params) == len(
|
||||
input_names
|
||||
), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
|
||||
assert len(args) + len(params) == len(input_names), (
|
||||
f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
|
||||
)
|
||||
return args, params
|
||||
|
||||
def verify_export(
|
||||
|
|
|
|||
|
|
@ -1805,9 +1805,9 @@ has_torch_function_variadic = _add_docstr(
|
|||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_overridable_functions() -> (
|
||||
tuple[dict[Any, list[Callable]], dict[Callable, str]]
|
||||
):
|
||||
def _get_overridable_functions() -> tuple[
|
||||
dict[Any, list[Callable]], dict[Callable, str]
|
||||
]:
|
||||
overridable_funcs = collections.defaultdict(list)
|
||||
index = {}
|
||||
tested_namespaces = [
|
||||
|
|
|
|||
|
|
@ -2023,9 +2023,9 @@ def _load(
|
|||
typename = _maybe_decode_ascii(saved_id[0])
|
||||
data = saved_id[1:]
|
||||
|
||||
assert (
|
||||
typename == "storage"
|
||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
assert typename == "storage", (
|
||||
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
)
|
||||
storage_type, key, location, numel = data
|
||||
if storage_type is torch.UntypedStorage:
|
||||
dtype = torch.uint8
|
||||
|
|
|
|||
|
|
@ -208,9 +208,9 @@ class DecisionTree:
|
|||
if name in dummy_col_2_col_val:
|
||||
(orig_name, value) = dummy_col_2_col_val[name]
|
||||
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
|
||||
assert (
|
||||
threshold == 0.5
|
||||
), f"expected threshold to be 0.5 but is {threshold}"
|
||||
assert threshold == 0.5, (
|
||||
f"expected threshold to be 0.5 but is {threshold}"
|
||||
)
|
||||
else:
|
||||
predicate = (
|
||||
f"{indent}if context.get_value('{name}') <= {threshold}:"
|
||||
|
|
|
|||
|
|
@ -446,9 +446,9 @@ class AHTrainDecisionTree(AHTrain):
|
|||
for row in group.itertuples():
|
||||
choice2time[row.choice] = row.median_execution_time
|
||||
|
||||
assert (
|
||||
len(unique_choices) == len(group)
|
||||
), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
|
||||
assert len(unique_choices) == len(group), (
|
||||
f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
|
||||
)
|
||||
|
||||
return pd.Series(
|
||||
{
|
||||
|
|
@ -869,9 +869,9 @@ class DecisionEvaluator:
|
|||
top_k_choices = self.top_k_classes(
|
||||
self.model, prob, k=self.k, avail_choices=avail_choices
|
||||
)
|
||||
assert (
|
||||
true in avail_choices
|
||||
), f"Best choice {true} not in available choices {avail_choices}"
|
||||
assert true in avail_choices, (
|
||||
f"Best choice {true} not in available choices {avail_choices}"
|
||||
)
|
||||
default_config = self.train.get_default_config(self.df.iloc[i])
|
||||
self.eval_prediction(
|
||||
avail_choices,
|
||||
|
|
|
|||
|
|
@ -405,9 +405,9 @@ class AHTrainRegressionTree(AHTrain):
|
|||
if name in dummy_col_2_col_val:
|
||||
(orig_name, value) = dummy_col_2_col_val[name]
|
||||
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
|
||||
assert (
|
||||
threshold == 0.5
|
||||
), f"expected threshold to be 0.5 but is {threshold}"
|
||||
assert threshold == 0.5, (
|
||||
f"expected threshold to be 0.5 but is {threshold}"
|
||||
)
|
||||
else:
|
||||
predicate = (
|
||||
f"{indent}if context.get_value('{name}') <= {threshold}:"
|
||||
|
|
|
|||
|
|
@ -250,7 +250,9 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
|||
elif t.name == BaseTy.Scalar:
|
||||
return BaseCType(scalarT)
|
||||
elif isinstance(t, ListType):
|
||||
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
|
||||
assert not mutable, (
|
||||
"Native functions should never return a mutable tensor list. They should return void."
|
||||
)
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
|
|
|||
|
|
@ -237,9 +237,9 @@ class LazyArgument:
|
|||
|
||||
@property
|
||||
def lazy_type(self) -> CType:
|
||||
assert (
|
||||
self.lazy_type_ is not None
|
||||
), f"Attempted to access lazy_type for invalid argument {self.name}"
|
||||
assert self.lazy_type_ is not None, (
|
||||
f"Attempted to access lazy_type for invalid argument {self.name}"
|
||||
)
|
||||
return self.lazy_type_
|
||||
|
||||
|
||||
|
|
@ -374,9 +374,9 @@ class LazyIrSchema:
|
|||
curr_args = curr_args.all()
|
||||
for arg in curr_args:
|
||||
if isGeneratorType(arg.type):
|
||||
assert (
|
||||
self.generator_arg is None
|
||||
), "We expect there is only one generator arg"
|
||||
assert self.generator_arg is None, (
|
||||
"We expect there is only one generator arg"
|
||||
)
|
||||
self.generator_arg = NamedCType(
|
||||
arg.name,
|
||||
arg.type, # type:ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -408,7 +408,9 @@ def kernel_signature(
|
|||
meta = backend_index.get_kernel(f)
|
||||
symint = meta is not None and meta.supports_symint()
|
||||
if symint:
|
||||
assert f.func.has_symint(), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||
assert f.func.has_symint(), (
|
||||
f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||
)
|
||||
if backend_index.external:
|
||||
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -481,9 +481,9 @@ class GenLazyNativeFuncDefinition:
|
|||
optional_devices = [
|
||||
a.name for a in scalar_args if a.lazy_type == optional_device
|
||||
]
|
||||
assert (
|
||||
len(value_types_names) > 0 or len(optional_devices) > 0
|
||||
), "Expected at least one Value or Device type"
|
||||
assert len(value_types_names) > 0 or len(optional_devices) > 0, (
|
||||
"Expected at least one Value or Device type"
|
||||
)
|
||||
get_device_str = (
|
||||
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
|
||||
)
|
||||
|
|
@ -580,9 +580,9 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
|||
# xla uses an instance method for tensor creation, for the time being
|
||||
if self.create_from_first_tensor:
|
||||
# TODO(whc) remove this if XLA switches to using static method for creation
|
||||
assert (
|
||||
first_tensor_name is not None
|
||||
), "Requires first tensor to create lazy tensor"
|
||||
assert first_tensor_name is not None, (
|
||||
"Requires first tensor to create lazy tensor"
|
||||
)
|
||||
return f"{first_tensor_name}.{self.create_tensor}"
|
||||
return f"{self.backend_namespace}::{self.create_tensor}"
|
||||
|
||||
|
|
@ -595,9 +595,9 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
|||
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
|
||||
|
||||
if returns_length > 1:
|
||||
assert (
|
||||
len(value_types_names) > 0
|
||||
), "Code below assumes there is at least one tensor arg"
|
||||
assert len(value_types_names) > 0, (
|
||||
"Code below assumes there is at least one tensor arg"
|
||||
)
|
||||
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
|
||||
for (int i = 0; i < {returns_length}; i++) {{
|
||||
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
|
||||
|
|
|
|||
|
|
@ -186,9 +186,9 @@ def compute_ufunc_cuda_functors(
|
|||
ufunc_name = loops[lk].name
|
||||
else:
|
||||
# See Note [ScalarOnly and Generic must match names for CUDA]
|
||||
assert (
|
||||
ufunc_name == loops[lk].name
|
||||
), "ScalarOnly and Generic must have same ufunc name"
|
||||
assert ufunc_name == loops[lk].name, (
|
||||
"ScalarOnly and Generic must have same ufunc name"
|
||||
)
|
||||
supported_dtypes |= loops[lk].supported_dtypes
|
||||
assert ufunc_name is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -65,9 +65,9 @@ class ComputeNativeFunctionStub:
|
|||
{comma.join([r.name for r in f.func.arguments.out])}
|
||||
)"""
|
||||
else:
|
||||
assert all(
|
||||
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
|
||||
), f"Only support tensor returns but got {f.func.returns}"
|
||||
assert all(a.type == BaseType(BaseTy.Tensor) for a in f.func.returns), (
|
||||
f"Only support tensor returns but got {f.func.returns}"
|
||||
)
|
||||
# Returns a tuple of empty tensors
|
||||
tensor_type = "at::Tensor"
|
||||
comma = ", "
|
||||
|
|
|
|||
|
|
@ -181,7 +181,9 @@ def returntype_type(t: Type, *, mutable: bool) -> CType:
|
|||
elif t.name == BaseTy.Scalar:
|
||||
return BaseCType(scalarT)
|
||||
elif isinstance(t, ListType):
|
||||
assert not mutable, "Native functions should never return a mutable tensor list. They should return void."
|
||||
assert not mutable, (
|
||||
"Native functions should never return a mutable tensor list. They should return void."
|
||||
)
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ class ETKernelKey:
|
|||
assert type_alias in type_alias_map, "Undefined type alias: " + str(
|
||||
type_alias
|
||||
)
|
||||
assert (
|
||||
dim_order in dim_order_alias_map
|
||||
), f"Undefined dim_order alias: {dim_order}"
|
||||
assert dim_order in dim_order_alias_map, (
|
||||
f"Undefined dim_order alias: {dim_order}"
|
||||
)
|
||||
dtype_alias_used.add(type_alias)
|
||||
|
||||
# Generate all permutations of dtype alias values
|
||||
|
|
@ -193,9 +193,9 @@ class ETKernelIndex:
|
|||
index: dict[OperatorName, BackendMetadata] = {}
|
||||
for op in self.index:
|
||||
kernel_dict = self.index[op]
|
||||
assert (
|
||||
len(kernel_dict.values()) == 1
|
||||
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
|
||||
assert len(kernel_dict.values()) == 1, (
|
||||
f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
|
||||
)
|
||||
index[op] = kernel_dict.get(
|
||||
ETKernelKey(default=True),
|
||||
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
|
||||
|
|
|
|||
|
|
@ -1433,9 +1433,9 @@ def get_grouped_by_view_native_functions(
|
|||
assert kind not in grouped_by_views[schema]
|
||||
grouped_by_views[schema][kind] = f
|
||||
else:
|
||||
assert (
|
||||
view_kind not in grouped_by_views[schema]
|
||||
), f"{view_kind} already in {grouped_by_views[schema].keys()}"
|
||||
assert view_kind not in grouped_by_views[schema], (
|
||||
f"{view_kind} already in {grouped_by_views[schema].keys()}"
|
||||
)
|
||||
grouped_by_views[schema][view_kind] = f
|
||||
|
||||
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
|
||||
|
|
@ -1483,9 +1483,9 @@ def get_ns_grouped_kernels(
|
|||
native_function_namespaces.add(namespace)
|
||||
else:
|
||||
namespace = DEFAULT_KERNEL_NAMESPACE
|
||||
assert (
|
||||
len(native_function_namespaces) <= 1
|
||||
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
|
||||
assert len(native_function_namespaces) <= 1, (
|
||||
f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
|
||||
)
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
native_function_decl_gen(f, backend_idx)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -80,34 +80,34 @@ def parse_backend_yaml(
|
|||
|
||||
# Mostly just defaulting to false to stick with LazyTensor convention.
|
||||
use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
|
||||
assert isinstance(
|
||||
use_out_as_primary, bool
|
||||
), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
|
||||
assert isinstance(use_out_as_primary, bool), (
|
||||
f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
|
||||
)
|
||||
|
||||
use_device_guard = yaml_values.pop("device_guard", False)
|
||||
assert isinstance(
|
||||
use_device_guard, bool
|
||||
), f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
|
||||
assert isinstance(use_device_guard, bool), (
|
||||
f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
|
||||
)
|
||||
|
||||
supported = yaml_values.pop("supported", [])
|
||||
if supported is None:
|
||||
supported = [] # Allow an empty list of supported ops
|
||||
assert isinstance(
|
||||
supported, list
|
||||
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
assert isinstance(supported, list), (
|
||||
f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
)
|
||||
|
||||
symint = yaml_values.pop("symint", [])
|
||||
if symint is None:
|
||||
symint = [] # Allow an empty list of symint ops
|
||||
assert isinstance(
|
||||
symint, list
|
||||
), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
assert isinstance(symint, list), (
|
||||
f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
|
||||
)
|
||||
symint_set = set(symint)
|
||||
|
||||
supported_autograd = yaml_values.pop("autograd", [])
|
||||
assert isinstance(
|
||||
supported_autograd, list
|
||||
), f'expected "autograd" to be a list, but got: {supported_autograd}'
|
||||
assert isinstance(supported_autograd, list), (
|
||||
f'expected "autograd" to be a list, but got: {supported_autograd}'
|
||||
)
|
||||
|
||||
# full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
|
||||
full_codegen = yaml_values.pop("full_codegen", [])
|
||||
|
|
@ -135,9 +135,9 @@ def parse_backend_yaml(
|
|||
metadata: dict[OperatorName, BackendMetadata] = {}
|
||||
for op in backend_ops:
|
||||
op_name = OperatorName.parse(op)
|
||||
assert (
|
||||
op_name in native_functions_map
|
||||
), f"Found an invalid operator name: {op_name}"
|
||||
assert op_name in native_functions_map, (
|
||||
f"Found an invalid operator name: {op_name}"
|
||||
)
|
||||
# See Note [External Backends Follow Dispatcher API]
|
||||
kernel_name = dispatcher.name(native_functions_map[op_name].func)
|
||||
if op in symint_ops:
|
||||
|
|
@ -238,11 +238,11 @@ the behavior of autograd for some operators on your backend. However "Autograd{b
|
|||
|
||||
forward_kernels = [f for f in forward_kernels if f is not None]
|
||||
backward_kernels = [f for f in backward_kernels if f is not None]
|
||||
assert (
|
||||
len(forward_kernels) == 0 or len(backward_kernels) == 0
|
||||
), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
|
||||
assert len(forward_kernels) == 0 or len(backward_kernels) == 0, (
|
||||
f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
|
||||
autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
|
||||
{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'
|
||||
)
|
||||
|
||||
return ParsedExternalYaml(
|
||||
backend_key, autograd_key, class_name, cpp_namespace, backend_indices
|
||||
|
|
|
|||
|
|
@ -444,9 +444,9 @@ def get_ns_grouped_kernels(
|
|||
native_function_namespaces.add(namespace)
|
||||
else:
|
||||
namespace = DEFAULT_KERNEL_NAMESPACE
|
||||
assert (
|
||||
len(native_function_namespaces) <= 1
|
||||
), f"Codegen only supports one namespace per operator, got {native_function_namespaces}"
|
||||
assert len(native_function_namespaces) <= 1, (
|
||||
f"Codegen only supports one namespace per operator, got {native_function_namespaces}"
|
||||
)
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
native_function_decl_gen(f, kernel_index)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,8 +40,7 @@ def use_const_ref_for_mutable_tensors() -> bool:
|
|||
|
||||
def use_ilistref_for_tensor_lists() -> bool:
|
||||
assert _locals.use_ilistref_for_tensor_lists is not None, (
|
||||
"need to initialize local.use_ilistref_for_tensor_lists with "
|
||||
"local.parametrize"
|
||||
"need to initialize local.use_ilistref_for_tensor_lists with local.parametrize"
|
||||
)
|
||||
return _locals.use_ilistref_for_tensor_lists
|
||||
|
||||
|
|
|
|||
|
|
@ -627,9 +627,9 @@ class NativeFunction:
|
|||
assert isinstance(use_const_ref_for_mutable_tensors, bool)
|
||||
|
||||
if use_const_ref_for_mutable_tensors:
|
||||
assert (
|
||||
not func.arguments.out
|
||||
), "see https://github.com/pytorch/pytorch/issues/145522"
|
||||
assert not func.arguments.out, (
|
||||
"see https://github.com/pytorch/pytorch/issues/145522"
|
||||
)
|
||||
|
||||
variants_s = e.pop("variants", "function")
|
||||
assert isinstance(variants_s, str)
|
||||
|
|
@ -643,9 +643,9 @@ class NativeFunction:
|
|||
raise AssertionError(f"illegal variant {v}")
|
||||
|
||||
manual_kernel_registration = e.pop("manual_kernel_registration", False)
|
||||
assert isinstance(
|
||||
manual_kernel_registration, bool
|
||||
), f"not a bool: {manual_kernel_registration}"
|
||||
assert isinstance(manual_kernel_registration, bool), (
|
||||
f"not a bool: {manual_kernel_registration}"
|
||||
)
|
||||
|
||||
manual_cpp_binding = e.pop("manual_cpp_binding", False)
|
||||
assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
|
||||
|
|
@ -654,9 +654,9 @@ class NativeFunction:
|
|||
assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
|
||||
|
||||
device_check_s = e.pop("device_check", None)
|
||||
assert device_check_s is None or isinstance(
|
||||
device_check_s, str
|
||||
), f"not a str: {device_check_s}"
|
||||
assert device_check_s is None or isinstance(device_check_s, str), (
|
||||
f"not a str: {device_check_s}"
|
||||
)
|
||||
assert (
|
||||
device_check_s is None or device_check_s in DeviceCheckType.__members__
|
||||
), f"illegal device_check: {device_check_s}"
|
||||
|
|
@ -682,26 +682,26 @@ class NativeFunction:
|
|||
structured_delegate = OperatorName.parse(structured_delegate_s)
|
||||
|
||||
structured_inherits = e.pop("structured_inherits", None)
|
||||
assert structured_inherits is None or isinstance(
|
||||
structured_inherits, str
|
||||
), f"not a str: {structured_inherits}"
|
||||
assert structured_inherits is None or isinstance(structured_inherits, str), (
|
||||
f"not a str: {structured_inherits}"
|
||||
)
|
||||
assert structured_inherits is None or "::" not in structured_inherits, (
|
||||
"namespace is not supported in structured inherits,"
|
||||
" using the same namespace as the native function"
|
||||
)
|
||||
|
||||
python_module = e.pop("python_module", None)
|
||||
assert python_module is None or isinstance(
|
||||
python_module, str
|
||||
), f"not a str: {python_module}"
|
||||
assert (
|
||||
python_module is None or Variant.method not in variants
|
||||
), "functions in modules cannot be methods"
|
||||
assert python_module is None or isinstance(python_module, str), (
|
||||
f"not a str: {python_module}"
|
||||
)
|
||||
assert python_module is None or Variant.method not in variants, (
|
||||
"functions in modules cannot be methods"
|
||||
)
|
||||
|
||||
category_override = e.pop("category_override", None)
|
||||
assert category_override is None or isinstance(
|
||||
category_override, str
|
||||
), f"not a str: {category_override}"
|
||||
assert category_override is None or isinstance(category_override, str), (
|
||||
f"not a str: {category_override}"
|
||||
)
|
||||
|
||||
precomputed_dict = e.pop("precomputed", None)
|
||||
assert precomputed_dict is None or structured is True
|
||||
|
|
@ -740,12 +740,12 @@ class NativeFunction:
|
|||
for ks, v in raw_dispatch.items():
|
||||
if ks == "__line__":
|
||||
continue # not worth tracking line numbers for dispatch entries
|
||||
assert isinstance(
|
||||
ks, str
|
||||
), f"illegal dispatch key '{ks}' in {raw_dispatch}"
|
||||
assert isinstance(
|
||||
v, str
|
||||
), f"illegal dispatch value '{v}' in {raw_dispatch}"
|
||||
assert isinstance(ks, str), (
|
||||
f"illegal dispatch key '{ks}' in {raw_dispatch}"
|
||||
)
|
||||
assert isinstance(v, str), (
|
||||
f"illegal dispatch value '{v}' in {raw_dispatch}"
|
||||
)
|
||||
for k in ks.split(","):
|
||||
dispatch_key = DispatchKey.parse(k.strip())
|
||||
num_dispatch_keys += 1
|
||||
|
|
@ -879,9 +879,9 @@ class NativeFunction:
|
|||
import torchgen.api.ufunc as ufunc
|
||||
|
||||
for dispatch_key in UFUNC_DISPATCH_KEYS:
|
||||
assert (
|
||||
dispatch_key not in dispatch
|
||||
), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
|
||||
assert dispatch_key not in dispatch, (
|
||||
f"ufunc should not have explicit dispatch entry for {dispatch_key}"
|
||||
)
|
||||
dispatch[dispatch_key] = BackendMetadata(
|
||||
kernel=ufunc.schema_kernel_name(func, dispatch_key),
|
||||
structured=True,
|
||||
|
|
@ -997,31 +997,31 @@ class NativeFunction:
|
|||
"Put structured field on the out= "
|
||||
"variant of a function; did you mean structured_delegate?"
|
||||
)
|
||||
assert (
|
||||
self.device_guard
|
||||
), "device_guard: False is not respected by structured kernels"
|
||||
assert self.device_guard, (
|
||||
"device_guard: False is not respected by structured kernels"
|
||||
)
|
||||
if self.structured_delegate:
|
||||
assert self.func.kind() != SchemaKind.out, (
|
||||
"structured_delegate field not allowed "
|
||||
"on out= functions; did you mean structured?"
|
||||
)
|
||||
assert (
|
||||
self.device_guard
|
||||
), "device_guard: False is not respected by structured kernels"
|
||||
assert self.device_guard, (
|
||||
"device_guard: False is not respected by structured kernels"
|
||||
)
|
||||
# Technically, with the asserts above, this assert is impossible to
|
||||
# happen
|
||||
assert not (
|
||||
self.structured and self.structured_delegate
|
||||
), "Cannot have both structured and structured_delegate on function"
|
||||
assert not (self.structured and self.structured_delegate), (
|
||||
"Cannot have both structured and structured_delegate on function"
|
||||
)
|
||||
defaulted_arguments = {
|
||||
a.name for a in self.func.schema_order_arguments() if a.default is not None
|
||||
}
|
||||
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
|
||||
assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
|
||||
if self.structured_inherits is not None:
|
||||
assert (
|
||||
self.structured
|
||||
), "structured_inherits must also imply structured: True"
|
||||
assert self.structured, (
|
||||
"structured_inherits must also imply structured: True"
|
||||
)
|
||||
if str(self.func.name).startswith("_foreach"):
|
||||
assert self.device_check == DeviceCheckType.NoCheck, (
|
||||
"foreach kernels fall back to slow path when tensor are on different devices, "
|
||||
|
|
@ -1299,9 +1299,9 @@ class BackendIndex:
|
|||
) -> None:
|
||||
for k, v in child_index.items():
|
||||
for op_name, metadata in v.items():
|
||||
assert (
|
||||
op_name not in parent_index[k]
|
||||
), f"duplicate operator {op_name} for dispatch key {k}"
|
||||
assert op_name not in parent_index[k], (
|
||||
f"duplicate operator {op_name} for dispatch key {k}"
|
||||
)
|
||||
parent_index[k][op_name] = metadata
|
||||
|
||||
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
|
||||
|
|
@ -1450,9 +1450,9 @@ class FunctionSchema:
|
|||
# We also enforce that if you have any mutable, positional args, then they are not returned.
|
||||
# This makes it easier to group these functions properly with their functional/out= counterparts.
|
||||
for a in self.arguments.post_self_positional_mutable:
|
||||
assert not any(
|
||||
a.annotation == r.annotation for r in self.returns
|
||||
), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
|
||||
assert not any(a.annotation == r.annotation for r in self.returns), (
|
||||
f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
|
||||
)
|
||||
# Invariant: we expect out arguments to appear as keyword arguments in the schema.
|
||||
# This means that all mutable returns should be aliased to a keyword argument
|
||||
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
|
||||
|
|
@ -1475,9 +1475,9 @@ class FunctionSchema:
|
|||
# (1) It's more annoying to handle properly
|
||||
# (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
|
||||
# Instead, we expect the (a!) argument to not be returned.
|
||||
assert (
|
||||
len(mutable_returns) == 0 or len(immutable_returns) == 0
|
||||
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
|
||||
assert len(mutable_returns) == 0 or len(immutable_returns) == 0, (
|
||||
f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
|
||||
)
|
||||
for ret in mutable_returns:
|
||||
assert any(ret.annotation == arg.annotation for arg in out_and_self), (
|
||||
'All mutable returns must be aliased either to a keyword argument, or to "self". '
|
||||
|
|
@ -1490,23 +1490,22 @@ class FunctionSchema:
|
|||
# and all other types of out= op schemas should return void.
|
||||
# There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
|
||||
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
|
||||
assert (
|
||||
len(self.returns) == 0
|
||||
), "out= ops that accept tensor lists as out arguments "
|
||||
assert len(self.returns) == 0, (
|
||||
"out= ops that accept tensor lists as out arguments "
|
||||
)
|
||||
"are expected to have no return type (since you can't do method chaining on them)"
|
||||
else:
|
||||
# mutable keyword arguments whose name has _scratch_ prefix are
|
||||
# scratch tensors for memory planning and should not be returned
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
arg
|
||||
for arg in self.arguments.out
|
||||
if not arg.name.startswith("_scratch_")
|
||||
]
|
||||
)
|
||||
== len(self.returns)
|
||||
), "Must return as many arguments as there are out arguments, or no return at all"
|
||||
assert len(
|
||||
[
|
||||
arg
|
||||
for arg in self.arguments.out
|
||||
if not arg.name.startswith("_scratch_")
|
||||
]
|
||||
) == len(self.returns), (
|
||||
"Must return as many arguments as there are out arguments, or no return at all"
|
||||
)
|
||||
|
||||
if self.name.name.inplace:
|
||||
self_a = self.arguments.self_arg
|
||||
|
|
@ -1599,12 +1598,14 @@ class FunctionSchema:
|
|||
if is_inplace:
|
||||
return SchemaKind.inplace
|
||||
elif is_scratch:
|
||||
assert (
|
||||
is_out
|
||||
), "invariant: all scratch operators are expected to be out= operators too"
|
||||
assert is_out, (
|
||||
"invariant: all scratch operators are expected to be out= operators too"
|
||||
)
|
||||
return SchemaKind.scratch
|
||||
elif is_out:
|
||||
assert not is_scratch, "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" # noqa: B950
|
||||
assert not is_scratch, (
|
||||
"We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
|
||||
) # noqa: B950
|
||||
return SchemaKind.out
|
||||
elif is_mutable:
|
||||
return SchemaKind.mutable
|
||||
|
|
@ -1801,13 +1802,13 @@ class Annotation:
|
|||
before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
|
||||
alias_set = tuple(before_alias.split("|"))
|
||||
is_write = m.group(3) == "!"
|
||||
assert not (
|
||||
is_write and len(alias_set) > 1
|
||||
), f"alias set larger than 1 is not mutable, got {ann} instead."
|
||||
assert not (is_write and len(alias_set) > 1), (
|
||||
f"alias set larger than 1 is not mutable, got {ann} instead."
|
||||
)
|
||||
after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
|
||||
assert not (
|
||||
len(before_alias) > 1 and len(after_set) > 1
|
||||
), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
|
||||
assert not (len(before_alias) > 1 and len(after_set) > 1), (
|
||||
f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
|
||||
)
|
||||
r = Annotation(
|
||||
alias_set=alias_set, is_write=is_write, alias_set_after=after_set
|
||||
)
|
||||
|
|
@ -2348,9 +2349,9 @@ class Arguments:
|
|||
if not arg:
|
||||
continue
|
||||
if arg == "*":
|
||||
assert (
|
||||
arguments_acc is positional
|
||||
), "invalid syntax: kwarg-only specifier * can only occur once"
|
||||
assert arguments_acc is positional, (
|
||||
"invalid syntax: kwarg-only specifier * can only occur once"
|
||||
)
|
||||
arguments_acc = kwarg_only
|
||||
continue
|
||||
parg = Argument.parse(arg)
|
||||
|
|
@ -2475,9 +2476,9 @@ class Arguments:
|
|||
for a in self.pre_self_positional
|
||||
if a.annotation is not None and a.annotation.is_write
|
||||
]
|
||||
assert (
|
||||
len(mutable_pre_self_positionals) == 0
|
||||
), "mutable pre_self_positional arguments are not currently supported in the schema"
|
||||
assert len(mutable_pre_self_positionals) == 0, (
|
||||
"mutable pre_self_positional arguments are not currently supported in the schema"
|
||||
)
|
||||
|
||||
|
||||
# Names that validly are __iXXX__ indicating inplace operations.
|
||||
|
|
@ -2831,9 +2832,9 @@ class Precompute:
|
|||
)
|
||||
|
||||
arg, with_list_raw = raw_replace_item.split(" -> ")
|
||||
assert (
|
||||
" " not in arg
|
||||
), f"illegal kernel param name '{arg}' in precomputed parameters'"
|
||||
assert " " not in arg, (
|
||||
f"illegal kernel param name '{arg}' in precomputed parameters'"
|
||||
)
|
||||
with_list = with_list_raw.split(",")
|
||||
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
||||
replace[arg] = with_list_args
|
||||
|
|
|
|||
|
|
@ -246,9 +246,9 @@ class FileManager:
|
|||
for key in sharded_keys:
|
||||
for shard in all_shards:
|
||||
if key in shard:
|
||||
assert isinstance(
|
||||
shard[key], list
|
||||
), "sharded keys in base_env must be a list"
|
||||
assert isinstance(shard[key], list), (
|
||||
"sharded keys in base_env must be a list"
|
||||
)
|
||||
shard[key] = shard[key].copy()
|
||||
else:
|
||||
shard[key] = []
|
||||
|
|
@ -441,9 +441,9 @@ class NamespaceHelper:
|
|||
) -> None:
|
||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||
cpp_namespaces = namespace_str.split("::")
|
||||
assert (
|
||||
len(cpp_namespaces) <= max_level
|
||||
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
|
||||
assert len(cpp_namespaces) <= max_level, (
|
||||
f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
|
||||
)
|
||||
self.cpp_namespace_ = namespace_str
|
||||
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
|
||||
self.epilogue_ = "\n".join(
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ class YamlLoader(Loader):
|
|||
mapping = []
|
||||
for key_node, value_node in node.value:
|
||||
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
||||
assert (
|
||||
key not in mapping
|
||||
), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
||||
assert key not in mapping, (
|
||||
f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
||||
)
|
||||
mapping.append(key)
|
||||
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
||||
return mapping
|
||||
|
|
|
|||
Loading…
Reference in a new issue