mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add patch for ONNX 1.16.0 shape inference bug (#20316)
### Description - Adds a patch that fixes a shape inference bug that caused a segfault: https://github.com/onnx/onnx/pull/6080 - Fix documentation describing why QLinearMatMul tests are currently being skipped. ### Motivation and Context The [PR for integrating with ONNX 1.16.0](https://github.com/microsoft/onnxruntime/pull/19745) disabled various python quantization tests due to a shape inference bug. This PR applies the ONNX fix as a patch. We still can't enable the tests because some of our CIs pip install onnx-1.16.0, which doesn't include the fix.
This commit is contained in:
parent
bb1972264b
commit
0a1902525f
5 changed files with 75 additions and 39 deletions
|
|
@ -1,8 +1,8 @@
|
|||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 4dd56b6e..018da488 100644
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 6d7ca846..69aa622f 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -397,6 +397,7 @@ if (MSVC)
|
||||
@@ -499,6 +499,7 @@ if (MSVC)
|
||||
endif()
|
||||
else()
|
||||
# On non-Windows, hide all symbols we don't need
|
||||
|
|
@ -10,7 +10,7 @@ index 4dd56b6e..018da488 100644
|
|||
set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)")
|
||||
set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden)
|
||||
set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1)
|
||||
@@ -548,20 +549,9 @@ endif()
|
||||
@@ -653,20 +654,9 @@ endif()
|
||||
if(MSVC)
|
||||
target_compile_options(onnx_proto
|
||||
PRIVATE /MP
|
||||
|
|
@ -31,14 +31,72 @@ index 4dd56b6e..018da488 100644
|
|||
${EXTRA_FLAGS})
|
||||
if(ONNX_USE_PROTOBUF_SHARED_LIBS)
|
||||
target_compile_options(onnx_proto
|
||||
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
|
||||
index b847798e..a6c31904 100644
|
||||
--- a/onnx/common/file_utils.h
|
||||
+++ b/onnx/common/file_utils.h
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
-#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {
|
||||
|
||||
template <typename T>
|
||||
void LoadProtoFromPath(const std::string proto_path, T& proto) {
|
||||
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
|
||||
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
|
||||
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
|
||||
if (!proto_stream.good()) {
|
||||
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
|
||||
}
|
||||
diff --git a/onnx/defs/quantization/defs.cc b/onnx/defs/quantization/defs.cc
|
||||
index 70b4a4db..98c11545 100644
|
||||
--- a/onnx/defs/quantization/defs.cc
|
||||
+++ b/onnx/defs/quantization/defs.cc
|
||||
@@ -200,6 +200,9 @@ ONNX_OPERATOR_SET_SCHEMA(
|
||||
.SetDoc(DequantizeLinear_ver21_doc)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 1, 0);
|
||||
+ if (!hasInputShape(ctx, 0)) {
|
||||
+ return;
|
||||
+ }
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
updateOutputShape(ctx, 0, input_shape);
|
||||
}));
|
||||
diff --git a/onnx/defs/quantization/old.cc b/onnx/defs/quantization/old.cc
|
||||
index 3f2d6384..d2f7cfd8 100644
|
||||
--- a/onnx/defs/quantization/old.cc
|
||||
+++ b/onnx/defs/quantization/old.cc
|
||||
@@ -130,6 +130,9 @@ ONNX_OPERATOR_SET_SCHEMA(
|
||||
.SetDoc(DequantizeLinear_ver19_doc)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 1, 0);
|
||||
+ if (!hasInputShape(ctx, 0)) {
|
||||
+ return;
|
||||
+ }
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
updateOutputShape(ctx, 0, input_shape);
|
||||
}));
|
||||
@@ -181,7 +184,6 @@ ONNX_OPERATOR_SET_SCHEMA(
|
||||
if (!hasInputShape(ctx, 0)) {
|
||||
return;
|
||||
}
|
||||
-
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
updateOutputShape(ctx, 0, input_shape);
|
||||
}));
|
||||
diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h
|
||||
index 0aab3e26..0f859267 100644
|
||||
index 0aab3e26..398ac2d6 100644
|
||||
--- a/onnx/onnx_pb.h
|
||||
+++ b/onnx/onnx_pb.h
|
||||
@@ -47,10 +47,28 @@
|
||||
#define ONNX_API ONNX_IMPORT
|
||||
#endif
|
||||
|
||||
|
||||
+#if defined(__GNUC__)
|
||||
+#pragma GCC diagnostic push
|
||||
+
|
||||
|
|
@ -58,34 +116,12 @@ index 0aab3e26..0f859267 100644
|
|||
#else
|
||||
#include "onnx/onnx.pb.h"
|
||||
#endif
|
||||
|
||||
|
||||
+#if defined(__GNUC__)
|
||||
+#pragma GCC diagnostic pop
|
||||
+#endif
|
||||
+
|
||||
#endif // ! ONNX_ONNX_PB_H
|
||||
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
|
||||
index b847798e..a6c31904 100644
|
||||
--- a/onnx/common/file_utils.h
|
||||
+++ b/onnx/common/file_utils.h
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
-#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {
|
||||
|
||||
template <typename T>
|
||||
void LoadProtoFromPath(const std::string proto_path, T& proto) {
|
||||
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
|
||||
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
|
||||
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
|
||||
if (!proto_stream.good()) {
|
||||
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
|
||||
}
|
||||
diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc
|
||||
index fab1faf2..8723dcd4 100644
|
||||
--- a/onnx/shape_inference/implementation.cc
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ class TestOpGemm(unittest.TestCase):
|
|||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
|
||||
def test_quantize_qop_gemm_s8s8(self):
|
||||
np.random.seed(1)
|
||||
model_fp32_path = "gemm_fp32.onnx"
|
||||
|
|
@ -366,7 +366,7 @@ class TestOpGemm(unittest.TestCase):
|
|||
calibrate_method=CalibrationMethod.Distribution,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
|
||||
def test_quantize_qop_gemm_e4m3fn_same(self):
|
||||
np.random.seed(1)
|
||||
model_fp32_path = "gemm_fp32.onnx"
|
||||
|
|
@ -397,7 +397,7 @@ class TestOpGemm(unittest.TestCase):
|
|||
calibrate_method=CalibrationMethod.Distribution,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
|
||||
def test_quantize_qop_gemm_e4m3fn_p3(self):
|
||||
np.random.seed(1)
|
||||
model_fp32_path = "gemm_fp32.onnx"
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ class TestOpMatMul(unittest.TestCase):
|
|||
def test_quantize_matmul_u8u8(self):
|
||||
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT, 18, 8)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
|
||||
@skip_if_new_opset_exception_raised
|
||||
def test_quantize_matmul_u8u8_f16(self):
|
||||
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9)
|
||||
|
|
@ -393,22 +393,22 @@ class TestOpMatMul(unittest.TestCase):
|
|||
def test_quantize_matmul_s8s8_distribution(self):
|
||||
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
|
||||
@skip_if_new_opset_exception_raised
|
||||
def test_quantize_matmul_s8s8_f16(self):
|
||||
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
|
||||
@skip_if_new_opset_exception_raised
|
||||
def test_quantize_matmul_s8s8_f16_entropy(self):
|
||||
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
|
||||
@skip_if_new_opset_exception_raised
|
||||
def test_quantize_matmul_s8s8_f16_percentile(self):
|
||||
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #5709")
|
||||
@unittest.skip(reason="QLinearMatMul(21), which supports float16, is not implemented in ORT.")
|
||||
@skip_if_new_opset_exception_raised
|
||||
def test_quantize_matmul_s8s8_f16_distribution(self):
|
||||
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class TestOpRelu(unittest.TestCase):
|
|||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6049")
|
||||
@unittest.skip(reason="Shape inference bug, see onnx PR #6080")
|
||||
def test_quantize_qop_relu_s8s8(self):
|
||||
np.random.seed(1)
|
||||
model_fp32_path = "relu_fp32.onnx"
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ def test_get_input_output_names():
|
|||
|
||||
|
||||
# Fails in ONNX 1.16.0 due to potential shape inference bug for custom ops.
|
||||
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6049
|
||||
# Potential ONNX fix: https://github.com/onnx/onnx/pull/6080
|
||||
# Error log: LookupError: The provided name onnx::linear.output::171 is not a graph value info or a graph output.
|
||||
@pytest.mark.skipif(
|
||||
pv.Version(onnx.__version__) == pv.Version("1.16.0"), reason="Shape inference bug for custom ops in ONNX 1.16.0"
|
||||
|
|
|
|||
Loading…
Reference in a new issue