diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 87feefa10c..d8822b3e45 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -208,6 +208,7 @@ class DataTypeImpl { static const std::vector& AllTensorTypes(); // up to IR4 (no float 8), deprecated static const std::vector& AllTensorTypesIRv4(); static const std::vector& AllTensorTypesIRv9(); + static const std::vector& AllTensorTypesIRv10(); static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated static const std::vector& AllFixedSizeTensorTypesIRv4(); diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 322afcf384..06aab16f4a 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -1301,6 +1301,12 @@ const std::vector& DataTypeImpl::AllTensorTypesIRv9() { return all_tensor_types; } +const std::vector& DataTypeImpl::AllTensorTypesIRv10() { + static std::vector all_tensor_types = + GetTensorTypesFromTypeList(); + return all_tensor_types; +} + const std::vector& DataTypeImpl::AllFixedSizeSequenceTensorTypes() { return AllFixedSizeSequenceTensorTypesIRv4(); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 33d2a0244b..8492391172 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -49,7 +49,7 @@ static constexpr uint32_t min_ort_version_with_shape_inference = 17; #endif #if !defined(DISABLE_FLOAT8_TYPES) -#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv9() +#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv10() #else #define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv4() #endif