diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 634548de0d..95ff657451 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -14,6 +14,7 @@ import numpy as np import onnx from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto +from packaging import version logger = logging.getLogger(__name__) @@ -170,7 +171,7 @@ def convert_float_to_float16( assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504" func_infer_shape = None - if not disable_shape_infer and onnx.__version__ >= "1.2": + if not disable_shape_infer and version.parse(onnx.__version__) >= version.parse("1.2.0"): try: from onnx.shape_inference import infer_shapes