mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Added code for FusedMatMul inside matmul op primitive (#11077)
This commit is contained in:
parent
ea004e953f
commit
112dec6565
4 changed files with 234 additions and 12 deletions
|
|
@ -19,6 +19,7 @@ DnnlOpManager::DnnlOpManager() {
|
|||
dnnl_ops_map_.emplace(std::make_pair("Erf", std::unique_ptr<DnnlNodeCapability>(new DnnlErfNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("FastGelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("FusedMatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Gelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr<DnnlNodeCapability>(new DnnlGemmNodeCapability())));
|
||||
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "dnnl_matmul.h"
|
||||
#include "dnnl_subgraph.h"
|
||||
#include "dnnl_subgraph_primitive.h"
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
|
@ -20,27 +21,130 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
|||
assert(node.Input(IN_BINARY).Exists());
|
||||
}
|
||||
|
||||
bool is_fusedmatmul = false;
|
||||
bool transA = false;
|
||||
bool transBatchA = false;
|
||||
bool transB = false;
|
||||
bool transBatchB = false;
|
||||
float alpha = 1.0;
|
||||
if (node.OpType() == "FusedMatMul") {
|
||||
// Fused matmul is matmul modified to behave like numpy:
|
||||
//https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
|
||||
is_fusedmatmul = true;
|
||||
transA = GetTransA(node);
|
||||
transBatchA = GetTransBatchA(node);
|
||||
transB = GetTransB(node);
|
||||
transBatchB = GetTransBatchB(node);
|
||||
alpha = GetAlpha(node);
|
||||
}
|
||||
|
||||
auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().dims();
|
||||
auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().dims();
|
||||
|
||||
|
||||
//If this is required for transposed inputs, then this will be done later on in the code.
|
||||
if (src_dims.size() != weights_dims.size()) {
|
||||
while (src_dims.size() < weights_dims.size()) {
|
||||
src_dims.insert(src_dims.begin(), 1);
|
||||
while (src_dims.size() < weights_dims.size() && (!transA && !transBatchA)) {
|
||||
src_dims.insert(src_dims.begin(), 1);
|
||||
}
|
||||
while (src_dims.size() > weights_dims.size() && (!transB && !transBatchB)) {
|
||||
weights_dims.insert(weights_dims.begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto dataA_dims = src_dims;
|
||||
auto ndataA_dims = src_dims.size();
|
||||
dnnl::memory::dims transposedA_dims(ndataA_dims, 0);
|
||||
|
||||
auto dataB_dims = weights_dims;
|
||||
auto ndataB_dims = weights_dims.size();
|
||||
dnnl::memory::dims transposedB_dims(ndataB_dims, 0);
|
||||
|
||||
auto dataA_mem = sp.GetMemory(node.Input(IN_A));
|
||||
auto dataB_mem = sp.GetMemory(node.Input(IN_B));
|
||||
|
||||
|
||||
//Holds transposed matrices A and B. ToDo: Eliminate its usage if in place transpose is possbile for FusedMatmul
|
||||
dnnl::memory::desc intermediateA_md;
|
||||
dnnl::memory intermediateA_mem;
|
||||
|
||||
dnnl::memory::desc intermediateB_md;
|
||||
dnnl::memory intermediateB_mem;
|
||||
|
||||
if (is_fusedmatmul)
|
||||
{
|
||||
if (transA || transBatchA) {
|
||||
dnnl::memory::dims strides = GetStrides(dataA_dims, transA, transBatchA, transposedA_dims);
|
||||
|
||||
intermediateA_md = dnnl::memory::desc(dataA_dims, node.Input(IN_A).Type(), strides);
|
||||
intermediateA_mem = dnnl::memory(intermediateA_md, eng);
|
||||
|
||||
auto traspose_primitive = dnnl::reorder(dataA_mem, intermediateA_mem);
|
||||
sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataA_mem},
|
||||
{DNNL_ARG_TO, intermediateA_mem}});
|
||||
|
||||
// The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor
|
||||
// that will have the correct dimentions and correct memory::format
|
||||
dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataA_dims.size()));
|
||||
dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr);
|
||||
void* handle = intermediateA_mem.get_data_handle();
|
||||
transposed_mem.set_data_handle(handle);
|
||||
while (transposedA_dims.size() < weights_dims.size()) {
|
||||
transposedA_dims.insert(transposedA_dims.begin(), 1);
|
||||
}
|
||||
}
|
||||
while (src_dims.size() > weights_dims.size()) {
|
||||
weights_dims.insert(weights_dims.begin(), 1);
|
||||
if (transB || transBatchB) { // Exact same logic for matrix B as used for matrix A
|
||||
dnnl::memory::dims strides = GetStrides(dataB_dims, transB, transBatchB, transposedB_dims);
|
||||
|
||||
intermediateB_md = dnnl::memory::desc(dataB_dims, node.Input(IN_B).Type(), strides);
|
||||
intermediateB_mem = dnnl::memory(intermediateB_md, eng);
|
||||
|
||||
auto traspose_primitive = dnnl::reorder(dataB_mem, intermediateB_mem);
|
||||
sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataB_mem},
|
||||
{DNNL_ARG_TO, intermediateB_mem}});
|
||||
|
||||
// The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor
|
||||
// that will have the correct dimentions and correct memory::format
|
||||
dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataB_dims.size()));
|
||||
dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr);
|
||||
void* handle = intermediateB_mem.get_data_handle();
|
||||
transposed_mem.set_data_handle(handle);
|
||||
while (src_dims.size() > transposedB_dims.size()) {
|
||||
transposedB_dims.insert(transposedB_dims.begin(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
|
||||
auto weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc src_md;
|
||||
if (transA || transBatchA) {
|
||||
src_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
|
||||
}
|
||||
|
||||
dnnl::memory::desc weights_md;
|
||||
if (transB || transBatchB) {
|
||||
weights_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
|
||||
}
|
||||
|
||||
|
||||
auto output_shape = src_dims;
|
||||
if (transA || transBatchA)
|
||||
output_shape = transposedA_dims;
|
||||
output_shape.pop_back();
|
||||
output_shape.emplace_back(weights_dims.back());
|
||||
if (transB || transBatchB)
|
||||
output_shape.emplace_back(transposedB_dims.back());
|
||||
else
|
||||
output_shape.emplace_back(weights_dims.back());
|
||||
for (size_t i = 0; i < output_shape.size() - 2; i++) {
|
||||
if (output_shape[i] == 1) {
|
||||
output_shape[i] = weights_dims[i];
|
||||
if (transB || transBatchB)
|
||||
output_shape[i] = transposedB_dims[i];
|
||||
else
|
||||
output_shape[i] = weights_dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,16 +179,32 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
|||
attr.set_post_ops(ops);
|
||||
}
|
||||
|
||||
if (is_fusedmatmul) { // Set the scaling of output as a post op in the primitive attribute, taking the value from alpha attribute
|
||||
std::vector<float> alphaScale({alpha});
|
||||
attr.set_output_scales(0, alphaScale);
|
||||
}
|
||||
|
||||
auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any);
|
||||
|
||||
auto matmul_d = dnnl::matmul::desc(src_md, weights_md, dst_md);
|
||||
|
||||
auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, eng);
|
||||
|
||||
auto matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng);
|
||||
auto matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng);
|
||||
dnnl::memory matmul_src_mem, matmul_weights_mem;
|
||||
auto matmul_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng);
|
||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||
|
||||
if (transA || transBatchA) {
|
||||
matmul_src_mem = intermediateA_mem;
|
||||
} else {
|
||||
matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng);
|
||||
}
|
||||
if (transB || transBatchB) {
|
||||
matmul_weights_mem = intermediateB_mem;
|
||||
} else {
|
||||
matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng);
|
||||
}
|
||||
|
||||
//a default memory map for matmul
|
||||
std::unordered_map<int, dnnl::memory> mem_map({{DNNL_ARG_SRC, matmul_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, matmul_weights_mem},
|
||||
|
|
@ -105,5 +225,95 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
|||
sp.SetMemory(node.Output(OUT_Y), matmul_dst_mem);
|
||||
}
|
||||
|
||||
dnnl::memory::dims DnnlMatMul::GetStrides(dnnl::memory::dims& data_dims,
|
||||
bool trans,
|
||||
bool transBatch,
|
||||
dnnl::memory::dims& transposed_dims) {
|
||||
std::vector<uint32_t> permA;
|
||||
std::vector<uint32_t> N_A;
|
||||
auto ndata_dims = data_dims.size();
|
||||
uint32_t M_A, Batch;
|
||||
for (uint32_t i = 0; i < ndata_dims; i++) // Temp vector to hold indices of the dims, will be used to track transposes required
|
||||
permA.push_back(i);
|
||||
Batch = permA[0]; // Batch Dimension
|
||||
M_A = permA[ndata_dims - 1]; // M Dimension
|
||||
if (ndata_dims == 4) // This will only be used if transBatch is true
|
||||
N_A.push_back(permA[ndata_dims - 3]);
|
||||
N_A.push_back(permA[ndata_dims - 2]);
|
||||
if (trans && !transBatch) { // Swap last two dimensions for Trans only
|
||||
auto n = permA[ndata_dims - 1];
|
||||
permA[ndata_dims - 1] = permA[ndata_dims - 2];
|
||||
permA[ndata_dims - 2] = n;
|
||||
} else if (!trans && transBatch) { // If transBatch only, {Batch, N, M} ---> {N, Batch, M}
|
||||
uint32_t i;
|
||||
for (i = 0; i < N_A.size(); i++) {
|
||||
permA[i] = N_A[i];
|
||||
}
|
||||
permA[i] = Batch;
|
||||
} else { // If both trans and transBatch is true, then end result should be {Batch, N, M} ----> {N, M, Batch}
|
||||
uint32_t i;
|
||||
for (i = 0; i < N_A.size(); i++) {
|
||||
permA[i] = N_A[i];
|
||||
}
|
||||
permA[i] = M_A;
|
||||
permA[i + 1] = Batch;
|
||||
}
|
||||
dnnl::memory::dims strides(ndata_dims, 0);
|
||||
dnnl::memory::dim total_stride = 1;
|
||||
for (int i = (int)ndata_dims - 1; i >= 0; i--) {
|
||||
transposed_dims[i] = data_dims[permA[i]];
|
||||
strides[permA[i]] = total_stride;
|
||||
total_stride *= data_dims[permA[i]];
|
||||
}
|
||||
|
||||
dnnl::memory::dims strides_inverse;
|
||||
strides_inverse.reserve(ndata_dims);
|
||||
for (size_t i = 0; i < ndata_dims; ++i) {
|
||||
strides_inverse.push_back(strides[ndata_dims - i - 1]);
|
||||
}
|
||||
|
||||
return strides;
|
||||
}
|
||||
|
||||
bool DnnlMatMul::GetTransA(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transA");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnlMatMul::GetTransBatchA(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transBatchA");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnlMatMul::GetTransB(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transB");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnlMatMul::GetTransBatchB(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("transBatchB");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return (attr->second().i() != 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
float DnnlMatMul::GetAlpha(DnnlNode& node) {
|
||||
auto attr = node.Attributes().find("alpha");
|
||||
if (attr != node.Attributes().end()) {
|
||||
return attr->second().f();
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -22,6 +22,17 @@ class DnnlMatMul {
|
|||
|
||||
DnnlMatMul();
|
||||
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
|
||||
|
||||
private:
|
||||
bool GetTransA(DnnlNode& node);
|
||||
bool GetTransBatchA(DnnlNode& node);
|
||||
bool GetTransB(DnnlNode& node);
|
||||
bool GetTransBatchB(DnnlNode& node);
|
||||
float GetAlpha(DnnlNode& node);
|
||||
dnnl::memory::dims GetStrides(dnnl::memory::dims& data_dims,
|
||||
bool trans,
|
||||
bool transBatch,
|
||||
dnnl::memory::dims& transposed_dims);
|
||||
};
|
||||
|
||||
} // namespace ort_dnnl
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) {
|
|||
}
|
||||
|
||||
for (auto& data : data_vec) {
|
||||
printf("%.6f \n", data);
|
||||
printf("%.6f ", data);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
|
@ -152,7 +152,7 @@ void DnnlSubgraphPrimitive::AddKernels() {
|
|||
DnnlGemm().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "LRN") {
|
||||
DnnlLrn().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMul" || node.OpType() == "MatMulAdd") {
|
||||
} else if (node.OpType() == "MatMul" || node.OpType() == "MatMulAdd" || node.OpType() == "FusedMatMul") {
|
||||
DnnlMatMul().CreatePrimitive(*this, node);
|
||||
} else if (node.OpType() == "MatMulInteger") {
|
||||
DnnlMatMulInteger().CreatePrimitive(*this, node);
|
||||
|
|
|
|||
Loading…
Reference in a new issue