fix autotuning init issues (#132837)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132837
Approved by: https://github.com/yanboliang
This commit is contained in:
chilli 2024-08-06 23:19:04 -07:00 committed by PyTorch MergeBot
parent 8b50d5398f
commit ffd0d92c18
2 changed files with 4 additions and 2 deletions

View file

@ -715,8 +715,10 @@ def flex_attention(
+ list(mask_mod_other_buffers)
)
input_gen_fns = {
4: create_num_blocks_fake_generator(full_kv_indices),
4: create_num_blocks_fake_generator(kv_indices),
5: create_indices_fake,
6: create_num_blocks_fake_generator(full_kv_indices),
7: create_indices_fake,
}
return (
autotune_select_algorithm(

View file

@ -769,7 +769,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
def _apply_kernel_options(query, key, value, kernel_options):
kernel_options = {} if kernel_options is None else kernel_options
kernel_options = {} if kernel_options is None else dict(kernel_options)
if "ROWS_GUARANTEED_SAFE" not in kernel_options:
kernel_options["ROWS_GUARANTEED_SAFE"] = False