mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update on "[dynamo][not ready] polyfill infra for classes"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
This commit is contained in:
commit
f770084b53
2 changed files with 6 additions and 7 deletions
|
|
@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify(
|
|||
T sinc(T a) {
|
||||
if (a == T(0)) {
|
||||
return T(1);
|
||||
} else {
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
); // sinc_string
|
||||
|
||||
|
|
|
|||
|
|
@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
@supported_platform
|
||||
def test_strided_backwards(self):
|
||||
shape = (1, 2, 4096, 64)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
|
||||
|
||||
K_sliced = K[:, :, :-128]
|
||||
|
|
|
|||
Loading…
Reference in a new issue