dnnl matmul tensor dimension check (#7383)

This commit is contained in:
jeyblu 2021-04-23 23:17:22 -07:00 committed by GitHub
parent afe912d47c
commit b9cbbc41ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 19 deletions

View file

@ -194,6 +194,26 @@ class DNNLExecutionProvider : public IExecutionProvider {
supported = false;
#endif // ENABLE_TRAINING
}
if (node->OpType().find("MatMul") != std::string::npos) {
auto node_inputs = node->InputDefs();
if ((node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() >= 2) &&
(node_inputs[1]->Shape() != nullptr && node_inputs[1]->Shape()->dim_size() >= 2) &&
(node_inputs[0]->Shape()->dim_size() == node_inputs[1]->Shape()->dim_size())) {
supported = true;
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[0]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
supported = false;
}
}
for (const onnx::TensorShapeProto_Dimension& dim : node_inputs[1]->Shape()->dim()) {
if (utils::HasDimValue(dim) && dim.dim_value() == 0) {
supported = false;
}
}
} else {
supported = false;
}
}
return supported;
}

View file

@ -75,7 +75,6 @@ class DnnlMatmul : public DnnlKernel {
auto wdim = wtensor_shape.size();
TensorShape w_shape(wshape, wdim);
AdjustSrcWeightsShape(x_shape, w_shape);
weights_shape_ = w_shape;
weights_format_ = GetSourceFormat(static_cast<int>(w_shape.NumDimensions()));
@ -303,12 +302,10 @@ class DnnlMatmul : public DnnlKernel {
private:
dnnl::memory::format_tag weights_format_;
std::shared_ptr<dnnl::memory> src_mem_from_;
size_t weights_size_;
size_t dst_size_;
TensorShape weights_shape_;
std::shared_ptr<dnnl::memory> src_mem_;
@ -330,25 +327,12 @@ class DnnlMatmul : public DnnlKernel {
output_shape = input_shape.GetDims();
output_shape.pop_back();
output_shape.emplace_back(weight_shape.GetDims().back());
}
void AdjustSrcWeightsShape(TensorShape& input_shape, TensorShape& weights_shape) const {
if (input_shape.NumDimensions() > weights_shape.NumDimensions()) {
auto dims = weights_shape.GetDims();
for (size_t i = 0; i < input_shape.NumDimensions() - weights_shape.NumDimensions(); i++) {
dims.insert(dims.begin(), 1);
for (size_t i = 0; i < output_shape.size() - 2; i++) {
if (output_shape[i] == 1) {
output_shape[i] = weight_shape[i];
}
weights_shape = TensorShape(dims);
} else if (input_shape.NumDimensions() < weights_shape.NumDimensions()) {
auto dims = input_shape.GetDims();
for (size_t i = 0; i < weights_shape.NumDimensions() - input_shape.NumDimensions(); i++) {
dims.insert(dims.begin(), 1);
}
input_shape = TensorShape(dims);
}
}
};
} // namespace ort_dnnl
} // namespace onnxruntime