mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
fix autotuning init issues (#132837)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132837 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
8b50d5398f
commit
ffd0d92c18
2 changed files with 4 additions and 2 deletions
|
|
@ -715,8 +715,10 @@ def flex_attention(
|
|||
+ list(mask_mod_other_buffers)
|
||||
)
|
||||
input_gen_fns = {
|
||||
4: create_num_blocks_fake_generator(full_kv_indices),
|
||||
4: create_num_blocks_fake_generator(kv_indices),
|
||||
5: create_indices_fake,
|
||||
6: create_num_blocks_fake_generator(full_kv_indices),
|
||||
7: create_indices_fake,
|
||||
}
|
||||
return (
|
||||
autotune_select_algorithm(
|
||||
|
|
|
|||
|
|
@ -769,7 +769,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
|
|||
|
||||
|
||||
def _apply_kernel_options(query, key, value, kernel_options):
|
||||
kernel_options = {} if kernel_options is None else kernel_options
|
||||
kernel_options = {} if kernel_options is None else dict(kernel_options)
|
||||
|
||||
if "ROWS_GUARANTEED_SAFE" not in kernel_options:
|
||||
kernel_options["ROWS_GUARANTEED_SAFE"] = False
|
||||
|
|
|
|||
Loading…
Reference in a new issue