Added code for FusedMatMul inside matmul op primitive (#11077)

This commit is contained in:
chethanpk 2022-04-04 10:00:02 -07:00 committed by GitHub
parent ea004e953f
commit 112dec6565
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 234 additions and 12 deletions

View file

@ -19,6 +19,7 @@ DnnlOpManager::DnnlOpManager() {
dnnl_ops_map_.emplace(std::make_pair("Erf", std::unique_ptr<DnnlNodeCapability>(new DnnlErfNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Exp", std::unique_ptr<DnnlNodeCapability>(new DnnlElementwiseCapability())));
dnnl_ops_map_.emplace(std::make_pair("FastGelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("FusedMatMul", std::unique_ptr<DnnlNodeCapability>(new DnnlMatMulNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Gelu", std::unique_ptr<DnnlNodeCapability>(new DnnlDefaultNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("Gemm", std::unique_ptr<DnnlNodeCapability>(new DnnlGemmNodeCapability())));
dnnl_ops_map_.emplace(std::make_pair("GlobalAveragePool", std::unique_ptr<DnnlNodeCapability>(new DnnlPoolNodeCapability())));

View file

@ -4,6 +4,7 @@
#include "dnnl_matmul.h"
#include "dnnl_subgraph.h"
#include "dnnl_subgraph_primitive.h"
#include <vector>
namespace onnxruntime {
namespace ort_dnnl {
@ -20,27 +21,130 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
assert(node.Input(IN_BINARY).Exists());
}
bool is_fusedmatmul = false;
bool transA = false;
bool transBatchA = false;
bool transB = false;
bool transBatchB = false;
float alpha = 1.0;
if (node.OpType() == "FusedMatMul") {
// Fused matmul is matmul modified to behave like numpy:
//https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
is_fusedmatmul = true;
transA = GetTransA(node);
transBatchA = GetTransBatchA(node);
transB = GetTransB(node);
transBatchB = GetTransBatchB(node);
alpha = GetAlpha(node);
}
auto src_dims = sp.GetMemory(node.Input(IN_A)).get_desc().dims();
auto weights_dims = sp.GetMemory(node.Input(IN_B)).get_desc().dims();
//If this is required for transposed inputs, then this will be done later on in the code.
if (src_dims.size() != weights_dims.size()) {
while (src_dims.size() < weights_dims.size()) {
src_dims.insert(src_dims.begin(), 1);
while (src_dims.size() < weights_dims.size() && (!transA && !transBatchA)) {
src_dims.insert(src_dims.begin(), 1);
}
while (src_dims.size() > weights_dims.size() && (!transB && !transBatchB)) {
weights_dims.insert(weights_dims.begin(), 1);
}
}
auto dataA_dims = src_dims;
auto ndataA_dims = src_dims.size();
dnnl::memory::dims transposedA_dims(ndataA_dims, 0);
auto dataB_dims = weights_dims;
auto ndataB_dims = weights_dims.size();
dnnl::memory::dims transposedB_dims(ndataB_dims, 0);
auto dataA_mem = sp.GetMemory(node.Input(IN_A));
auto dataB_mem = sp.GetMemory(node.Input(IN_B));
//Holds transposed matrices A and B. ToDo: Eliminate its usage if in place transpose is possbile for FusedMatmul
dnnl::memory::desc intermediateA_md;
dnnl::memory intermediateA_mem;
dnnl::memory::desc intermediateB_md;
dnnl::memory intermediateB_mem;
if (is_fusedmatmul)
{
if (transA || transBatchA) {
dnnl::memory::dims strides = GetStrides(dataA_dims, transA, transBatchA, transposedA_dims);
intermediateA_md = dnnl::memory::desc(dataA_dims, node.Input(IN_A).Type(), strides);
intermediateA_mem = dnnl::memory(intermediateA_md, eng);
auto traspose_primitive = dnnl::reorder(dataA_mem, intermediateA_mem);
sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataA_mem},
{DNNL_ARG_TO, intermediateA_mem}});
// The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor
// that will have the correct dimentions and correct memory::format
dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataA_dims.size()));
dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr);
void* handle = intermediateA_mem.get_data_handle();
transposed_mem.set_data_handle(handle);
while (transposedA_dims.size() < weights_dims.size()) {
transposedA_dims.insert(transposedA_dims.begin(), 1);
}
}
while (src_dims.size() > weights_dims.size()) {
weights_dims.insert(weights_dims.begin(), 1);
if (transB || transBatchB) { // Exact same logic for matrix B as used for matrix A
dnnl::memory::dims strides = GetStrides(dataB_dims, transB, transBatchB, transposedB_dims);
intermediateB_md = dnnl::memory::desc(dataB_dims, node.Input(IN_B).Type(), strides);
intermediateB_mem = dnnl::memory(intermediateB_md, eng);
auto traspose_primitive = dnnl::reorder(dataB_mem, intermediateB_mem);
sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataB_mem},
{DNNL_ARG_TO, intermediateB_mem}});
// The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor
// that will have the correct dimentions and correct memory::format
dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataB_dims.size()));
dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr);
void* handle = intermediateB_mem.get_data_handle();
transposed_mem.set_data_handle(handle);
while (src_dims.size() > transposedB_dims.size()) {
transposedB_dims.insert(transposedB_dims.begin(), 1);
}
}
}
auto src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
auto weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
dnnl::memory::desc src_md;
if (transA || transBatchA) {
src_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
} else {
src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
}
dnnl::memory::desc weights_md;
if (transB || transBatchB) {
weights_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
} else {
weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
}
auto output_shape = src_dims;
if (transA || transBatchA)
output_shape = transposedA_dims;
output_shape.pop_back();
output_shape.emplace_back(weights_dims.back());
if (transB || transBatchB)
output_shape.emplace_back(transposedB_dims.back());
else
output_shape.emplace_back(weights_dims.back());
for (size_t i = 0; i < output_shape.size() - 2; i++) {
if (output_shape[i] == 1) {
output_shape[i] = weights_dims[i];
if (transB || transBatchB)
output_shape[i] = transposedB_dims[i];
else
output_shape[i] = weights_dims[i];
}
}
@ -75,16 +179,32 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
attr.set_post_ops(ops);
}
if (is_fusedmatmul) { // Set the scaling of output as a post op in the primitive attribute, taking the value from alpha attribute
std::vector<float> alphaScale({alpha});
attr.set_output_scales(0, alphaScale);
}
auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any);
auto matmul_d = dnnl::matmul::desc(src_md, weights_md, dst_md);
auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, eng);
auto matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng);
auto matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng);
dnnl::memory matmul_src_mem, matmul_weights_mem;
auto matmul_dst_mem = dnnl::memory(matmul_pd.dst_desc(), eng);
auto matmul_prim = dnnl::matmul(matmul_pd);
if (transA || transBatchA) {
matmul_src_mem = intermediateA_mem;
} else {
matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng);
}
if (transB || transBatchB) {
matmul_weights_mem = intermediateB_mem;
} else {
matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng);
}
//a default memory map for matmul
std::unordered_map<int, dnnl::memory> mem_map({{DNNL_ARG_SRC, matmul_src_mem},
{DNNL_ARG_WEIGHTS, matmul_weights_mem},
@ -105,5 +225,95 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
sp.SetMemory(node.Output(OUT_Y), matmul_dst_mem);
}
dnnl::memory::dims DnnlMatMul::GetStrides(dnnl::memory::dims& data_dims,
bool trans,
bool transBatch,
dnnl::memory::dims& transposed_dims) {
std::vector<uint32_t> permA;
std::vector<uint32_t> N_A;
auto ndata_dims = data_dims.size();
uint32_t M_A, Batch;
for (uint32_t i = 0; i < ndata_dims; i++) // Temp vector to hold indices of the dims, will be used to track transposes required
permA.push_back(i);
Batch = permA[0]; // Batch Dimension
M_A = permA[ndata_dims - 1]; // M Dimension
if (ndata_dims == 4) // This will only be used if transBatch is true
N_A.push_back(permA[ndata_dims - 3]);
N_A.push_back(permA[ndata_dims - 2]);
if (trans && !transBatch) { // Swap last two dimensions for Trans only
auto n = permA[ndata_dims - 1];
permA[ndata_dims - 1] = permA[ndata_dims - 2];
permA[ndata_dims - 2] = n;
} else if (!trans && transBatch) { // If transBatch only, {Batch, N, M} ---> {N, Batch, M}
uint32_t i;
for (i = 0; i < N_A.size(); i++) {
permA[i] = N_A[i];
}
permA[i] = Batch;
} else { // If both trans and transBatch is true, then end result should be {Batch, N, M} ----> {N, M, Batch}
uint32_t i;
for (i = 0; i < N_A.size(); i++) {
permA[i] = N_A[i];
}
permA[i] = M_A;
permA[i + 1] = Batch;
}
dnnl::memory::dims strides(ndata_dims, 0);
dnnl::memory::dim total_stride = 1;
for (int i = (int)ndata_dims - 1; i >= 0; i--) {
transposed_dims[i] = data_dims[permA[i]];
strides[permA[i]] = total_stride;
total_stride *= data_dims[permA[i]];
}
dnnl::memory::dims strides_inverse;
strides_inverse.reserve(ndata_dims);
for (size_t i = 0; i < ndata_dims; ++i) {
strides_inverse.push_back(strides[ndata_dims - i - 1]);
}
return strides;
}
bool DnnlMatMul::GetTransA(DnnlNode& node) {
auto attr = node.Attributes().find("transA");
if (attr != node.Attributes().end()) {
return (attr->second().i() != 0);
}
return false;
}
bool DnnlMatMul::GetTransBatchA(DnnlNode& node) {
auto attr = node.Attributes().find("transBatchA");
if (attr != node.Attributes().end()) {
return (attr->second().i() != 0);
}
return false;
}
bool DnnlMatMul::GetTransB(DnnlNode& node) {
auto attr = node.Attributes().find("transB");
if (attr != node.Attributes().end()) {
return (attr->second().i() != 0);
}
return false;
}
bool DnnlMatMul::GetTransBatchB(DnnlNode& node) {
auto attr = node.Attributes().find("transBatchB");
if (attr != node.Attributes().end()) {
return (attr->second().i() != 0);
}
return false;
}
float DnnlMatMul::GetAlpha(DnnlNode& node) {
auto attr = node.Attributes().find("alpha");
if (attr != node.Attributes().end()) {
return attr->second().f();
}
return 1.0;
}
} // namespace ort_dnnl
} // namespace onnxruntime

