pytorch/test/export
Pian Pawakapan f206c5c628 [export] handle new roots & root swapping in derived dims suggested fixes (#125543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543

This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping.

1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions.

For example:
```
"dx": {"eq": 3*_dx-1, "max": 36}
"dy": {"eq": dx+1}
This should lead to suggested fixes
  _dx = Dim('_dx', max=12)
  dx = 3 * _dx - 1
  dy = 3 * _dx
```

This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both.

2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name.

This PR only adds a suggested fix for the root dim, and removes all other derived suggestions.

For example, with the export test case test_derived_dim_out_of_order_simplified:
```
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
dimx = dimy - 1
dimz = torch.export.Dim("dimz", min=6, max=8)  # doesn't work, should be = _dimz

class Foo(torch.nn.Module):
    def forward(self, x, y, z):
        return x + y[1:] + z[2:]

foo = Foo()
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
export(
    foo,
    (u, v, w),
    dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
```

Before:
```
Suggested fixes:
  _dimz = Dim('_dimz', min=3, max=9223372036854775807)  # 2 <= _dimz - 1 <= 9223372036854775806
  _dimz - 2 = Dim('_dimz - 2', min=4, max=6)
  _dimz = Dim('_dimz', min=2, max=9223372036854775806)  # 2 <= _dimz <= 9223372036854775806
  _dimz - 1 = _dimz - 1
  dimz = _dimz
```

New suggested fixes:
```
Suggested fixes:
  dimz = _dimz
```

Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example:
```
def forward(self, x, y):
    return x.reshape([-1]) + y  # guard: s0 * 2 = s1
dx = Dim("dx")
export(
    model,
    (torch.randn(6, 2), torch.randn(12)),
    dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}
)
```
This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543
Approved by: https://github.com/avikchaudhuri
2024-05-28 20:41:43 +00:00
..
__init__.py
opinfo_schema.py [export] add SchemaCheckMode testing for pre-dispatch export, OpInfo (#125481) 2024-05-14 21:07:21 +00:00
test_converter.py TorchScript 2 ExportedProgram Converter (#126920) 2024-05-23 17:00:18 +00:00
test_db.py UFMT formatting on test/export (#123520) 2024-04-10 05:38:42 +00:00
test_experimental.py [Fix]: populate input parameter name when convert TorchScript to ExportedProgram (#126787) 2024-05-28 17:33:44 +00:00
test_export.py [export] handle new roots & root swapping in derived dims suggested fixes (#125543) 2024-05-28 20:41:43 +00:00
test_export_nonstrict.py Remove several expectedFailureNonStrict (#122802) 2024-03-28 00:42:49 +00:00
test_functionalized_assertions.py
test_hop.py [while_loop] add a simiple op_info test (#123814) 2024-04-11 19:59:04 +00:00
test_lift_unlift.py [export] handle constant aliasing for export (#125509) 2024-05-10 00:14:37 +00:00
test_pass_infra.py [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126) 2024-05-27 14:49:57 +00:00
test_passes.py [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126) 2024-05-27 14:49:57 +00:00
test_retraceability.py [Export] Add runtime assert to non-strict export (#123681) 2024-04-18 16:13:27 +00:00
test_schema.py
test_serdes.py [export] Delete predispatch tests (#126459) 2024-05-17 00:48:32 +00:00
test_serialize.py [RELAND] Switch default behavoir of export IR to be predispatch (#125860) 2024-05-10 17:36:53 +00:00
test_sparse.py [traced-graph][sparse] propagate sparsity metadata into traced graph (#117907) 2024-05-23 22:46:46 +00:00
test_tools.py Tool for scouting exportability in one shot (#126471) 2024-05-18 00:10:46 +00:00
test_torchbind.py [export] handle constant aliasing for export (#125509) 2024-05-10 00:14:37 +00:00
test_tree_utils.py
test_unflatten.py [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126) 2024-05-27 14:49:57 +00:00
test_verifier.py [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126) 2024-05-27 14:49:57 +00:00
testing.py Add tests for pre_dispatch + run_decomp flow and taskify failures (#122508) 2024-03-29 01:47:07 +00:00