mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
### Description <!-- Describe your changes. --> 1. upgrade cutlass to 3.0 that containing attn_bias support. 2. extend Attention/MHA to use memory efficient attention when rel_pos_bias with [1, num_head, s, s*] and 1d mask with [2 * batch_size + 1] are present. new mask format introduction: MASK_1D_KEY_SEQ_LEN_START, [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., key_start[batch_size - 1], key_end[batch_size - 1]] e.g 2D mask with [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0]] converts to this 1D mask is [3, 5, 0, 6, 12, 0, 6, 12] ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It potentially benefits tnlrv6 and t5(encoder) --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net> Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com> Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
92 lines
2.4 KiB
Diff
92 lines
2.4 KiB
Diff
diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp
|
|
index 3790ebd3..cf727d09 100644
|
|
--- a/include/cute/numeric/complex.hpp
|
|
+++ b/include/cute/numeric/complex.hpp
|
|
@@ -41,10 +41,14 @@
|
|
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
|
|
// on line 656 of thrust/detail/type_traits.h.
|
|
// These pragmas suppress the warnings.
|
|
+#ifdef __GNUC__
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wconversion"
|
|
+#endif
|
|
#include <thrust/complex.h>
|
|
+#ifdef __GNUC__
|
|
#pragma GCC diagnostic pop
|
|
+#endif
|
|
|
|
#include <cute/config.hpp>
|
|
|
|
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
|
|
index 59aec46a..8f2a913a 100644
|
|
--- a/include/cutlass/functional.h
|
|
+++ b/include/cutlass/functional.h
|
|
@@ -89,7 +89,7 @@ struct multiplies {
|
|
}
|
|
};
|
|
|
|
-#if defined(__CUDA_ARCH__)
|
|
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
|
/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set
|
|
template<>
|
|
struct plus<__half2> {
|
|
@@ -143,12 +143,12 @@ struct multiplies<__half> {
|
|
|
|
|
|
// Maximum with nan propogation
|
|
-// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
|
+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
|
template <typename T>
|
|
struct maximum_with_nan_propogation {
|
|
CUTLASS_HOST_DEVICE
|
|
T operator()(T const &lhs, T const &rhs) const {
|
|
- return lhs > rhs or std::isnan(lhs) ? lhs : rhs;
|
|
+ return lhs > rhs or isnan(lhs) ? lhs : rhs;
|
|
}
|
|
};
|
|
|
|
@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation<float> {
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs));
|
|
#else
|
|
- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs;
|
|
+ res = lhs > rhs or isnan(lhs) ? lhs : rhs;
|
|
#endif
|
|
return res;
|
|
}
|
|
@@ -233,7 +233,7 @@ struct negate {
|
|
}
|
|
};
|
|
|
|
-/// Greater equal
|
|
+/// Greater equal
|
|
template <typename T>
|
|
struct greater_equal {
|
|
CUTLASS_HOST_DEVICE
|
|
@@ -242,7 +242,7 @@ struct greater_equal {
|
|
}
|
|
};
|
|
|
|
-/// Greater
|
|
+/// Greater
|
|
template <typename T>
|
|
struct greater {
|
|
CUTLASS_HOST_DEVICE
|
|
@@ -251,7 +251,7 @@ struct greater {
|
|
}
|
|
};
|
|
|
|
-/// Less equal
|
|
+/// Less equal
|
|
template <typename T>
|
|
struct less_equal {
|
|
CUTLASS_HOST_DEVICE
|
|
@@ -260,7 +260,7 @@ struct less_equal {
|
|
}
|
|
};
|
|
|
|
-/// Less
|
|
+/// Less
|
|
template <typename T>
|
|
struct less {
|
|
CUTLASS_HOST_DEVICE
|