mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543 This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions. For example: ``` "dx": {"eq": 3*_dx-1, "max": 36} "dy": {"eq": dx+1} This should lead to suggested fixes _dx = Dim('_dx', max=12) dx = 3 * _dx - 1 dy = 3 * _dx ``` This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125543 Approved by: https://github.com/avikchaudhuri |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| opinfo_schema.py | ||
| test_converter.py | ||
| test_db.py | ||
| test_experimental.py | ||
| test_export.py | ||
| test_export_nonstrict.py | ||
| test_functionalized_assertions.py | ||
| test_hop.py | ||
| test_lift_unlift.py | ||
| test_pass_infra.py | ||
| test_passes.py | ||
| test_retraceability.py | ||
| test_schema.py | ||
| test_serdes.py | ||
| test_serialize.py | ||
| test_sparse.py | ||
| test_tools.py | ||
| test_torchbind.py | ||
| test_tree_utils.py | ||
| test_unflatten.py | ||
| test_verifier.py | ||
| testing.py | ||