mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
dnnl matmul tensor dimension check (#7383)
This commit is contained in:
parent
afe912d47c
commit
b9cbbc41ff
2 changed files with 23 additions and 19 deletions
|
|
@ -194,6 +194,26 @@ class DNNLExecutionProvider : public IExecutionProvider {
|
|||
supported = false;
|
||||
#endif // ENABLE_TRAINING
|
||||
}
|
||||
if (node->OpType().find("MatMul") != std::string::npos) {
|
||||
auto node_inputs = node->InputDefs();
|
||||
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
|
||||
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
|
||||
supported = true;
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
|
||||
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ class DnnlMatmul : public DnnlKernel {
|
|||
auto wdim = wtensor_shape.size();
|
||||
TensorShape w_shape(wshape, wdim);
|
||||
|
||||
AdjustSrcWeightsShape(x_shape, w_shape);
|
||||
weights_shape_ = w_shape;
|
||||
weights_format_ = GetSourceFormat(static_cast<int>(w_shape.NumDimensions()));
|
||||
|
||||
|
|
@ -303,12 +302,10 @@ class DnnlMatmul : public DnnlKernel {
|
|||
|
||||
private:
|
||||
dnnl::memory::format_tag weights_format_;
|
||||
|
||||
std::shared_ptr<dnnl::memory> src_mem_from_;
|
||||
|
||||
size_t weights_size_;
|
||||
size_t dst_size_;
|
||||
|
||||
TensorShape weights_shape_;
|
||||
|
||||
std::shared_ptr<dnnl::memory> src_mem_;
|
||||
|
|
@ -330,25 +327,12 @@ class DnnlMatmul : public DnnlKernel {
|
|||
output_shape = input_shape.GetDims();
|
||||
output_shape.pop_back();
|
||||
output_shape.emplace_back(weight_shape.GetDims().back());
|
||||
}
|
||||
|
||||
void AdjustSrcWeightsShape(TensorShape& input_shape, TensorShape& weights_shape) const {
|
||||
|
||||
if (input_shape.NumDimensions() > weights_shape.NumDimensions()) {
|
||||
auto dims = weights_shape.GetDims();
|
||||
for (size_t i = 0; i < input_shape.NumDimensions() - weights_shape.NumDimensions(); i++) {
|
||||
dims.insert(dims.begin(), 1);
|
||||
for (size_t i = 0; i < output_shape.size() - 2; i++) {
|
||||
if (output_shape[i] == 1) {
|
||||
output_shape[i] = weight_shape[i];
|
||||
}
|
||||
weights_shape = TensorShape(dims);
|
||||
} else if (input_shape.NumDimensions() < weights_shape.NumDimensions()) {
|
||||
auto dims = input_shape.GetDims();
|
||||
for (size_t i = 0; i < weights_shape.NumDimensions() - input_shape.NumDimensions(); i++) {
|
||||
dims.insert(dims.begin(), 1);
|
||||
}
|
||||
input_shape = TensorShape(dims);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace ort_dnnl
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue