mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix tensor external data info length parsing issue. (#23526)
Fix tensor external data info length parsing issue.
The old implementation was parsing a `size_t` value with `strtol` (via `OrtStrToPtrDiff`) on ARM64 MSVC.
bf023ab3d5/onnxruntime/core/platform/path_lib.h (L74)
If we have `sizeof(size_t) == 8` and `sizeof(long) == 4` (as is the case for x64 and ARM64 MSVC), `strtol` will return a maximum value of `2^31-1` even for a larger, valid `size_t` value. `strtol` will also set `errno` to `ERANGE`, but we weren't checking that.
Updated to use `ParseStringWithClassicLocale` which will parse directly to the target type.
Added some tests.
This commit is contained in:
parent
e3e41739a7
commit
d5338da1f5
3 changed files with 88 additions and 30 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include "tensor_external_data_info.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/narrow.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/string_utils.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
|
|
@ -18,21 +19,8 @@ using ::ONNX_NAMESPACE::StringStringEntryProto;
|
|||
|
||||
namespace onnxruntime {
|
||||
Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>& input,
|
||||
std::unique_ptr<ExternalDataInfo>& out) {
|
||||
auto str_to_int = [](const std::string& s, OFFSET_TYPE& result) -> Status {
|
||||
char* end;
|
||||
#ifdef _WIN32
|
||||
result = _strtoi64(s.c_str(), &end, 10);
|
||||
#else
|
||||
result = OrtStrToPtrDiff(s.c_str(), &end);
|
||||
#endif
|
||||
if (end != s.c_str() + s.length()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", s, " failed");
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
out = std::make_unique<ExternalDataInfo>();
|
||||
std::unique_ptr<ExternalDataInfo>& external_data_info_result) {
|
||||
auto external_data_info = std::make_unique<ExternalDataInfo>();
|
||||
PrepackedInfos prepacked_infos;
|
||||
|
||||
const int input_size = input.size();
|
||||
|
|
@ -43,17 +31,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a key for the external data info");
|
||||
if (!stringmap.has_value())
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a value for the external data info");
|
||||
|
||||
if (stringmap.key() == "location" && !stringmap.value().empty()) {
|
||||
out->rel_path_ = ToWideString(stringmap.value());
|
||||
external_data_info->rel_path_ = ToWideString(stringmap.value());
|
||||
} else if (stringmap.key() == "offset" && !stringmap.value().empty()) {
|
||||
ORT_RETURN_IF_ERROR(str_to_int(stringmap.value(), out->offset_));
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->offset_));
|
||||
} else if (stringmap.key() == "length" && !stringmap.value().empty()) {
|
||||
char* end;
|
||||
out->length_ = narrow<size_t>(OrtStrToPtrDiff(stringmap.value().c_str(), &end));
|
||||
if (end != stringmap.value().c_str() + stringmap.value().length())
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed");
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->length_));
|
||||
} else if (stringmap.key() == "checksum" && !stringmap.value().empty()) {
|
||||
out->checksum_ = stringmap.value();
|
||||
external_data_info->checksum_ = stringmap.value();
|
||||
} else if (stringmap.key().find("prepacked", 0) == 0) {
|
||||
// Starts with 'prepacked', each has its own key.
|
||||
// Each prepacked entry may have multiple blobs with the same key
|
||||
|
|
@ -72,10 +58,11 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
|
|||
const auto& blob = split_fields[f];
|
||||
auto blob_fields = utils::SplitString(blob, ";", false);
|
||||
if (blob_fields.size() == 3) {
|
||||
OFFSET_TYPE offset, len;
|
||||
ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[0]), offset));
|
||||
ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[1]), len));
|
||||
blob_infos.push_back(std::make_tuple(offset, narrow<size_t>(len), std::string(blob_fields[2])));
|
||||
OFFSET_TYPE offset;
|
||||
size_t len;
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[0], offset));
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[1], len));
|
||||
blob_infos.push_back(std::make_tuple(offset, len, std::string(blob_fields[2])));
|
||||
}
|
||||
}
|
||||
if (blob_infos.empty()) {
|
||||
|
|
@ -88,14 +75,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
|
|||
}
|
||||
}
|
||||
|
||||
if (out->rel_path_.empty()) {
|
||||
if (external_data_info->rel_path_.empty()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Missing 'location'");
|
||||
}
|
||||
|
||||
if (!prepacked_infos.empty()) {
|
||||
out->prepacked_infos_ = std::move(prepacked_infos);
|
||||
external_data_info->prepacked_infos_ = std::move(prepacked_infos);
|
||||
}
|
||||
|
||||
external_data_info_result = std::move(external_data_info);
|
||||
return Status::OK();
|
||||
}
|
||||
void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& external_file_path,
|
||||
|
|
|
|||
|
|
@ -32,8 +32,6 @@ class ExternalDataInfo {
|
|||
|
||||
const std::string& GetChecksum() const { return checksum_; }
|
||||
|
||||
// If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a
|
||||
// wrong value and return FAIL.
|
||||
static common::Status Create(
|
||||
const ::google::protobuf::RepeatedPtrField<::ONNX_NAMESPACE::StringStringEntryProto>& input,
|
||||
std::unique_ptr<ExternalDataInfo>& out);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/common/parse_string.h"
|
||||
#include "core/framework/prepacked_weights.h"
|
||||
#include "core/framework/prepacked_weights_container.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
|
@ -9,6 +10,9 @@
|
|||
#include "test/util/include/asserts.h"
|
||||
#include "file_util.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
|
|
@ -22,6 +26,74 @@ using namespace ONNX_NAMESPACE;
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
// if `expected_error_message_substring` is nullptr, parsing is expected to be successful
|
||||
static void TestExternalDataInfoParsingOffsetAndLengthWithStrings(
|
||||
std::string_view offset_str,
|
||||
std::string_view length_str,
|
||||
const char* expected_error_message_substring = nullptr) {
|
||||
SCOPED_TRACE(MakeString("offset: \"", offset_str, "\", length: \"", length_str, "\""));
|
||||
|
||||
ONNX_NAMESPACE::TensorProto tensor_proto;
|
||||
const std::filesystem::path kExternalDataPath("test.bin");
|
||||
|
||||
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
|
||||
|
||||
auto* location_entry = tensor_proto.add_external_data();
|
||||
location_entry->set_key("location");
|
||||
location_entry->set_value(ToUTF8String(kExternalDataPath.native()));
|
||||
|
||||
auto* offset_entry = tensor_proto.add_external_data();
|
||||
offset_entry->set_key("offset");
|
||||
offset_entry->set_value(offset_str.data(), offset_str.size());
|
||||
|
||||
auto* length_entry = tensor_proto.add_external_data();
|
||||
length_entry->set_key("length");
|
||||
length_entry->set_value(length_str.data(), length_str.size());
|
||||
|
||||
std::unique_ptr<ExternalDataInfo> external_data_info{};
|
||||
const auto create_status = ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info);
|
||||
if (expected_error_message_substring) {
|
||||
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(create_status, expected_error_message_substring);
|
||||
return;
|
||||
}
|
||||
ASSERT_STATUS_OK(create_status);
|
||||
|
||||
// if we got this far, assume that offset_str and length_str are able to be parsed.
|
||||
const auto expected_offset = ParseStringWithClassicLocale<ExternalDataInfo::OFFSET_TYPE>(offset_str);
|
||||
const auto expected_length = ParseStringWithClassicLocale<size_t>(length_str);
|
||||
|
||||
ASSERT_EQ(external_data_info->GetOffset(), expected_offset);
|
||||
ASSERT_EQ(external_data_info->GetLength(), expected_length);
|
||||
}
|
||||
|
||||
// if `expected_error_message_substring` is nullptr, parsing is expected to be successful
|
||||
static void TestExternalDataInfoParsingOffsetAndLength(intmax_t offset,
|
||||
uintmax_t length,
|
||||
const char* expected_error_message_substring = nullptr) {
|
||||
TestExternalDataInfoParsingOffsetAndLengthWithStrings(std::to_string(offset), std::to_string(length),
|
||||
expected_error_message_substring);
|
||||
}
|
||||
|
||||
TEST(TensorProtoUtilsTest, ParseExternalDataInfoOffsetAndLength) {
|
||||
TestExternalDataInfoParsingOffsetAndLength(0, 0);
|
||||
|
||||
TestExternalDataInfoParsingOffsetAndLength(0, 1024);
|
||||
TestExternalDataInfoParsingOffsetAndLength(0, std::numeric_limits<size_t>::max());
|
||||
|
||||
TestExternalDataInfoParsingOffsetAndLength(1024, 1024);
|
||||
TestExternalDataInfoParsingOffsetAndLength(std::numeric_limits<ExternalDataInfo::OFFSET_TYPE>::max(), 1024);
|
||||
|
||||
{
|
||||
// assuming that this value is too large to fit in either size_t or ExternalDataInfo::OFFSET_TYPE
|
||||
const std::string_view two_to_the_65th_power = "36893488147419103232";
|
||||
const std::string_view zero = "0";
|
||||
TestExternalDataInfoParsingOffsetAndLengthWithStrings(two_to_the_65th_power, zero, "Failed to parse value");
|
||||
TestExternalDataInfoParsingOffsetAndLengthWithStrings(zero, two_to_the_65th_power, "Failed to parse value");
|
||||
}
|
||||
|
||||
// TODO should ExternalDataInfo::Create() also reject negative offset values?
|
||||
}
|
||||
|
||||
// Test ExternalData functionality
|
||||
TEST(TensorProtoUtilsTest, SetExternalDataInformation) {
|
||||
ONNX_NAMESPACE::TensorProto tensor_proto;
|
||||
|
|
|
|||
Loading…
Reference in a new issue