From b9cbbc41ff786dbb781d01df196c00303d88f10c Mon Sep 17 00:00:00 2001 From: jeyblu Date: Fri, 23 Apr 2021 23:17:22 -0700 Subject: [PATCH] dnnl matmul tensor dimension check (#7383) --- .../providers/dnnl/dnnl_execution_provider.h | 20 +++++++++++++++++ .../providers/dnnl/subgraph/dnnl_matmul.h | 22 +++---------------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 12506fd70b..0178c3876e 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -194,6 +194,26 @@ class DNNLExecutionProvider : public IExecutionProvider { supported = false; #endif // ENABLE_TRAINING } + if (node->OpType().find("MatMul") != std::string::npos) { + auto node_inputs = node->InputDefs(); + if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) && + (node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) && + (node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) { + supported = true; + for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) { + if (utils::HasDimValue(dim) && dim.dim_value() == 0) { + supported = false; + } + } + for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) { + if (utils::HasDimValue(dim) && dim.dim_value() == 0) { + supported = false; + } + } + } else { + supported = false; + } + } return supported; } diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h index 92597ea1bc..2d6a776fe7 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.h @@ -75,7 +75,6 @@ class DnnlMatmul : public DnnlKernel { auto wdim = wtensor_shape.size(); TensorShape w_shape(wshape, wdim); - AdjustSrcWeightsShape(x_shape, w_shape); weights_shape_ = w_shape; weights_format_ = GetSourceFormat(static_cast(w_shape.NumDimensions())); @@ -303,12 +302,10 @@ class DnnlMatmul : public DnnlKernel { private: dnnl::memory::format_tag weights_format_; - std::shared_ptr src_mem_from_; size_t weights_size_; size_t dst_size_; - TensorShape weights_shape_; std::shared_ptr src_mem_; @@ -330,25 +327,12 @@ class DnnlMatmul : public DnnlKernel { output_shape = input_shape.GetDims(); output_shape.pop_back(); output_shape.emplace_back(weight_shape.GetDims().back()); - } - - void AdjustSrcWeightsShape(TensorShape& input_shape, TensorShape& weights_shape) const { - - if (input_shape.NumDimensions() > weights_shape.NumDimensions()) { - auto dims = weights_shape.GetDims(); - for (size_t i = 0; i < input_shape.NumDimensions() - weights_shape.NumDimensions(); i++) { - dims.insert(dims.begin(), 1); + for (size_t i = 0; i < output_shape.size() - 2; i++) { + if (output_shape[i] == 1) { + output_shape[i] = weight_shape[i]; } - weights_shape = TensorShape(dims); - } else if (input_shape.NumDimensions() < weights_shape.NumDimensions()) { - auto dims = input_shape.GetDims(); - for (size_t i = 0; i < weights_shape.NumDimensions() - input_shape.NumDimensions(); i++) { - dims.insert(dims.begin(), 1); - } - input_shape = TensorShape(dims); } } - }; } // namespace ort_dnnl } // namespace onnxruntime \ No newline at end of file