realize stride symbols in estimate_runtime

ghstack-source-id: 1e948df9fc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
This commit is contained in:
Laith Sakka 2025-02-07 23:30:04 -08:00
parent c70362fac8
commit c5ffc85dab

View file

@ -1452,8 +1452,9 @@ def estimate_runtime(node):
return hint_int(d, fallback=4096)
shape = [realize_symbol(s) for s in shape]
stride = [realize_symbol(s) for s in x.meta["tensor_meta"].stride]
return x.meta["val"].new_empty_strided(
shape, stride=x.meta["tensor_meta"].stride
shape, stride=stride
)
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
return hint_int(x.meta["val"], fallback=4096)