From 4d31076d687560f7cfdada19f6ba9ad5a86612f2 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Mon, 18 Mar 2024 08:54:24 -0700 Subject: [PATCH] [objc] Add check for ORTValue being a tensor in ORTValue methods that should only be used with tensors. (#19946) Add check to report error instead of crashing. --- objectivec/ort_value.mm | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/objectivec/ort_value.mm b/objectivec/ort_value.mm index b9dc1a9885..c61a7ea809 100644 --- a/objectivec/ort_value.mm +++ b/objectivec/ort_value.mm @@ -148,6 +148,9 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) { - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) @@ -156,6 +159,9 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) { - (nullable NSMutableData*)tensorDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { ORT_CXX_API_THROW( "This ORTValue holds string data. Please call tensorStringDataWithError: " @@ -182,6 +188,9 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) { - (nullable NSArray*)tensorStringDataWithError:(NSError**)error { try { const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo(); + if (!tensorTypeAndShapeInfo) { + ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION); + } const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount(); const size_t tensorStringDataLength = _value->GetStringTensorDataLength(); std::vector tensorStringData(tensorStringDataLength, '\0');