Add patch for ONNX 1.16.0 shape inference bug (#20316)

### Description
- Adds a patch that fixes a shape inference bug that caused a segfault:
https://github.com/onnx/onnx/pull/6080
- Fix documentation describing why QLinearMatMul tests are currently
being skipped.



### Motivation and Context
The [PR for integrating with ONNX
1.16.0](https://github.com/microsoft/onnxruntime/pull/19745) disabled
various python quantization tests due to a shape inference bug. This PR
applies the ONNX fix as a patch. We still can't enable the tests because
some of our CIs pip install onnx-1.16.0, which doesn't include the fix.
This commit is contained in:
Adrian Lizarraga 2024-04-17 10:23:22 -07:00 committed by GitHub
parent bb1972264b
commit 0a1902525f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 75 additions and 39 deletions

View file

@ -1,8 +1,8 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4dd56b6e..018da488 100644
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6d7ca846..69aa622f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -397,6 +397,7 @@ if (MSVC)
@@ -499,6 +499,7 @@ if (MSVC)
endif()
else()
# On non-Windows, hide all symbols we don't need
@ -10,7 +10,7 @@ index 4dd56b6e..018da488 100644
set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)")
set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden)
set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1)
@@ -548,20 +549,9 @@ endif()
@@ -653,20 +654,9 @@ endif()
if(MSVC)
target_compile_options(onnx_proto
PRIVATE /MP
@ -31,14 +31,72 @@ index 4dd56b6e..018da488 100644
${EXTRA_FLAGS})
if(ONNX_USE_PROTOBUF_SHARED_LIBS)
target_compile_options(onnx_proto
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
index b847798e..a6c31904 100644
--- a/onnx/common/file_utils.h
+++ b/onnx/common/file_utils.h
@@ -6,7 +6,6 @@
#pragma once
-#include <filesystem>
#include <fstream>
#include <string>
@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {
template <typename T>
void LoadProtoFromPath(const std::string proto_path, T& proto) {
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
if (!proto_stream.good()) {
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
}
diff --git a/onnx/defs/quantization/defs.cc b/onnx/defs/quantization/defs.cc
index 70b4a4db..98c11545 100644
--- a/onnx/defs/quantization/defs.cc
+++ b/onnx/defs/quantization/defs.cc
@@ -200,6 +200,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetDoc(DequantizeLinear_ver21_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
diff --git a/onnx/defs/quantization/old.cc b/onnx/defs/quantization/old.cc
index 3f2d6384..d2f7cfd8 100644
--- a/onnx/defs/quantization/old.cc
+++ b/onnx/defs/quantization/old.cc
@@ -130,6 +130,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetDoc(DequantizeLinear_ver19_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
@@ -181,7 +184,6 @@ ONNX_OPERATOR_SET_SCHEMA(
if (!hasInputShape(ctx, 0)) {
return;
}
-
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h
index 0aab3e26..0f859267 100644
index 0aab3e26..398ac2d6 100644
--- a/onnx/onnx_pb.h
+++ b/onnx/onnx_pb.h
@@ -47,10 +47,28 @@
#define ONNX_API ONNX_IMPORT
#endif
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+
@ -58,34 +116,12 @@ index 0aab3e26..0f859267 100644
#else
#include "onnx/onnx.pb.h"
#endif
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
#endif // ! ONNX_ONNX_PB_H
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
index b847798e..a6c31904 100644
--- a/onnx/common/file_utils.h
+++ b/onnx/common/file_utils.h
@@ -6,7 +6,6 @@
#pragma once
-#include <filesystem>
#include <fstream>
#include <string>
@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {
template <typename T>
void LoadProtoFromPath(const std::string proto_path, T& proto) {
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
if (!proto_stream.good()) {
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
}
diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc
index fab1faf2..8723dcd4 100644
--- a/onnx/shape_inference/implementation.cc

View file

@ -318,7 +318,7 @@ class TestOpGemm(unittest.TestCase):
weight_type=QuantType.QUInt8,
)
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_s8s8(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"
@ -366,7 +366,7 @@ class TestOpGemm(unittest.TestCase):
calibrate_method=CalibrationMethod.Distribution,
)
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_e4m3fn_same(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"
@ -397,7 +397,7 @@ class TestOpGemm(unittest.TestCase):
calibrate_method=CalibrationMethod.Distribution,
)
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_gemm_e4m3fn_p3(self):
np.random.seed(1)
model_fp32_path = "gemm_fp32.onnx"

View file

@ -347,7 +347,7 @@ class TestOpMatMul(unittest.TestCase):
def test_quantize_matmul_u8u8(self):
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT, 18, 8)
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_u8u8_f16(self):
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9)
@ -393,22 +393,22 @@ class TestOpMatMul(unittest.TestCase):
def test_quantize_matmul_s8s8_distribution(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution)
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9)
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_entropy(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy)
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_percentile(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile)
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
@skip_if_new_opset_exception_raised
def test_quantize_matmul_s8s8_f16_distribution(self):
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution)

View file

@ -194,7 +194,7 @@ class TestOpRelu(unittest.TestCase):
weight_type=QuantType.QUInt8,
)
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
def test_quantize_qop_relu_s8s8(self):
np.random.seed(1)
model_fp32_path = "relu_fp32.onnx"

View file

@ -348,7 +348,7 @@ def test_get_input_output_names():
# Fails in ONNX 1.16.0 due to potential shape inference bug for custom ops.
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6049
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6080
# Error log: LookupError: The provided name onnx::linear.output::171 is not a graph value info or a graph output.
@pytest.mark.skipif(
pv.Version(onnx.__version__) == pv.Version("1.16.0"), reason="Shape inference bug for custom ops in ONNX 1.16.0"