diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 4cae6c6e93c..2accbc7d6e4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1452,8 +1452,9 @@ def estimate_runtime(node): return hint_int(d, fallback=4096) shape = [realize_symbol(s) for s in shape] + stride = [realize_symbol(s) for s in x.meta["tensor_meta"].stride] return x.meta["val"].new_empty_strided( - shape, stride=x.meta["tensor_meta"].stride + shape, stride=stride ) elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): return hint_int(x.meta["val"], fallback=4096)