diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index fd3e51bdcca..de1a4b4e3c4 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -34,7 +34,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { (query_size_last != value_size_last)) { if (debug) { TORCH_WARN( - "OneDNN Graph's attention requires q,k,v to have the same last dimension.", + "OneDNN attention requires q,k,v to have the same last dimension.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", @@ -48,7 +48,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { if (query_size_last > 256) { if (debug) { TORCH_WARN( - "OneDNN Graph's attention requires q,k,v to have head dimension less than 256.", + "OneDNN attention requires q,k,v to have head dimension less than 256.", " Got ", query_size_last, " instead."); @@ -97,7 +97,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { // 1. Flash Attention // 2. Math fallback auto& ctx = at::globalContext(); - // use overrideable linked to onednn graph as mem efficient implementation + // use overrideable linked to onednn as overrideable implementation if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) { return sdp::SDPBackend::error; } @@ -135,7 +135,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { // reason why the kernel was not selected print_debug = true; - TORCH_WARN("OneDNN Graph kernel not used because:"); + TORCH_WARN("OneDNN kernel not used because:"); use_overrideable_xpu(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return sdp::SDPBackend::error; @@ -184,17 +184,17 @@ _scaled_dot_product_fused_attention_overrideable_xpu( bool is_causal, bool return_debug_mask, std::optional scale) { - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( query.dim() == 4 && key.dim() == 4 && value.dim() == 4, "scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && (key.size(2) == value.size(2)), "scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head"); - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( dropout_p == 0.0, "scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0"); - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( !(attn_bias.has_value() && is_causal), "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); @@ -225,7 +225,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( attn_mask_fallback = std::nullopt; } } - at::native::onednn::graph::gpu_float_sdpa( + at::native::onednn::gpu_float_sdpa( batch_size, seq_len_q, seq_len_kv, diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 1f666c45d8e..c7a71b2ad4a 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -1,12 +1,17 @@ #include -#include #include #include #include #include -using namespace at::native::onednn::graph; +using namespace at::native::onednn; +using logical_tensor = dnnl::graph::logical_tensor; +using data_type = logical_tensor::data_type; +using dims = logical_tensor::dims; +using op = dnnl::graph::op; +using partition = dnnl::graph::partition; + namespace { struct SDPALogicalParams { enum class TensorID { @@ -126,7 +131,7 @@ partition create_sdpa_graph_partition( {masked_qk_out.value()}, "mask_add"}; } else if (is_causal) { - TORCH_CHECK(false, "Causal mask must use fallback mask for now."); + TORCH_INTERNAL_ASSERT(false, "Causal mask must use fallback mask for now."); } op softmax{op_id++, op::kind::SoftMax, "softmax"}; @@ -143,8 +148,8 @@ partition create_sdpa_graph_partition( {params.output}, "matmul_v"}; - engine::kind ekind = engine::kind::gpu; - graph g(ekind); + constexpr auto ekind = dnnl::engine::kind::gpu; + dnnl::graph::graph g(ekind); g.add_op(matmul_qk); g.add_op(scale_div); if (mask_add.has_value()) { @@ -154,14 +159,14 @@ partition create_sdpa_graph_partition( g.add_op(matmul_v); g.finalize(); auto partitions = g.get_partitions(); - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( (partitions.size() == 1) && partitions[0].is_supported(), - "oneDNN Graph doesn't support this fusion pattern. If you'd like its support, please submit a issue."); + "oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue."); return partitions[0]; } } // namespace -namespace at::native::onednn::graph { +namespace at::native::onednn { void gpu_float_sdpa( int batch_size, int seq_len_q, @@ -188,11 +193,11 @@ void gpu_float_sdpa( : query.scalar_type() == c10::ScalarType::Half ? data_type::f16 : query.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16 : data_type::undef; - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( (logical_tensor_dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported"); - thread_local static GraphCache cache; + thread_local static PartitionCache cache; // cache key creation // patternID is determined on the basis of the arguments provided @@ -239,49 +244,34 @@ void gpu_float_sdpa( attn_mask->strides().end()); } - auto cp_entry_ref = cache.find_kernel(map_key); - if (!cp_entry_ref.has_value()) { - SDPALogicalParams logical_params( - query, key, value, attn_mask, output, logical_tensor_dtype); - - auto partition_ = cache.find_partition(patternID); - if (!partition_.has_value()) { - // partition cache no hit - // graph building and partitioning - partition sdp_partition = create_sdpa_graph_partition( - batch_size, - seq_len_q, - seq_len_k, - num_head, - head_dim, - is_causal, - logical_tensor_dtype, - logical_params); - partition_ = cache.insert_partition_cache(patternID, sdp_partition); - } - cp_entry sdp_cp_entry{ - /*.partition_ = */ partition_->get(), - /*.input_logical_tensors = */ logical_params.get_input(), - /*.output_logical_tensors = */ logical_params.get_output(), - }; - // partition compilation - sdp_cp_entry.cp = sdp_cp_entry.partition_.compile( - sdp_cp_entry.input_logical_tensors, - sdp_cp_entry.output_logical_tensors, - eng); - cp_entry_ref = cache.insert_fused_kernel_cache(map_key, sdp_cp_entry); + const SDPALogicalParams logical_params( + query, key, value, attn_mask, output, logical_tensor_dtype); + auto partition_ = cache.find_partition(patternID); + if (!partition_.has_value()) { + // partition cache no hit + // graph building and partitioning + partition sdp_partition = create_sdpa_graph_partition( + batch_size, + seq_len_q, + seq_len_k, + num_head, + head_dim, + is_causal, + logical_tensor_dtype, + logical_params); + partition_ = cache.insert_partition_cache(patternID, sdp_partition); } - - // partition execution - auto& sdp_cp_entry = cp_entry_ref->get(); - const auto& l_inputs = sdp_cp_entry.input_logical_tensors; - const auto& l_outputs = sdp_cp_entry.output_logical_tensors; + const auto l_inputs = logical_params.get_input(); + const auto l_outputs = logical_params.get_output(); + // partition compilation + auto compiled_partition = partition_->get().compile(l_inputs, l_outputs, eng); std::vector outputs = { {l_outputs[0], eng, output.data_ptr()}, }; size_t i = 0; std::vector inputs; + inputs.reserve(l_inputs.size()); inputs.emplace_back(l_inputs[i++], eng, query.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, key.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr()); @@ -289,6 +279,6 @@ void gpu_float_sdpa( inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); } inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); - sdp_cp_entry.cp.execute(strm, inputs, outputs); + compiled_partition.execute(strm, inputs, outputs); } -} // namespace at::native::onednn::graph +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Graph.h b/aten/src/ATen/native/mkldnn/xpu/detail/Graph.h index 6345c7235e9..c4cb0791488 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Graph.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Graph.h @@ -7,7 +7,7 @@ #include #include -namespace at::native::onednn::graph { +namespace at::native::onednn { using namespace dnnl::graph; using data_type = logical_tensor::data_type; @@ -93,4 +93,4 @@ struct GraphCache { return std::nullopt; } }; -} // namespace at::native::onednn::graph +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h index d793789607f..8c14e5593c6 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include @@ -96,4 +98,33 @@ dnnl::memory dnnl_memory_from_host_scalar( return mem; } +struct PartitionCache { + std::unordered_map, dnnl::graph::partition> partition_map_{}; + + // The first 8 bits are reserved + // bit 0: is int8 + // bit 1: is uint8 + // bit 2: fp16(0) / bf16(1) + // bit 3: is fp32 + // bit 4: is sdp pattern + // bit 5-7: N/A + // The rest of the bits depend upon the arguments provided + // However, down the line, we might have different bitsets for different + // patterns + dnnl::graph::partition& insert_partition_cache( + std::bitset<32>& patternID, + dnnl::graph::partition& p) { + partition_map_[patternID] = std::move(p); + return partition_map_[patternID]; + } + std::optional> find_partition( + std::bitset<32>& patternID) { + auto iter = partition_map_.find(patternID); + if (iter != partition_map_.end()) { + return iter->second; + } + return std::nullopt; + } +}; + } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index 4c97bd8a2f5..7c696755145 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -133,7 +133,6 @@ at::Tensor quantized_convolution( torch::List> unary_scalars, std::optional unary_algorithm); -namespace graph { void gpu_float_sdpa( int batch_size, int seq_len_q, @@ -149,5 +148,4 @@ void gpu_float_sdpa( bool is_causal, float softmax_scale, const Tensor& output); -} // namespace graph } // namespace at::native::onednn