mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[while_loop][inductor] support sym expression as cond_fn output (#146222)"
This reverts commit 5ecdc428b2.
Reverted https://github.com/pytorch/pytorch/pull/146222 on behalf of https://github.com/atalman due to Internal failure, please see associated diff ([comment](https://github.com/pytorch/pytorch/pull/146222#issuecomment-2643379933))
This commit is contained in:
parent
5d7532140f
commit
076717785c
5 changed files with 13 additions and 79 deletions
|
|
@ -1511,26 +1511,6 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_while_loop_with_sym_expr_cond(self, dynamic):
|
||||
inputs = (
|
||||
torch.randn(10, 20, device=self.device),
|
||||
torch.randn(10, 20, device=self.device),
|
||||
)
|
||||
dim0_ab = Dim("s0", min=2, max=1024)
|
||||
dynamic_shapes = None
|
||||
if dynamic:
|
||||
dynamic_shapes = {
|
||||
"c": {},
|
||||
"a": {0: dim0_ab, 1: None},
|
||||
"b": {0: dim0_ab, 1: None},
|
||||
}
|
||||
self.check_model_with_multiple_inputs(
|
||||
WhileLoopModels.SymExprCond(),
|
||||
prepend_counters(inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@config.patch({"is_predispatch": True})
|
||||
def test_constant(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -876,23 +876,6 @@ class WhileLoopModels:
|
|||
[c, a, b],
|
||||
)
|
||||
|
||||
class SymExprCond(torch.nn.Module):
|
||||
def forward(self, c, a, b):
|
||||
d = a.sum().to(torch.int64).item()
|
||||
e = torch.nonzero(b).size(0)
|
||||
|
||||
def cond_fn(c, a, b):
|
||||
return d + e + a.shape[0] - b.shape[0] < 10
|
||||
|
||||
def body_fn(c, a, b):
|
||||
return c + 1, a + e, b + d
|
||||
|
||||
return torch._higher_order_ops.while_loop(
|
||||
cond_fn,
|
||||
body_fn,
|
||||
[c, a, b],
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopTests(TestCase):
|
||||
def _run_test(
|
||||
|
|
@ -1156,23 +1139,6 @@ class WhileLoopTests(TestCase):
|
|||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@torch._dynamo.config.patch(
|
||||
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
||||
)
|
||||
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=WhileLoopModels.SymExprCond(),
|
||||
inputs=(
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
|
||||
class AssociativeScanTests(TestCase):
|
||||
@requires_gpu
|
||||
|
|
|
|||
|
|
@ -1594,14 +1594,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
subgraph.graph.graph_outputs, outer_outputs
|
||||
):
|
||||
src = inner_output.codegen_reference()
|
||||
if not isinstance(inner_output, ir.ShapeAsConstantBuffer):
|
||||
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
|
||||
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
|
||||
# constructor is deleted.
|
||||
src = f"std::move({src})"
|
||||
# in case the outer_output carried a value
|
||||
# before (e.g., in the while_loop codegen)
|
||||
self.writeline(f"{outer_output}.reset();")
|
||||
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
|
||||
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
|
||||
# constructor is deleted.
|
||||
src = f"std::move({src})"
|
||||
# in case the outer_output carried a value
|
||||
# before (e.g., in the while_loop codegen)
|
||||
self.writeline(f"{outer_output}.reset();")
|
||||
self.writeline(f"{outer_output} = {src};")
|
||||
|
||||
def codegen_invoke_subgraph(self, invoke_subgraph):
|
||||
|
|
@ -1663,9 +1662,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
self.pop_codegened_graph()
|
||||
|
||||
def codegen_while_loop(self, while_loop):
|
||||
is_bool_pred = isinstance(
|
||||
while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
|
||||
)
|
||||
name = while_loop.get_name()
|
||||
outer_carried_inputs = [
|
||||
buf.codegen_reference() for buf in while_loop.carried_inputs
|
||||
|
|
@ -1674,10 +1670,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
buf.codegen_reference() for buf in while_loop.additional_inputs
|
||||
]
|
||||
cond_result_name = f"{name}_cond_result"
|
||||
if is_bool_pred:
|
||||
self.writeline(f"bool {cond_result_name};")
|
||||
else:
|
||||
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
|
||||
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
|
||||
|
||||
cond_outer_inputs = []
|
||||
for inp, out in zip(outer_carried_inputs, while_loop.outputs):
|
||||
|
|
@ -1707,11 +1700,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
|
||||
)
|
||||
|
||||
if is_bool_pred:
|
||||
cond_result = f"{cond_result_name}"
|
||||
else:
|
||||
cond_result = f"{cond_result_name}_scalar"
|
||||
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
|
||||
cond_result = f"{cond_result_name}_scalar"
|
||||
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
|
||||
self.writeline(f"if (!{cond_result}) break;")
|
||||
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
|
|
|||
|
|
@ -2551,7 +2551,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
|
||||
)
|
||||
self.writeline(
|
||||
f"if not {cond_outer_outputs[0]}: break"
|
||||
f"if not {cond_outer_outputs[0]}.item(): break"
|
||||
) # condition doesn't hold
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
|
|
|
|||
|
|
@ -7407,10 +7407,8 @@ class WhileLoop(ExternKernel):
|
|||
|
||||
# make sure cond_fn returns a boolean scalar Tensor
|
||||
assert len(cond_outputs) == 1, cond_outputs
|
||||
p = cond_outputs[0]
|
||||
if not isinstance(p, ShapeAsConstantBuffer):
|
||||
assert p.get_dtype() == torch.bool, p
|
||||
assert len(p.get_size()) == 0, p
|
||||
assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs
|
||||
assert len(cond_outputs[0].get_size()) == 0, cond_outputs
|
||||
|
||||
assert (
|
||||
len(all_inputs) > 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue