mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
realize stride symbols in estimate_runtime
ghstack-source-id: 1e948df9fc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
This commit is contained in:
parent
c70362fac8
commit
c5ffc85dab
1 changed files with 2 additions and 1 deletions
|
|
@ -1452,8 +1452,9 @@ def estimate_runtime(node):
|
|||
return hint_int(d, fallback=4096)
|
||||
|
||||
shape = [realize_symbol(s) for s in shape]
|
||||
stride = [realize_symbol(s) for s in x.meta["tensor_meta"].stride]
|
||||
return x.meta["val"].new_empty_strided(
|
||||
shape, stride=x.meta["tensor_meta"].stride
|
||||
shape, stride=stride
|
||||
)
|
||||
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
|
||||
return hint_int(x.meta["val"], fallback=4096)
|
||||
|
|
|
|||
Loading…
Reference in a new issue