From c5ffc85dab33f89046fff13da0ffe209fa0b0deb Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 7 Feb 2025 23:30:04 -0800 Subject: [PATCH] realize stride symbols in estimate_runtime ghstack-source-id: 1e948df9fc53942c8c2bd79536b15e58c2d68384 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752 --- torch/_functorch/partitioners.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 4cae6c6e93c..2accbc7d6e4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1452,8 +1452,9 @@ def estimate_runtime(node): return hint_int(d, fallback=4096) shape = [realize_symbol(s) for s in shape] + stride = [realize_symbol(s) for s in x.meta["tensor_meta"].stride] return x.meta["val"].new_empty_strided( - shape, stride=x.meta["tensor_meta"].stride + shape, stride=stride ) elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): return hint_int(x.meta["val"], fallback=4096)