Update on "[dynamo][not ready] polyfill infra for classes"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
This commit is contained in:
Animesh Jain 2025-02-09 15:54:17 -08:00
commit f770084b53
2 changed files with 6 additions and 7 deletions

View file

@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify(
T sinc(T a) {
if (a == T(0)) {
return T(1);
} else {
constexpr T pi = T(3.14159265358979323846L);
T product = pi * a;
return std::sin(product) / product;
}
constexpr T pi = T(3.14159265358979323846L);
T product = pi * a;
return std::sin(product) / product;
}
); // sinc_string

View file

@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
def test_strided_backwards(self):
shape = (1, 2, 4096, 64)
Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
Q = torch.randn(shape, requires_grad=True, device="cuda")
K = torch.randn(shape, requires_grad=True, device="cuda")
V = torch.randn(shape, requires_grad=True, device="cuda")
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
K_sliced = K[:, :, :-128]