mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
parent
02ed609285
commit
ed73ae210b
1 changed files with 3 additions and 0 deletions
|
|
@ -361,6 +361,9 @@ def is_torch_sdpa_available():
|
|||
# NOTE: MLU is OK with non-contiguous inputs.
|
||||
if is_torch_mlu_available():
|
||||
return version.parse(_torch_version) >= version.parse("2.1.0")
|
||||
# NOTE: NPU can use SDPA in Transformers with torch>=2.1.0.
|
||||
if is_torch_npu_available():
|
||||
return version.parse(_torch_version) >= version.parse("2.1.0")
|
||||
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
|
||||
return version.parse(_torch_version) >= version.parse("2.1.1")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue