diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index 6770e66a29..5c42216ac0 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -2,7 +2,7 @@ include (ExternalProject) set(DNNL_URL https://github.com/oneapi-src/onednn.git) # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. -set(DNNL_TAG v2.4.4) +set(DNNL_TAG v2.6) if(WIN32) set(DNNL_SHARED_LIB dnnl.dll) diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 5496e4672c..07eed52fa4 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -489,10 +489,7 @@ bool DnnlSumNodeCapability::IsDimensionSupported(const Node* node) const { bool DnnlBinaryNodeCapability::Supported(const Node* node, const GraphViewer& graph_viewer) const { ORT_UNUSED_PARAMETER(graph_viewer); if (!IsTypeSupported(node)) return false; - //gpu broadcast for source 0 not supported - if (dnnl_engine_get_count(dnnl_engine_kind_t::dnnl_gpu)) { - return false; - } + return true; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc index 579d2b2b8e..63c84b4288 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_binary.cc @@ -24,8 +24,10 @@ void DnnlBinary::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } // GetMemory in OrtFormat. Broadcasting and mix format binary ops can result in computation failure - auto src_0_ori_md = sp.GetMemoryInOrtFormat(node.Input(IN_A), eng).get_desc(); - auto src_1_ori_md = sp.GetMemoryInOrtFormat(node.Input(IN_B), eng).get_desc(); + auto binary_src0_mem = sp.GetMemoryInOrtFormat(node.Input(IN_A), eng); + auto binary_src1_mem = sp.GetMemoryInOrtFormat(node.Input(IN_B), eng); + auto src_0_ori_md = binary_src0_mem.get_desc(); + auto src_1_ori_md = binary_src1_mem.get_desc(); auto src_0_dims = src_0_ori_md.dims(); auto src_1_dims = src_1_ori_md.dims(); @@ -53,11 +55,6 @@ void DnnlBinary::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto binary_d = dnnl::binary::desc(algo, src_0_md, src_1_md, dst_md); auto binary_pd = dnnl::binary::primitive_desc(binary_d, eng); - auto binary_src0_mem = sp.GetMemoryAndReshape(node.Input(IN_A), binary_pd.src0_desc(), eng); - auto binary_src1_mem = sp.GetMemoryAndReshape(node.Input(IN_B), binary_pd.src1_desc(), eng); - - - auto binary_dst_mem = dnnl::memory(binary_pd.dst_desc(), eng); auto binary_prim = dnnl::binary(binary_pd); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc index 05e0c034b3..b05c91d2f4 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_layernorm.cc @@ -108,7 +108,7 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { // X = LayerNornm(X) // Check if we are training and need the extra outputs for backprop dnnl::prop_kind prop_kind; -#if defined(ENABLE_TRAINING) +#if 0 //defined(ENABLE_TRAINING) prop_kind = dnnl::prop_kind::forward_training; #else prop_kind = dnnl::prop_kind::forward_inference; @@ -146,18 +146,20 @@ void DnnlLayerNorm::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } // Check outputs used for training -#if defined(ENABLE_TRAINING) +#if 0 //defined(ENABLE_TRAINING) // If Mean exists - if (node.Output(OUT_MEAN).Exists()) { - auto mean_mem = dnnl::memory(lnorm_pd.mean_desc(), dnnl_engine); - lnorm_args.insert({DNNL_ARG_MEAN, mean_mem}); - sp.SetMemory(node.Output(OUT_MEAN), mean_mem); - } - // If Variance exists - if (node.Output(OUT_INV_STD_VAR).Exists()) { - auto variance_mem = dnnl::memory(lnorm_pd.variance_desc(), dnnl_engine); - lnorm_args.insert({DNNL_ARG_VARIANCE, variance_mem}); - sp.SetMemory(node.Output(OUT_INV_STD_VAR), variance_mem); + if (node.OutputCount() > 1) { + if (node.Output(OUT_MEAN).Exists()) { + auto mean_mem = dnnl::memory(lnorm_pd.mean_desc(), dnnl_engine); + lnorm_args.insert({DNNL_ARG_MEAN, mean_mem}); + sp.SetMemory(node.Output(OUT_MEAN), mean_mem); + } + // If Variance exists + if (node.Output(OUT_INV_STD_VAR).Exists()) { + auto variance_mem = dnnl::memory(lnorm_pd.variance_desc(), dnnl_engine); + lnorm_args.insert({DNNL_ARG_VARIANCE, variance_mem}); + sp.SetMemory(node.Output(OUT_INV_STD_VAR), variance_mem); + } } #endif // ENABLE_TRAINING diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc index 9cae1015e6..19310fa41b 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc @@ -66,85 +66,91 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { //Holds transposed matrices A and B. ToDo: Eliminate its usage if in place transpose is possbile for FusedMatmul - dnnl::memory::desc intermediateA_md; - dnnl::memory intermediateA_mem; + dnnl::memory::desc transposedA_md; + dnnl::memory transposedA_mem; - dnnl::memory::desc intermediateB_md; - dnnl::memory intermediateB_mem; + dnnl::memory::desc transposedB_md; + dnnl::memory transposedB_mem; if (is_fusedmatmul) { if (transA || transBatchA) { dnnl::memory::dims strides = GetStrides(dataA_dims, transA, transBatchA, transposedA_dims); - intermediateA_md = dnnl::memory::desc(dataA_dims, node.Input(IN_A).Type(), strides); - intermediateA_mem = dnnl::memory(intermediateA_md, eng); + dnnl::memory::desc intermediateA_md = dnnl::memory::desc(dataA_dims, node.Input(IN_A).Type(), strides); + dnnl::memory intermediateA_mem = dnnl::memory(intermediateA_md, eng); auto traspose_primitive = dnnl::reorder(dataA_mem, intermediateA_mem); - sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataA_mem}, - {DNNL_ARG_TO, intermediateA_mem}}); + sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataA_mem}, {DNNL_ARG_TO, intermediateA_mem}}); - // The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor - // that will have the correct dimentions and correct memory::format - dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataA_dims.size())); - dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr); - void* handle = intermediateA_mem.get_data_handle(); - transposed_mem.set_data_handle(handle); while (transposedA_dims.size() < weights_dims.size()) { transposedA_dims.insert(transposedA_dims.begin(), 1); } + + // The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor + // that will have the correct dimentions and correct memory::format + transposedA_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(transposedA_dims.size())); + transposedA_mem = dnnl::memory(transposedA_md, eng, nullptr); + void* handle = intermediateA_mem.get_data_handle(); + transposedA_mem.set_data_handle(handle); } if (transB || transBatchB) { // Exact same logic for matrix B as used for matrix A dnnl::memory::dims strides = GetStrides(dataB_dims, transB, transBatchB, transposedB_dims); - intermediateB_md = dnnl::memory::desc(dataB_dims, node.Input(IN_B).Type(), strides); - intermediateB_mem = dnnl::memory(intermediateB_md, eng); + dnnl::memory::desc intermediateB_md = dnnl::memory::desc(dataB_dims, node.Input(IN_B).Type(), strides); + dnnl::memory intermediateB_mem = dnnl::memory(intermediateB_md, eng); auto traspose_primitive = dnnl::reorder(dataB_mem, intermediateB_mem); - sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataB_mem}, - {DNNL_ARG_TO, intermediateB_mem}}); + sp.AddPrimitive(traspose_primitive, {{DNNL_ARG_FROM, dataB_mem}, {DNNL_ARG_TO, intermediateB_mem}}); - // The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor - // that will have the correct dimentions and correct memory::format - dnnl::memory::desc transposed_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_A).Type(), sp.GetDnnlFormat(dataB_dims.size())); - dnnl::memory transposed_mem = dnnl::memory(transposed_md, eng, nullptr); - void* handle = intermediateB_mem.get_data_handle(); - transposed_mem.set_data_handle(handle); while (src_dims.size() > transposedB_dims.size()) { transposedB_dims.insert(transposedB_dims.begin(), 1); } + + // The reorder from above will get the memory in the right order. The next few lines will create a memory and memory descriptor + // that will have the correct dimentions and correct memory::format + transposedB_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_B).Type(), sp.GetDnnlFormat(transposedB_dims.size())); + transposedB_mem = dnnl::memory(transposedB_md, eng, nullptr); + void* handle = intermediateB_mem.get_data_handle(); + transposedB_mem.set_data_handle(handle); + } } dnnl::memory::desc src_md; if (transA || transBatchA) { - src_md = dnnl::memory::desc(transposedA_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any); + src_md = transposedA_md; } else { src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any); } dnnl::memory::desc weights_md; if (transB || transBatchB) { - weights_md = dnnl::memory::desc(transposedB_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any); + weights_md = transposedB_md; } else { weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any); } auto output_shape = src_dims; - if (transA || transBatchA) + if (transA || transBatchA) { output_shape = transposedA_dims; + } output_shape.pop_back(); - if (transB || transBatchB) + if (transB || transBatchB) { output_shape.emplace_back(transposedB_dims.back()); - else + } else { output_shape.emplace_back(weights_dims.back()); + } + for (size_t i = 0; i < output_shape.size() - 2; i++) { if (output_shape[i] == 1) { - if (transB || transBatchB) + if (transB || transBatchB) { output_shape[i] = transposedB_dims[i]; - else - output_shape[i] = weights_dims[i]; + } else { + output_shape[i] = weights_dims[i]; + } + } } @@ -195,12 +201,12 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto matmul_prim = dnnl::matmul(matmul_pd); if (transA || transBatchA) { - matmul_src_mem = intermediateA_mem; + matmul_src_mem = transposedA_mem; } else { matmul_src_mem = sp.GetMemoryAndReshape(node.Input(IN_A), matmul_pd.src_desc(), eng); } if (transB || transBatchB) { - matmul_weights_mem = intermediateB_mem; + matmul_weights_mem = transposedB_mem; } else { matmul_weights_mem = sp.GetMemoryAndReshape(node.Input(IN_B), matmul_pd.weights_desc(), eng); } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc index 076959a983..341868a3c7 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_pool.cc @@ -12,8 +12,15 @@ DnnlPool::DnnlPool() {} void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dnnl_engine = sp.GetEngine(); - +#ifdef ENABLE_TRAINING + // When using training the memory needs to be in a format known to pool_forward and the + // pool_backward primitives. Since we don't currently have a way to pass the memory format + // from pool_forward to pool_backward; we are choosing to use Onnxruntime's memory format + // as the common memory format to be used by both forward and the backward primitives. + auto pool_src_mem = sp.GetMemoryInOrtFormat(node.Input(IN_X), dnnl_engine); +#else auto pool_src_mem = sp.GetMemory(node.Input(IN_X)); +#endif // ENABLE_TRAINING auto src_md = pool_src_mem.get_desc(); auto src_dims = pool_src_mem.get_desc().dims(); @@ -51,8 +58,10 @@ void DnnlPool::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto pool_pd = dnnl::pooling_forward::primitive_desc(pool_desc, dnnl_engine); +#ifndef ENABLE_TRAINING // If using GPU this will move the memory from the CPU to the GPU. pool_src_mem = sp.GetMemoryAndReshape(node.Input(IN_X), pool_pd.src_desc(), dnnl_engine); +#endif dnnl::memory pool_dst_mem = dnnl::memory(pool_pd.dst_desc(), dnnl_engine); auto pool_op = dnnl::pooling_forward(pool_pd); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc index dc2c497cc0..946d5a5543 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_poolgrad.cc @@ -74,6 +74,7 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { auto dx_dims = node.Output(OUT_DX).Dim(); dnnl::memory::desc dx_md(dx_dims, node.Input(IN_DY).Type(), dnnl::memory::format_tag::any); + dnnl::memory::desc fwd_dx_md(dx_dims, node.Input(IN_DY).Type(), sp.GetDnnlFormat(dx_dims.size())); //Read the attributes auto kernel_shape = GetKernelShape(node); @@ -92,7 +93,7 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { } dnnl::pooling_forward::desc pool_forward_desc(dnnl::prop_kind::forward, algo, - dx_md, dy_md, + fwd_dx_md, dy_md, strides, kernel_shape, padding_left, padding_right); dnnl::pooling_forward::primitive_desc pool_forward_pd(pool_forward_desc, dnnl_engine); @@ -110,9 +111,9 @@ void DnnlPoolGrad::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { dnnl::memory dx_mem(pool_backward_pd.diff_src_desc(), dnnl_engine); if (maxpoolgrad_optype) { - sp.AddPrimitive(pool_backward_op, {{DNNL_ARG_DIFF_DST, dy_mem}, - {DNNL_ARG_DIFF_SRC, dx_mem}, - {DNNL_ARG_WORKSPACE, indices_mem}}); + sp.AddPrimitive(pool_backward_op, {{DNNL_ARG_DIFF_DST, dy_mem}, + {DNNL_ARG_DIFF_SRC, dx_mem}, + {DNNL_ARG_WORKSPACE, indices_mem}}); } else { sp.AddPrimitive(pool_backward_op, {{DNNL_ARG_DIFF_DST, dy_mem}, {DNNL_ARG_DIFF_SRC, dx_mem}}); diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc index d3e6636875..92f788f4aa 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_primitive.cc @@ -36,7 +36,35 @@ #include #include +#include +#include +/* +* The DNNL_TENSOR_PRINT_MEMORY should always be 0 unless debugging +* +* These macros can be used to print the contents of a OneDNN tensor +* This can be used to debug and investigate the values inputs and outputs +* of OneDNN ops. +* +* To use set DNNL_TENSOR_PRINT_MEMORY to 1 +* Find the operator you want to investigate and add the memory you want to print when +* calling AddPrimitive() for example: +* Change this code: +* ``` +* sp.AddPrimitive(elemenwise_primitive, + {{DNNL_ARG_SRC, elementwise_src_mem}, {DNNL_ARG_DST, elementwise_dst_mem}}); +* ``` +* to +* ``` +* sp.AddPrimitive(elemenwise_primitive, + {{DNNL_ARG_SRC, elementwise_src_mem}, {DNNL_ARG_DST, elementwise_dst_mem}}, + {DNNL_ARG_SRC, DNNL_ARG_DST}); +* ``` +* Then rebuild and run the code. +* This is a developer only solution to investigating contents of OneDNN's tensors. +*/ +#define DNNL_TENSOR_PRINT_MEMORY 0 +#define DNNL_TENSOR_PRINT_MEMORY_MAX_TENSOR_ELEMENTS 150 namespace onnxruntime { namespace ort_dnnl { @@ -46,12 +74,12 @@ inline bool Contains(const Map& map, const Key& key) { return map.find(key) != map.end(); } - +#if DNNL_TENSOR_PRINT_MEMORY void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { auto md = mem.get_desc(); auto dt = md.data_type(); auto dims = md.dims(); - if (Product(dims) > 50) { + if (Product(dims) > DNNL_TENSOR_PRINT_MEMORY_MAX_TENSOR_ELEMENTS) { printf("tensor too long ignore printing \n"); return; } @@ -75,10 +103,12 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { ((char*)data_vec.data())[i] = ((char*)dh)[i]; } + std::cout << "["; for (auto& data : data_vec) { - printf("%.6f ", data); + std::cout << std::setprecision(6) << data; + if (&data != &data_vec.back()) std::cout << ", "; } - printf("\n"); + std::cout << "]\n"; } else if (dt == dnnl::memory::data_type::u8) { std::vector data_vec(Product(dims)); @@ -87,10 +117,12 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { ((char*)data_vec.data())[i] = ((char*)dh)[i]; } + std::cout << "["; for (auto& data : data_vec) { - printf("%" PRIu8 "\n", data); + std::cout << +data; + if (&data != &data_vec.back()) std::cout << ", "; } - printf("\n"); + std::cout << "]\n"; } else if (dt == dnnl::memory::data_type::s8) { std::vector data_vec(Product(dims)); auto dh = to_mem.get_data_handle(); @@ -98,10 +130,12 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { ((char*)data_vec.data())[i] = ((char*)dh)[i]; } + std::cout << "["; for (auto& data : data_vec) { - printf("%" PRIi8 "\n", data); + std::cout << +data; + if (&data != &data_vec.back()) std::cout << ", "; } - printf("\n"); + std::cout << "]\n"; } else if (dt == dnnl::memory::data_type::s32) { std::vector data_vec(Product(dims)); auto dh = to_mem.get_data_handle(); @@ -109,14 +143,17 @@ void DnnlSubgraphPrimitive::PrintMemory(const dnnl::memory& mem) { ((char*)data_vec.data())[i] = ((char*)dh)[i]; } + std::cout << "["; for (auto& data : data_vec) { - printf("%" PRIi32 "\n", data); + std::cout << data; + if (&data != &data_vec.back()) std::cout << ", "; } - printf("\n"); + std::cout << "]\n"; } else { ORT_THROW("Cannot print such data type"); } } +#endif // DNNL_TENSOR_PRINT_MEMORY int Product(dnnl::memory::dims d) { int result = 1; @@ -647,9 +684,8 @@ onnxruntime::common::Status DnnlSubgraphPrimitive::Predict(const std::unordered_ for (size_t i = 0; i < net_.size(); ++i) { net_.at(i).execute(stream, net_args_.at(i)); stream.wait(); - +#if DNNL_TENSOR_PRINT_MEMORY //for debug memory purpose - /* for (auto e : items_to_print_) { auto net_index = e.first; auto net_arg_index = e.second; @@ -657,8 +693,7 @@ onnxruntime::common::Status DnnlSubgraphPrimitive::Predict(const std::unordered_ PrintMemory(net_args_.at(i)[net_arg_index]); } } - */ - +#endif //DNNL_TENSOR_PRINT_MEMORY } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 6cad4cea25..688ca21f13 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -2749,7 +2749,7 @@ TEST(ReductionOpTest, ReduceInfLogSum) { {FLOAT_INF, FLOAT_INF, -std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider}); } TEST(ReductionOpTest, ReduceInfLogSumExp) {