mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
parent
21c282ed54
commit
61ba9ac1bb
9 changed files with 374 additions and 21 deletions
6
cmake/external/dnnl.cmake
vendored
6
cmake/external/dnnl.cmake
vendored
|
|
@ -2,16 +2,16 @@ include (ExternalProject)
|
|||
|
||||
set(DNNL_URL https://github.com/oneapi-src/onednn)
|
||||
# If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated.
|
||||
set(DNNL_TAG v1.8.1)
|
||||
set(DNNL_TAG v2.2)
|
||||
|
||||
if(WIN32)
|
||||
set(DNNL_SHARED_LIB dnnl.dll)
|
||||
set(DNNL_IMPORT_LIB dnnl.lib)
|
||||
else()
|
||||
if (APPLE)
|
||||
set(DNNL_SHARED_LIB libdnnl.1.dylib)
|
||||
set(DNNL_SHARED_LIB libdnnl.2.dylib)
|
||||
else()
|
||||
set(DNNL_SHARED_LIB libdnnl.so.1)
|
||||
set(DNNL_SHARED_LIB libdnnl.so.2)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -402,9 +402,6 @@ if (onnxruntime_USE_DNNL)
|
|||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dnnl DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
|
||||
set_target_properties(onnxruntime_providers_dnnl PROPERTIES FOLDER "ONNXRuntime")
|
||||
set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX)
|
||||
if (onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl")
|
||||
target_compile_definitions(onnxruntime_providers_dnnl PRIVATE USE_DNNL_GPU_OCL=1)
|
||||
endif()
|
||||
|
||||
if(APPLE)
|
||||
set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/dnnl/exported_symbols.lst")
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ void DNNLExecutionProvider::CreateOrUpdateDnnlNode(const Node* node,
|
|||
}
|
||||
#endif //ENABLE_TRAINING
|
||||
|
||||
if (node->OpType() == "Conv") {
|
||||
if (node->OpType() == "Conv" || node->OpType() == "MatMul") {
|
||||
dnnl_node.weight_name = node->InputDefs()[1]->Name();
|
||||
}
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "MaxPoolGrad", "LRN"};
|
||||
#else
|
||||
std::set<std::string> dnnl_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum",
|
||||
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN"};
|
||||
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN", "MatMul"};
|
||||
#endif // ENABLE_TRAINING
|
||||
|
||||
mutable std::unordered_map<std::string, std::shared_ptr<ort_dnnl::Subgraph>> mkl_subgraphs_;
|
||||
|
|
|
|||
|
|
@ -524,20 +524,14 @@ class DnnlConv : public DnnlKernel {
|
|||
filter_data = static_cast<T*>(filter_dst_mem->get_data_handle());
|
||||
filter_mem_->set_data_handle(static_cast<void*>(const_cast<T*>(filter_data)));
|
||||
} else { // gpu_available_
|
||||
#ifdef USE_DNNL_GPU_OCL
|
||||
std::lock_guard<OrtMutex> lock(provider_->GetMutex());
|
||||
filter_mem_gpu_->set_ocl_mem_object(filter_dst_mem->get_ocl_mem_object());
|
||||
#endif // USE_DNNL_GPU_OCL
|
||||
filter_mem_gpu_->set_data_handle(filter_dst_mem->get_data_handle());
|
||||
}
|
||||
#else // ENABLE_TRAINING
|
||||
if (!gpu_available_) {
|
||||
filter_data = static_cast<T*>(filter_dst_mem_->get_data_handle());
|
||||
filter_mem_->set_data_handle(static_cast<void*>(const_cast<T*>(filter_data)));
|
||||
} else if (gpu_available_) {
|
||||
#ifdef USE_DNNL_GPU_OCL
|
||||
std::lock_guard<OrtMutex> lock(provider_->GetMutex());
|
||||
filter_mem_gpu_->set_ocl_mem_object(filter_dst_mem_->get_ocl_mem_object());
|
||||
#endif // USE_DNNL_GPU_OCL
|
||||
filter_mem_gpu_->set_data_handle(filter_dst_mem_->get_data_handle());
|
||||
}
|
||||
#endif // ENABLE_TRAINING
|
||||
|
||||
|
|
|
|||
|
|
@ -545,9 +545,7 @@ class DnnlConvBatchNorm : public DnnlKernel {
|
|||
filter_data = static_cast<T*>(filter_dst_mem->get_data_handle());
|
||||
filter_mem_->set_data_handle(static_cast<void*>(const_cast<T*>(filter_data)));
|
||||
} else { // gpu_available_
|
||||
#ifdef USE_DNNL_GPU_OCL
|
||||
filter_mem_gpu_->set_ocl_mem_object(filter_dst_mem->get_ocl_mem_object());
|
||||
#endif
|
||||
filter_mem_gpu_->set_data_handle(filter_dst_mem->get_data_handle());
|
||||
}
|
||||
|
||||
std::shared_ptr<dnnl::memory> bias_mem = provider_->GetBiasMemoryBuffer(mklnode_ptr_->weight_name);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include "core/providers/dnnl/subgraph/dnnl_pool.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_sum.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_lrn.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_matmul.h"
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "core/providers/dnnl/subgraph/dnnl_convgrad.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_relugrad.h"
|
||||
|
|
@ -202,6 +203,15 @@ class SubgraphPrimitive : public PrimitiveBase {
|
|||
kernel->parents_.push_back(context_.kernels[index]);
|
||||
}
|
||||
context_.kernels.push_back(kernel);
|
||||
} else if (dnnl_node.name == "MatMul") {
|
||||
std::ostringstream os;
|
||||
os << "MatMul-" << dnnl_node.node_index << "-";
|
||||
std::shared_ptr<DnnlMatmul<T>> kernel;
|
||||
kernel = std::make_shared<DnnlMatmul<T>>(dnnl_node, params.provider, *params.attributes, os.str());
|
||||
for (auto index : dnnl_node.parent_nodes) {
|
||||
kernel->parents_.push_back(context_.kernels[index]);
|
||||
}
|
||||
context_.kernels.push_back(kernel);
|
||||
}
|
||||
#ifdef ENABLE_TRAINING
|
||||
else if (dnnl_node.name == "ConvGrad") {
|
||||
|
|
|
|||
354
onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h
Normal file
354
onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
// Copyright(C) 2020 Intel Corporation
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/dnnl/dnnl_fwd.h"
|
||||
#include "core/providers/dnnl/dnnl_execution_provider.h"
|
||||
#include "core/providers/dnnl/subgraph/dnnl_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ort_dnnl {
|
||||
|
||||
template <typename T>
|
||||
class DnnlMatmul : public DnnlKernel {
|
||||
public:
|
||||
DnnlMatmul(const DnnlNode& node,
|
||||
DNNLExecutionProvider* provider,
|
||||
const NodeAttributes& attributes,
|
||||
const std::string attributes_prefix = "") : DnnlKernel(node, provider) {
|
||||
ReadAttributes(attributes, attributes_prefix);
|
||||
}
|
||||
|
||||
void CreatePrimitives(const OrtCustomOpApi* api,
|
||||
OrtKernelContext* context,
|
||||
const std::unordered_map<dnnl::engine::kind, dnnl::engine>& dnnl_engine,
|
||||
std::vector<dnnl::primitive>& net,
|
||||
std::vector<std::unordered_map<int, dnnl::memory>>& net_args) {
|
||||
dnnl::engine cpu_engine;
|
||||
dnnl::engine engine_to_use;
|
||||
std::unordered_map<dnnl::engine::kind, dnnl::engine>::const_iterator iter = dnnl_engine.find(dnnl::engine::kind::cpu);
|
||||
if (iter != dnnl_engine.end()) {
|
||||
dnnl_engine_cpu_ = iter->second;
|
||||
cpu_engine = iter->second;
|
||||
engine_to_use = cpu_engine;
|
||||
}
|
||||
gpu_available_ = false;
|
||||
dnnl::engine gpu_engine;
|
||||
iter = dnnl_engine.find(dnnl::engine::kind::gpu);
|
||||
if (iter != dnnl_engine.end()) {
|
||||
dnnl_engine_gpu_ = iter->second;
|
||||
gpu_engine = iter->second;
|
||||
gpu_available_ = true;
|
||||
engine_to_use = gpu_engine;
|
||||
LOGS_DEFAULT(INFO) << "gpu engine found" << std::endl;
|
||||
}
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
int input_index = mklnode_ptr_->input_start_index < 0 ? 0 : mklnode_ptr_->input_start_index;
|
||||
|
||||
TensorShape x_shape;
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
|
||||
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
|
||||
auto tensor_shape = ort.GetTensorShape(tensor_info);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
auto xshape = tensor_shape.data();
|
||||
auto xdim = tensor_shape.size();
|
||||
x_shape = TensorShape(xshape, xdim);
|
||||
ort_source_format_ = GetSourceFormat(static_cast<int>(xdim));
|
||||
ort_source_desc_ = dnnl::memory::desc(
|
||||
{dnnl::memory::dims(x_shape.GetDims().begin(), x_shape.GetDims().end())}, DnnnType<T>(), ort_source_format_);
|
||||
source_desc_ = ort_source_desc_;
|
||||
} else {
|
||||
// get the output of previous node (Dnnl block propagation).
|
||||
x_shape = parents_[0].get()->primitive_dst_shape_;
|
||||
ort_source_format_ = parents_[0].get()->ort_source_format_;
|
||||
ort_source_desc_ = parents_[0].get()->ort_source_desc_;
|
||||
source_desc_ = parents_[0].get()->primitive_dst_desc_;
|
||||
}
|
||||
|
||||
const OrtValue* winput_tensor = ort.KernelContext_GetInput(context, input_index + 1);
|
||||
auto wtensor_info = ort.GetTensorTypeAndShape(winput_tensor);
|
||||
auto wtensor_shape = ort.GetTensorShape(wtensor_info);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(wtensor_info);
|
||||
auto wshape = wtensor_shape.data();
|
||||
auto wdim = wtensor_shape.size();
|
||||
TensorShape w_shape(wshape, wdim);
|
||||
|
||||
AdjustSrcWeightsShape(x_shape, w_shape);
|
||||
weights_shape_ = w_shape;
|
||||
weights_format_ = GetSourceFormat(static_cast<int>(w_shape.NumDimensions()));
|
||||
|
||||
std::vector<int64_t> y_dims;
|
||||
InferOutputShape(x_shape, w_shape, y_dims);
|
||||
primitive_dst_shape_ = TensorShape(y_dims);
|
||||
|
||||
std::unique_ptr<dnnl::memory::desc> src_md = onnxruntime::make_unique<dnnl::memory::desc>(
|
||||
dnnl::memory::dims(x_shape.GetDims().begin(), x_shape.GetDims().end()), DnnnType<T>(), dnnl::memory::format_tag::any);
|
||||
|
||||
std::unique_ptr<dnnl::memory::desc> weights_md = onnxruntime::make_unique<dnnl::memory::desc>(
|
||||
dnnl::memory::dims(w_shape.GetDims().begin(), w_shape.GetDims().end()), DnnnType<T>(), dnnl::memory::format_tag::any);
|
||||
|
||||
primitive_dst_md_ = onnxruntime::make_unique<dnnl::memory::desc>(
|
||||
dnnl::memory::dims(y_dims.begin(), y_dims.end()), DnnnType<T>(), dnnl::memory::format_tag::any);
|
||||
|
||||
std::unique_ptr<dnnl::matmul::desc> matmul_desc = onnxruntime::make_unique<dnnl::matmul::desc>(*src_md, *weights_md, *primitive_dst_md_);
|
||||
matmul_pd_ = onnxruntime::make_unique<dnnl::matmul::primitive_desc>(*matmul_desc, engine_to_use);
|
||||
matmul_ = onnxruntime::make_unique<dnnl::matmul>(dnnl::matmul(*matmul_pd_));
|
||||
|
||||
primitive_src_desc_ = static_cast<dnnl::memory::desc>(matmul_pd_.get()->src_desc());
|
||||
primitive_dst_desc_ = static_cast<dnnl::memory::desc>(matmul_pd_.get()->dst_desc());
|
||||
|
||||
weights_size_ = matmul_pd_.get()->weights_desc().get_size();
|
||||
dst_size_ = matmul_pd_.get()->dst_desc().get_size();
|
||||
|
||||
weights_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_.get()->weights_desc(), cpu_engine, nullptr));
|
||||
if (gpu_available_) {
|
||||
weights_mem_gpu_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_.get()->weights_desc(), gpu_engine, nullptr));
|
||||
}
|
||||
|
||||
if (!gpu_available_) {
|
||||
if (primitive_src_desc_ != source_desc_) {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
dnnl::memory::dims src_dims(x_shape.GetDims().begin(), x_shape.GetDims().end());
|
||||
auto pd = dnnl::memory::desc({{src_dims}, DnnnType<T>(), ort_source_format_});
|
||||
src_mem_from_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(pd, cpu_engine, nullptr));
|
||||
}
|
||||
else
|
||||
src_mem_from_ = parents_[0].get()->primitive_dst_mem_;
|
||||
|
||||
src_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->src_desc(), cpu_engine, nullptr));
|
||||
net.push_back(dnnl::reorder(*src_mem_from_, *src_mem_));
|
||||
net_args.push_back({{DNNL_ARG_FROM, *src_mem_from_},
|
||||
{DNNL_ARG_TO, *src_mem_}});
|
||||
} else {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
src_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->src_desc(), cpu_engine, nullptr));
|
||||
} else {
|
||||
src_mem_ = parents_[0].get()->primitive_dst_mem_;
|
||||
}
|
||||
}
|
||||
|
||||
if (mklnode_ptr_->output_index >= 0) {
|
||||
if (primitive_dst_desc_ != ort_source_desc_) {
|
||||
primitive_dst_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_.get()->dst_desc(), cpu_engine));
|
||||
} else {
|
||||
primitive_dst_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_.get()->dst_desc(), cpu_engine, nullptr));
|
||||
}
|
||||
}
|
||||
} else { // gpu_available_
|
||||
if (primitive_src_desc_ != source_desc_) {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
dnnl::memory::dims src_dims(x_shape.GetDims().begin(), x_shape.GetDims().end());
|
||||
auto pd = dnnl::memory::desc({{src_dims}, DnnnType<T>(), ort_source_format_});
|
||||
src_mem_from_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(pd, cpu_engine, nullptr));
|
||||
} else {
|
||||
src_mem_from_ = parents_[0].get()->primitive_dst_mem_;
|
||||
}
|
||||
src_mem_gpu_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->src_desc(), gpu_engine));
|
||||
net.push_back(dnnl::reorder(*src_mem_from_, *src_mem_gpu_));
|
||||
net_args.push_back({{DNNL_ARG_FROM, *src_mem_from_},
|
||||
{DNNL_ARG_TO, *src_mem_gpu_}});
|
||||
} else {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
src_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->src_desc(), cpu_engine, nullptr));
|
||||
src_mem_gpu_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->src_desc(), gpu_engine));
|
||||
net.push_back(dnnl::reorder(*src_mem_, *src_mem_gpu_));
|
||||
net_args.push_back({{DNNL_ARG_SRC, *src_mem_},
|
||||
{DNNL_ARG_DST, *src_mem_gpu_}});
|
||||
} else {
|
||||
src_mem_gpu_ = parents_[0].get()->primitive_dst_mem_;
|
||||
}
|
||||
}
|
||||
|
||||
primitive_dst_mem_ = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_.get()->dst_desc(), gpu_engine));
|
||||
}
|
||||
|
||||
net.push_back(*matmul_);
|
||||
if (!gpu_available_) {
|
||||
net_args.push_back({{DNNL_ARG_SRC, *src_mem_},
|
||||
{DNNL_ARG_WEIGHTS, *weights_mem_},
|
||||
{DNNL_ARG_DST, *primitive_dst_mem_}});
|
||||
} else { // gpu_available_
|
||||
net_args.push_back({{DNNL_ARG_SRC, *src_mem_gpu_},
|
||||
{DNNL_ARG_WEIGHTS, *weights_mem_gpu_},
|
||||
{DNNL_ARG_DST, *primitive_dst_mem_}});
|
||||
}
|
||||
|
||||
if (mklnode_ptr_->output_index >= 0) {
|
||||
dnnl::memory::data_type t = DnnnType<T>();
|
||||
InitDstReorderOutput(cpu_engine, t, net, net_args, gpu_available_);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void ReorderWeights(const OrtCustomOpApi* api, OrtKernelContext* context, const dnnl::engine& cpu_engine) override {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
int input_index = mklnode_ptr_->input_start_index < 0 ? 0 : mklnode_ptr_->input_start_index;
|
||||
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index + 1);
|
||||
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
|
||||
auto tensor_shape = ort.GetTensorShape(tensor_info);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
|
||||
const T* weights_data = const_cast<T*>(ort.GetTensorData<T>(input_tensor));
|
||||
|
||||
dnnl::memory::dims weights_dims_dnnl;
|
||||
weights_dims_dnnl.assign(weights_shape_.GetDims().begin(), weights_shape_.GetDims().end());
|
||||
|
||||
{
|
||||
// lock to make sure reordering is done only once
|
||||
std::lock_guard<OrtMutex> lock(provider_->GetMutex());
|
||||
std::shared_ptr<dnnl::memory> weights_dst_mem = provider_->GetWeightsMemoryBuffer(mklnode_ptr_->weight_name);
|
||||
|
||||
if (weights_dst_mem == nullptr) {
|
||||
dnnl::memory src = dnnl::memory({{weights_dims_dnnl}, DnnnType<T>(), weights_format_}, cpu_engine, (void*)weights_data);
|
||||
IAllocatorUniquePtr<void> weights_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc_, weights_size_);
|
||||
if (!gpu_available_) {
|
||||
weights_dst_mem = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->weights_desc(), cpu_engine, weights_reorder_buffer.get()));
|
||||
|
||||
dnnl::reorder(src, *weights_dst_mem)
|
||||
.execute(cpu_engine, src, *weights_dst_mem);
|
||||
|
||||
provider_->SaveAllocatedMemory(std::move(weights_reorder_buffer));
|
||||
weights_data = static_cast<T*>(weights_dst_mem->get_data_handle());
|
||||
} else { // gpu_available_
|
||||
weights_dst_mem = onnxruntime::make_unique<dnnl::memory>(
|
||||
dnnl::memory(matmul_pd_->weights_desc(), dnnl_engine_gpu_));
|
||||
|
||||
dnnl::reorder(src, *weights_dst_mem)
|
||||
.execute(dnnl_engine_gpu_, src, *weights_dst_mem);
|
||||
}
|
||||
|
||||
provider_->SetWeightsMemoryBuffer(mklnode_ptr_->weight_name, weights_dst_mem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Bind(const OrtCustomOpApi* api, OrtKernelContext* context) override {
|
||||
Ort::CustomOpApi ort{*api};
|
||||
|
||||
ORT_RETURN_IF_ERROR(primitive_created_status_);
|
||||
|
||||
int input_index = mklnode_ptr_->input_start_index < 0 ? 0 : mklnode_ptr_->input_start_index;
|
||||
const OrtValue* winput_tensor = ort.KernelContext_GetInput(context, input_index + 1);
|
||||
const T* weights_data = const_cast<T*>(ort.GetTensorData<T>(winput_tensor));
|
||||
|
||||
std::shared_ptr<dnnl::memory> weights_dst_mem = provider_->GetWeightsMemoryBuffer(mklnode_ptr_->weight_name);
|
||||
if (weights_dst_mem == nullptr) {
|
||||
ReorderWeights(api, context, dnnl_engine_cpu_);
|
||||
weights_dst_mem = provider_->GetWeightsMemoryBuffer(mklnode_ptr_->weight_name);
|
||||
}
|
||||
if (!gpu_available_) {
|
||||
weights_data = static_cast<T*>(weights_dst_mem->get_data_handle());
|
||||
weights_mem_->set_data_handle(static_cast<void*>(const_cast<T*>(weights_data)));
|
||||
} else { // gpu_available_
|
||||
weights_mem_gpu_->set_data_handle(weights_dst_mem->get_data_handle());
|
||||
}
|
||||
|
||||
if (primitive_src_desc_ != source_desc_) {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
|
||||
const T* src_data = const_cast<T*>(ort.GetTensorData<T>(input_tensor));
|
||||
src_mem_from_->set_data_handle(static_cast<void*>(const_cast<T*>(src_data)));
|
||||
} else {
|
||||
src_mem_from_ = parents_[0].get()->primitive_dst_mem_;
|
||||
}
|
||||
|
||||
if (!gpu_available_) {
|
||||
auto src_size = matmul_pd_.get()->src_desc().get_size();
|
||||
src_reorder_buffer_ = IAllocator::MakeUniquePtr<void>(alloc_, src_size);
|
||||
src_mem_->set_data_handle(src_reorder_buffer_.get());
|
||||
}
|
||||
} else {
|
||||
if (mklnode_ptr_->parent_nodes.empty()) {
|
||||
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
|
||||
const T* src_data = const_cast<T*>(ort.GetTensorData<T>(input_tensor));
|
||||
src_mem_->set_data_handle(static_cast<void*>(const_cast<T*>(src_data)));
|
||||
} else {
|
||||
src_mem_ = parents_[0].get()->primitive_dst_mem_;
|
||||
}
|
||||
}
|
||||
|
||||
if (mklnode_ptr_->output_index >= 0) {
|
||||
auto& y_dims = primitive_dst_shape_.GetDims();
|
||||
// Allocate memory for output buffer
|
||||
OrtValue* output = ort.KernelContext_GetOutput(context, mklnode_ptr_->output_index, &y_dims[0], static_cast<int>(primitive_dst_shape_.GetDims().size()));
|
||||
T* dst_data = ort.GetTensorMutableData<T>(output);
|
||||
|
||||
if (!gpu_available_) {
|
||||
if (primitive_dst_desc_ != ort_source_desc_) {
|
||||
reorder_dst_mem_to_->set_data_handle(dst_data);
|
||||
} else {
|
||||
primitive_dst_mem_->set_data_handle(dst_data);
|
||||
}
|
||||
} else { // gpu_available_
|
||||
reorder_dst_mem_to_->set_data_handle(dst_data);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
dnnl::memory::format_tag weights_format_;
|
||||
|
||||
std::shared_ptr<dnnl::memory> src_mem_from_;
|
||||
|
||||
size_t weights_size_;
|
||||
size_t dst_size_;
|
||||
|
||||
TensorShape weights_shape_;
|
||||
|
||||
std::shared_ptr<dnnl::memory> src_mem_;
|
||||
std::shared_ptr<dnnl::memory> src_mem_gpu_;
|
||||
std::shared_ptr<dnnl::memory> weights_mem_;
|
||||
std::unique_ptr<dnnl::memory> weights_mem_gpu_;
|
||||
|
||||
std::unique_ptr<dnnl::matmul::primitive_desc> matmul_pd_;
|
||||
std::unique_ptr<dnnl::matmul::primitive> matmul_;
|
||||
|
||||
dnnl::engine dnnl_engine_cpu_;
|
||||
dnnl::engine dnnl_engine_gpu_;
|
||||
|
||||
bool gpu_available_;
|
||||
|
||||
IAllocatorUniquePtr<void> src_reorder_buffer_;
|
||||
|
||||
void InferOutputShape(const TensorShape& input_shape, const TensorShape& weight_shape, std::vector<int64_t>& output_shape) const {
|
||||
output_shape = input_shape.GetDims();
|
||||
output_shape.pop_back();
|
||||
output_shape.emplace_back(weight_shape.GetDims().back());
|
||||
}
|
||||
|
||||
void AdjustSrcWeightsShape(TensorShape& input_shape, TensorShape& weights_shape) const {
|
||||
|
||||
if (input_shape.NumDimensions() > weights_shape.NumDimensions()) {
|
||||
auto dims = weights_shape.GetDims();
|
||||
for (size_t i = 0; i < input_shape.NumDimensions() - weights_shape.NumDimensions(); i++) {
|
||||
dims.insert(dims.begin(), 1);
|
||||
}
|
||||
weights_shape = TensorShape(dims);
|
||||
} else if (input_shape.NumDimensions() < weights_shape.NumDimensions()) {
|
||||
auto dims = input_shape.GetDims();
|
||||
for (size_t i = 0; i < weights_shape.NumDimensions() - input_shape.NumDimensions(); i++) {
|
||||
dims.insert(dims.begin(), 1);
|
||||
}
|
||||
input_shape = TensorShape(dims);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
4
setup.py
4
setup.py
|
|
@ -157,7 +157,7 @@ except ImportError as error:
|
|||
|
||||
# Additional binaries
|
||||
if platform.system() == 'Linux':
|
||||
libs = ['onnxruntime_pybind11_state.so', 'libdnnl.so.1', 'libmklml_intel.so', 'libmklml_gnu.so', 'libiomp5.so', 'mimalloc.so']
|
||||
libs = ['onnxruntime_pybind11_state.so', 'libdnnl.so.2', 'libmklml_intel.so', 'libmklml_gnu.so', 'libiomp5.so', 'mimalloc.so']
|
||||
# DNNL, TensorRT & OpenVINO EPs are built as shared libs
|
||||
libs.extend(['libonnxruntime_providers_shared.so'])
|
||||
libs.extend(['libonnxruntime_providers_dnnl.so'])
|
||||
|
|
@ -168,7 +168,7 @@ if platform.system() == 'Linux':
|
|||
if nightly_build:
|
||||
libs.extend(['libonnxruntime_pywrapper.so'])
|
||||
elif platform.system() == "Darwin":
|
||||
libs = ['onnxruntime_pybind11_state.so', 'libdnnl.1.dylib', 'mimalloc.so'] # TODO add libmklml and libiomp5 later.
|
||||
libs = ['onnxruntime_pybind11_state.so', 'libdnnl.2.dylib', 'mimalloc.so'] # TODO add libmklml and libiomp5 later.
|
||||
# DNNL & TensorRT EPs are built as shared libs
|
||||
libs.extend(['libonnxruntime_providers_shared.dylib'])
|
||||
libs.extend(['libonnxruntime_providers_dnnl.dylib'])
|
||||
|
|
|
|||
Loading…
Reference in a new issue