This commit is contained in:
Ding, Yi1 2024-12-31 06:06:20 +00:00
parent 34f37bfac7
commit 9322904c1d
No known key found for this signature in database
GPG key ID: C226938FAA74C08C
5 changed files with 80 additions and 61 deletions

View file

@ -34,7 +34,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
(query_size_last != value_size_last)) {
if (debug) {
TORCH_WARN(
"OneDNN Graph's attention requires q,k,v to have the same last dimension.",
"OneDNN attention requires q,k,v to have the same last dimension.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
@ -48,7 +48,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
if (query_size_last > 256) {
if (debug) {
TORCH_WARN(
"OneDNN Graph's attention requires q,k,v to have head dimension less than 256.",
"OneDNN attention requires q,k,v to have head dimension less than 256.",
" Got ",
query_size_last,
" instead.");
@ -97,7 +97,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
// 1. Flash Attention
// 2. Math fallback
auto& ctx = at::globalContext();
// use overrideable linked to onednn graph as mem efficient implementation
// use overrideable linked to onednn as overrideable implementation
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) {
return sdp::SDPBackend::error;
}
@ -135,7 +135,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
// reason why the kernel was not selected
print_debug = true;
TORCH_WARN("OneDNN Graph kernel not used because:");
TORCH_WARN("OneDNN kernel not used because:");
use_overrideable_xpu(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return sdp::SDPBackend::error;
@ -184,17 +184,17 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head");
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0");
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
@ -225,7 +225,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
attn_mask_fallback = std::nullopt;
}
}
at::native::onednn::graph::gpu_float_sdpa(
at::native::onednn::gpu_float_sdpa(
batch_size,
seq_len_q,
seq_len_kv,

View file

@ -1,12 +1,17 @@
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Graph.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <omp.h>
#include <oneapi/dnnl/dnnl.hpp>
using namespace at::native::onednn::graph;
using namespace at::native::onednn;
using logical_tensor = dnnl::graph::logical_tensor;
using data_type = logical_tensor::data_type;
using dims = logical_tensor::dims;
using op = dnnl::graph::op;
using partition = dnnl::graph::partition;
namespace {
struct SDPALogicalParams {
enum class TensorID {
@ -126,7 +131,7 @@ partition create_sdpa_graph_partition(
{masked_qk_out.value()},
"mask_add"};
} else if (is_causal) {
TORCH_CHECK(false, "Causal mask must use fallback mask for now.");
TORCH_INTERNAL_ASSERT(false, "Causal mask must use fallback mask for now.");
}
op softmax{op_id++, op::kind::SoftMax, "softmax"};
@ -143,8 +148,8 @@ partition create_sdpa_graph_partition(
{params.output},
"matmul_v"};
engine::kind ekind = engine::kind::gpu;
graph g(ekind);
constexpr auto ekind = dnnl::engine::kind::gpu;
dnnl::graph::graph g(ekind);
g.add_op(matmul_qk);
g.add_op(scale_div);
if (mask_add.has_value()) {
@ -154,14 +159,14 @@ partition create_sdpa_graph_partition(
g.add_op(matmul_v);
g.finalize();
auto partitions = g.get_partitions();
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
(partitions.size() == 1) && partitions[0].is_supported(),
"oneDNN Graph doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
"oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
return partitions[0];
}
} // namespace
namespace at::native::onednn::graph {
namespace at::native::onednn {
void gpu_float_sdpa(
int batch_size,
int seq_len_q,
@ -188,11 +193,11 @@ void gpu_float_sdpa(
: query.scalar_type() == c10::ScalarType::Half ? data_type::f16
: query.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
: data_type::undef;
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
(logical_tensor_dtype != data_type::undef),
"Only FP16/BF16/FP32 datatypes are currently supported");
thread_local static GraphCache cache;
thread_local static PartitionCache cache;
// cache key creation
// patternID is determined on the basis of the arguments provided
@ -239,49 +244,34 @@ void gpu_float_sdpa(
attn_mask->strides().end());
}
auto cp_entry_ref = cache.find_kernel(map_key);
if (!cp_entry_ref.has_value()) {
SDPALogicalParams logical_params(
query, key, value, attn_mask, output, logical_tensor_dtype);
auto partition_ = cache.find_partition(patternID);
if (!partition_.has_value()) {
// partition cache no hit
// graph building and partitioning
partition sdp_partition = create_sdpa_graph_partition(
batch_size,
seq_len_q,
seq_len_k,
num_head,
head_dim,
is_causal,
logical_tensor_dtype,
logical_params);
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
}
cp_entry sdp_cp_entry{
/*.partition_ = */ partition_->get(),
/*.input_logical_tensors = */ logical_params.get_input(),
/*.output_logical_tensors = */ logical_params.get_output(),
};
// partition compilation
sdp_cp_entry.cp = sdp_cp_entry.partition_.compile(
sdp_cp_entry.input_logical_tensors,
sdp_cp_entry.output_logical_tensors,
eng);
cp_entry_ref = cache.insert_fused_kernel_cache(map_key, sdp_cp_entry);
const SDPALogicalParams logical_params(
query, key, value, attn_mask, output, logical_tensor_dtype);
auto partition_ = cache.find_partition(patternID);
if (!partition_.has_value()) {
// partition cache no hit
// graph building and partitioning
partition sdp_partition = create_sdpa_graph_partition(
batch_size,
seq_len_q,
seq_len_k,
num_head,
head_dim,
is_causal,
logical_tensor_dtype,
logical_params);
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
}
// partition execution
auto& sdp_cp_entry = cp_entry_ref->get();
const auto& l_inputs = sdp_cp_entry.input_logical_tensors;
const auto& l_outputs = sdp_cp_entry.output_logical_tensors;
const auto l_inputs = logical_params.get_input();
const auto l_outputs = logical_params.get_output();
// partition compilation
auto compiled_partition = partition_->get().compile(l_inputs, l_outputs, eng);
std::vector<dnnl::graph::tensor> outputs = {
{l_outputs[0], eng, output.data_ptr()},
};
size_t i = 0;
std::vector<dnnl::graph::tensor> inputs;
inputs.reserve(l_inputs.size());
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr());
@ -289,6 +279,6 @@ void gpu_float_sdpa(
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
}
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
sdp_cp_entry.cp.execute(strm, inputs, outputs);
compiled_partition.execute(strm, inputs, outputs);
}
} // namespace at::native::onednn::graph
} // namespace at::native::onednn

View file

@ -7,7 +7,7 @@
#include <bitset>
#include <list>
namespace at::native::onednn::graph {
namespace at::native::onednn {
using namespace dnnl::graph;
using data_type = logical_tensor::data_type;
@ -93,4 +93,4 @@ struct GraphCache {
return std::nullopt;
}
};
} // namespace at::native::onednn::graph
} // namespace at::native::onednn

View file

@ -7,6 +7,8 @@
#include <ATen/core/grad_mode.h>
#include <c10/core/MemoryFormat.h>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
#include <oneapi/dnnl/dnnl_sycl.hpp>
#include <oneapi/dnnl/dnnl_version.h>
@ -96,4 +98,33 @@ dnnl::memory dnnl_memory_from_host_scalar(
return mem;
}
struct PartitionCache {
std::unordered_map<std::bitset<32>, dnnl::graph::partition> partition_map_{};
// The first 8 bits are reserved
// bit 0: is int8
// bit 1: is uint8
// bit 2: fp16(0) / bf16(1)
// bit 3: is fp32
// bit 4: is sdp pattern
// bit 5-7: N/A
// The rest of the bits depend upon the arguments provided
// However, down the line, we might have different bitsets for different
// patterns
dnnl::graph::partition& insert_partition_cache(
std::bitset<32>& patternID,
dnnl::graph::partition& p) {
partition_map_[patternID] = std::move(p);
return partition_map_[patternID];
}
std::optional<std::reference_wrapper<dnnl::graph::partition>> find_partition(
std::bitset<32>& patternID) {
auto iter = partition_map_.find(patternID);
if (iter != partition_map_.end()) {
return iter->second;
}
return std::nullopt;
}
};
} // namespace at::native::onednn

View file

@ -133,7 +133,6 @@ at::Tensor quantized_convolution(
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm);
namespace graph {
void gpu_float_sdpa(
int batch_size,
int seq_len_q,
@ -149,5 +148,4 @@ void gpu_float_sdpa(
bool is_causal,
float softmax_scale,
const Tensor& output);
} // namespace graph
} // namespace at::native::onednn