View file

@ -22,6 +22,17 @@ class DnnlMatMul {
DnnlMatMul();
void CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node);
private:
bool GetTransA(DnnlNode& node);
bool GetTransBatchA(DnnlNode& node);
bool GetTransB(DnnlNode& node);
bool GetTransBatchB(DnnlNode& node);
float GetAlpha(DnnlNode& node);
dnnl::memory::dims GetStrides(dnnl::memory::dims& data_dims,
bool trans,
bool transBatch,
dnnl::memory::dims& transposed_dims);
};
} // namespace ort_dnnl

View file

@ -74,7 +74,7 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) {
}
for (auto& data : data_vec) {
printf("%.6f \n", data);
printf("%.6f ", data);
}
printf("\n");
}
@ -152,7 +152,7 @@ void DnnlSubgraphPrimitive::AddKernels() {
DnnlGemm().CreatePrimitive(*this, node);
} else if (node.OpType() == "LRN") {
DnnlLrn().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMul" || node.OpType() == "MatMulAdd") {
} else if (node.OpType() == "MatMul" || node.OpType() == "MatMulAdd" || node.OpType() == "FusedMatMul") {
DnnlMatMul().CreatePrimitive(*this, node);
} else if (node.OpType() == "MatMulInteger") {
DnnlMatMulInteger().CreatePrimitive(*this, node);