From d5338da1f556eddff636293c97e0921443fe0fc4 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:35:25 -0800 Subject: [PATCH] Fix tensor external data info length parsing issue. (#23526) Fix tensor external data info length parsing issue. The old implementation was parsing a `size_t` value with `strtol` (via `OrtStrToPtrDiff`) on ARM64 MSVC. https://github.com/microsoft/onnxruntime/blob/bf023ab3d565668c13a5334b505df0eb6acf3625/onnxruntime/core/platform/path_lib.h#L74 If we have `sizeof(size_t) == 8` and `sizeof(long) == 4` (as is the case for x64 and ARM64 MSVC), `strtol` will return a maximum value of `2^31-1` even for a larger, valid `size_t` value. `strtol` will also set `errno` to `ERANGE`, but we weren't checking that. Updated to use `ParseStringWithClassicLocale` which will parse directly to the target type. Added some tests. --- .../framework/tensor_external_data_info.cc | 44 +++++------- .../framework/tensor_external_data_info.h | 2 - .../test/framework/tensorutils_test.cc | 72 +++++++++++++++++++ 3 files changed, 88 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index ec8b25e9f4..971851db62 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -4,6 +4,7 @@ #include "tensor_external_data_info.h" #include "core/common/common.h" #include "core/common/narrow.h" +#include "core/common/parse_string.h" #include "core/common/safeint.h" #include "core/common/string_utils.h" #include "core/platform/path_lib.h" @@ -18,21 +19,8 @@ using ::ONNX_NAMESPACE::StringStringEntryProto; namespace onnxruntime { Status ExternalDataInfo::Create(const RepeatedPtrField& input, - std::unique_ptr& out) { - auto str_to_int = [](const std::string& s, OFFSET_TYPE& result) -> Status { - char* end; -#ifdef _WIN32 - result = _strtoi64(s.c_str(), &end, 10); -#else - result = OrtStrToPtrDiff(s.c_str(), &end); -#endif - if (end != s.c_str() + s.length()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", s, " failed"); - } - return Status::OK(); - }; - - out = std::make_unique(); + std::unique_ptr& external_data_info_result) { + auto external_data_info = std::make_unique(); PrepackedInfos prepacked_infos; const int input_size = input.size(); @@ -43,17 +31,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a key for the external data info"); if (!stringmap.has_value()) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a value for the external data info"); + if (stringmap.key() == "location" && !stringmap.value().empty()) { - out->rel_path_ = ToWideString(stringmap.value()); + external_data_info->rel_path_ = ToWideString(stringmap.value()); } else if (stringmap.key() == "offset" && !stringmap.value().empty()) { - ORT_RETURN_IF_ERROR(str_to_int(stringmap.value(), out->offset_)); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->offset_)); } else if (stringmap.key() == "length" && !stringmap.value().empty()) { - char* end; - out->length_ = narrow(OrtStrToPtrDiff(stringmap.value().c_str(), &end)); - if (end != stringmap.value().c_str() + stringmap.value().length()) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed"); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->length_)); } else if (stringmap.key() == "checksum" && !stringmap.value().empty()) { - out->checksum_ = stringmap.value(); + external_data_info->checksum_ = stringmap.value(); } else if (stringmap.key().find("prepacked", 0) == 0) { // Starts with 'prepacked', each has its own key. // Each prepacked entry may have multiple blobs with the same key @@ -72,10 +58,11 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& const auto& blob = split_fields[f]; auto blob_fields = utils::SplitString(blob, ";", false); if (blob_fields.size() == 3) { - OFFSET_TYPE offset, len; - ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[0]), offset)); - ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[1]), len)); - blob_infos.push_back(std::make_tuple(offset, narrow(len), std::string(blob_fields[2]))); + OFFSET_TYPE offset; + size_t len; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[0], offset)); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[1], len)); + blob_infos.push_back(std::make_tuple(offset, len, std::string(blob_fields[2]))); } } if (blob_infos.empty()) { @@ -88,14 +75,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& } } - if (out->rel_path_.empty()) { + if (external_data_info->rel_path_.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Missing 'location'"); } if (!prepacked_infos.empty()) { - out->prepacked_infos_ = std::move(prepacked_infos); + external_data_info->prepacked_infos_ = std::move(prepacked_infos); } + external_data_info_result = std::move(external_data_info); return Status::OK(); } void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& external_file_path, diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 1b185b8c5d..2de1e01f38 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -32,8 +32,6 @@ class ExternalDataInfo { const std::string& GetChecksum() const { return checksum_; } - // If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a - // wrong value and return FAIL. static common::Status Create( const ::google::protobuf::RepeatedPtrField<::ONNX_NAMESPACE::StringStringEntryProto>& input, std::unique_ptr& out); diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 229f4f95b8..931a507c53 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/inlined_containers.h" +#include "core/common/parse_string.h" #include "core/framework/prepacked_weights.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/tensorprotoutils.h" @@ -9,6 +10,9 @@ #include "test/util/include/asserts.h" #include "file_util.h" +#include +#include + #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -22,6 +26,74 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { +// if `expected_error_message_substring` is nullptr, parsing is expected to be successful +static void TestExternalDataInfoParsingOffsetAndLengthWithStrings( + std::string_view offset_str, + std::string_view length_str, + const char* expected_error_message_substring = nullptr) { + SCOPED_TRACE(MakeString("offset: \"", offset_str, "\", length: \"", length_str, "\"")); + + ONNX_NAMESPACE::TensorProto tensor_proto; + const std::filesystem::path kExternalDataPath("test.bin"); + + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + + auto* location_entry = tensor_proto.add_external_data(); + location_entry->set_key("location"); + location_entry->set_value(ToUTF8String(kExternalDataPath.native())); + + auto* offset_entry = tensor_proto.add_external_data(); + offset_entry->set_key("offset"); + offset_entry->set_value(offset_str.data(), offset_str.size()); + + auto* length_entry = tensor_proto.add_external_data(); + length_entry->set_key("length"); + length_entry->set_value(length_str.data(), length_str.size()); + + std::unique_ptr external_data_info{}; + const auto create_status = ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info); + if (expected_error_message_substring) { + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(create_status, expected_error_message_substring); + return; + } + ASSERT_STATUS_OK(create_status); + + // if we got this far, assume that offset_str and length_str are able to be parsed. + const auto expected_offset = ParseStringWithClassicLocale(offset_str); + const auto expected_length = ParseStringWithClassicLocale(length_str); + + ASSERT_EQ(external_data_info->GetOffset(), expected_offset); + ASSERT_EQ(external_data_info->GetLength(), expected_length); +} + +// if `expected_error_message_substring` is nullptr, parsing is expected to be successful +static void TestExternalDataInfoParsingOffsetAndLength(intmax_t offset, + uintmax_t length, + const char* expected_error_message_substring = nullptr) { + TestExternalDataInfoParsingOffsetAndLengthWithStrings(std::to_string(offset), std::to_string(length), + expected_error_message_substring); +} + +TEST(TensorProtoUtilsTest, ParseExternalDataInfoOffsetAndLength) { + TestExternalDataInfoParsingOffsetAndLength(0, 0); + + TestExternalDataInfoParsingOffsetAndLength(0, 1024); + TestExternalDataInfoParsingOffsetAndLength(0, std::numeric_limits::max()); + + TestExternalDataInfoParsingOffsetAndLength(1024, 1024); + TestExternalDataInfoParsingOffsetAndLength(std::numeric_limits::max(), 1024); + + { + // assuming that this value is too large to fit in either size_t or ExternalDataInfo::OFFSET_TYPE + const std::string_view two_to_the_65th_power = "36893488147419103232"; + const std::string_view zero = "0"; + TestExternalDataInfoParsingOffsetAndLengthWithStrings(two_to_the_65th_power, zero, "Failed to parse value"); + TestExternalDataInfoParsingOffsetAndLengthWithStrings(zero, two_to_the_65th_power, "Failed to parse value"); + } + + // TODO should ExternalDataInfo::Create() also reject negative offset values? +} + // Test ExternalData functionality TEST(TensorProtoUtilsTest, SetExternalDataInformation) { ONNX_NAMESPACE::TensorProto tensor_proto;