mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Clean up
This commit is contained in:
parent
34f37bfac7
commit
9322904c1d
5 changed files with 80 additions and 61 deletions
|
|
@ -34,7 +34,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
|
|||
(query_size_last != value_size_last)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"OneDNN Graph's attention requires q,k,v to have the same last dimension.",
|
||||
"OneDNN attention requires q,k,v to have the same last dimension.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
|
|
@ -48,7 +48,7 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
|
|||
if (query_size_last > 256) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"OneDNN Graph's attention requires q,k,v to have head dimension less than 256.",
|
||||
"OneDNN attention requires q,k,v to have head dimension less than 256.",
|
||||
" Got ",
|
||||
query_size_last,
|
||||
" instead.");
|
||||
|
|
@ -97,7 +97,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
|||
// 1. Flash Attention
|
||||
// 2. Math fallback
|
||||
auto& ctx = at::globalContext();
|
||||
// use overrideable linked to onednn graph as mem efficient implementation
|
||||
// use overrideable linked to onednn as overrideable implementation
|
||||
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) {
|
||||
return sdp::SDPBackend::error;
|
||||
}
|
||||
|
|
@ -135,7 +135,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
|||
// reason why the kernel was not selected
|
||||
|
||||
print_debug = true;
|
||||
TORCH_WARN("OneDNN Graph kernel not used because:");
|
||||
TORCH_WARN("OneDNN kernel not used because:");
|
||||
use_overrideable_xpu(kernel_params, print_debug);
|
||||
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
|
||||
return sdp::SDPBackend::error;
|
||||
|
|
@ -184,17 +184,17 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
|
|||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
std::optional<double> scale) {
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
|
||||
(key.size(2) == value.size(2)),
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: K/V should have the same batch / seq / num_head");
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dropout_p == 0.0,
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0");
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!(attn_bias.has_value() && is_causal),
|
||||
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
|
|||
attn_mask_fallback = std::nullopt;
|
||||
}
|
||||
}
|
||||
at::native::onednn::graph::gpu_float_sdpa(
|
||||
at::native::onednn::gpu_float_sdpa(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_kv,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Graph.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
|
||||
#include <omp.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
using namespace at::native::onednn::graph;
|
||||
using namespace at::native::onednn;
|
||||
using logical_tensor = dnnl::graph::logical_tensor;
|
||||
using data_type = logical_tensor::data_type;
|
||||
using dims = logical_tensor::dims;
|
||||
using op = dnnl::graph::op;
|
||||
using partition = dnnl::graph::partition;
|
||||
|
||||
namespace {
|
||||
struct SDPALogicalParams {
|
||||
enum class TensorID {
|
||||
|
|
@ -126,7 +131,7 @@ partition create_sdpa_graph_partition(
|
|||
{masked_qk_out.value()},
|
||||
"mask_add"};
|
||||
} else if (is_causal) {
|
||||
TORCH_CHECK(false, "Causal mask must use fallback mask for now.");
|
||||
TORCH_INTERNAL_ASSERT(false, "Causal mask must use fallback mask for now.");
|
||||
}
|
||||
|
||||
op softmax{op_id++, op::kind::SoftMax, "softmax"};
|
||||
|
|
@ -143,8 +148,8 @@ partition create_sdpa_graph_partition(
|
|||
{params.output},
|
||||
"matmul_v"};
|
||||
|
||||
engine::kind ekind = engine::kind::gpu;
|
||||
graph g(ekind);
|
||||
constexpr auto ekind = dnnl::engine::kind::gpu;
|
||||
dnnl::graph::graph g(ekind);
|
||||
g.add_op(matmul_qk);
|
||||
g.add_op(scale_div);
|
||||
if (mask_add.has_value()) {
|
||||
|
|
@ -154,14 +159,14 @@ partition create_sdpa_graph_partition(
|
|||
g.add_op(matmul_v);
|
||||
g.finalize();
|
||||
auto partitions = g.get_partitions();
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(partitions.size() == 1) && partitions[0].is_supported(),
|
||||
"oneDNN Graph doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
|
||||
"oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue.");
|
||||
return partitions[0];
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at::native::onednn::graph {
|
||||
namespace at::native::onednn {
|
||||
void gpu_float_sdpa(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
|
|
@ -188,11 +193,11 @@ void gpu_float_sdpa(
|
|||
: query.scalar_type() == c10::ScalarType::Half ? data_type::f16
|
||||
: query.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16
|
||||
: data_type::undef;
|
||||
TORCH_CHECK(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(logical_tensor_dtype != data_type::undef),
|
||||
"Only FP16/BF16/FP32 datatypes are currently supported");
|
||||
|
||||
thread_local static GraphCache cache;
|
||||
thread_local static PartitionCache cache;
|
||||
|
||||
// cache key creation
|
||||
// patternID is determined on the basis of the arguments provided
|
||||
|
|
@ -239,49 +244,34 @@ void gpu_float_sdpa(
|
|||
attn_mask->strides().end());
|
||||
}
|
||||
|
||||
auto cp_entry_ref = cache.find_kernel(map_key);
|
||||
if (!cp_entry_ref.has_value()) {
|
||||
SDPALogicalParams logical_params(
|
||||
query, key, value, attn_mask, output, logical_tensor_dtype);
|
||||
|
||||
auto partition_ = cache.find_partition(patternID);
|
||||
if (!partition_.has_value()) {
|
||||
// partition cache no hit
|
||||
// graph building and partitioning
|
||||
partition sdp_partition = create_sdpa_graph_partition(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
num_head,
|
||||
head_dim,
|
||||
is_causal,
|
||||
logical_tensor_dtype,
|
||||
logical_params);
|
||||
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
|
||||
}
|
||||
cp_entry sdp_cp_entry{
|
||||
/*.partition_ = */ partition_->get(),
|
||||
/*.input_logical_tensors = */ logical_params.get_input(),
|
||||
/*.output_logical_tensors = */ logical_params.get_output(),
|
||||
};
|
||||
// partition compilation
|
||||
sdp_cp_entry.cp = sdp_cp_entry.partition_.compile(
|
||||
sdp_cp_entry.input_logical_tensors,
|
||||
sdp_cp_entry.output_logical_tensors,
|
||||
eng);
|
||||
cp_entry_ref = cache.insert_fused_kernel_cache(map_key, sdp_cp_entry);
|
||||
const SDPALogicalParams logical_params(
|
||||
query, key, value, attn_mask, output, logical_tensor_dtype);
|
||||
auto partition_ = cache.find_partition(patternID);
|
||||
if (!partition_.has_value()) {
|
||||
// partition cache no hit
|
||||
// graph building and partitioning
|
||||
partition sdp_partition = create_sdpa_graph_partition(
|
||||
batch_size,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
num_head,
|
||||
head_dim,
|
||||
is_causal,
|
||||
logical_tensor_dtype,
|
||||
logical_params);
|
||||
partition_ = cache.insert_partition_cache(patternID, sdp_partition);
|
||||
}
|
||||
|
||||
// partition execution
|
||||
auto& sdp_cp_entry = cp_entry_ref->get();
|
||||
const auto& l_inputs = sdp_cp_entry.input_logical_tensors;
|
||||
const auto& l_outputs = sdp_cp_entry.output_logical_tensors;
|
||||
const auto l_inputs = logical_params.get_input();
|
||||
const auto l_outputs = logical_params.get_output();
|
||||
// partition compilation
|
||||
auto compiled_partition = partition_->get().compile(l_inputs, l_outputs, eng);
|
||||
|
||||
std::vector<dnnl::graph::tensor> outputs = {
|
||||
{l_outputs[0], eng, output.data_ptr()},
|
||||
};
|
||||
size_t i = 0;
|
||||
std::vector<dnnl::graph::tensor> inputs;
|
||||
inputs.reserve(l_inputs.size());
|
||||
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
|
||||
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
|
||||
inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr());
|
||||
|
|
@ -289,6 +279,6 @@ void gpu_float_sdpa(
|
|||
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
|
||||
}
|
||||
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
|
||||
sdp_cp_entry.cp.execute(strm, inputs, outputs);
|
||||
compiled_partition.execute(strm, inputs, outputs);
|
||||
}
|
||||
} // namespace at::native::onednn::graph
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include <bitset>
|
||||
#include <list>
|
||||
|
||||
namespace at::native::onednn::graph {
|
||||
namespace at::native::onednn {
|
||||
|
||||
using namespace dnnl::graph;
|
||||
using data_type = logical_tensor::data_type;
|
||||
|
|
@ -93,4 +93,4 @@ struct GraphCache {
|
|||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
} // namespace at::native::onednn::graph
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
#include <ATen/core/grad_mode.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_graph.hpp>
|
||||
#include <oneapi/dnnl/dnnl_graph_sycl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_sycl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_version.h>
|
||||
|
||||
|
|
@ -96,4 +98,33 @@ dnnl::memory dnnl_memory_from_host_scalar(
|
|||
return mem;
|
||||
}
|
||||
|
||||
struct PartitionCache {
|
||||
std::unordered_map<std::bitset<32>, dnnl::graph::partition> partition_map_{};
|
||||
|
||||
// The first 8 bits are reserved
|
||||
// bit 0: is int8
|
||||
// bit 1: is uint8
|
||||
// bit 2: fp16(0) / bf16(1)
|
||||
// bit 3: is fp32
|
||||
// bit 4: is sdp pattern
|
||||
// bit 5-7: N/A
|
||||
// The rest of the bits depend upon the arguments provided
|
||||
// However, down the line, we might have different bitsets for different
|
||||
// patterns
|
||||
dnnl::graph::partition& insert_partition_cache(
|
||||
std::bitset<32>& patternID,
|
||||
dnnl::graph::partition& p) {
|
||||
partition_map_[patternID] = std::move(p);
|
||||
return partition_map_[patternID];
|
||||
}
|
||||
std::optional<std::reference_wrapper<dnnl::graph::partition>> find_partition(
|
||||
std::bitset<32>& patternID) {
|
||||
auto iter = partition_map_.find(patternID);
|
||||
if (iter != partition_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -133,7 +133,6 @@ at::Tensor quantized_convolution(
|
|||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm);
|
||||
|
||||
namespace graph {
|
||||
void gpu_float_sdpa(
|
||||
int batch_size,
|
||||
int seq_len_q,
|
||||
|
|
@ -149,5 +148,4 @@ void gpu_float_sdpa(
|
|||
bool is_causal,
|
||||
float softmax_scale,
|
||||
const Tensor& output);
|
||||
} // namespace graph
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
Loading…
Reference in a new issue