diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc index 9a6f8f3b7b..7986082da0 100644 --- a/onnxruntime/lora/adapter_format_utils.cc +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -6,6 +6,8 @@ #include "core/framework/allocator.h" #include "core/common/common.h" +#include "core/framework/endian.h" +#include "core/framework/endian_utils.h" #include "core/common/span_utils.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" @@ -75,17 +77,75 @@ const Adapter* ValidateAndGetAdapterFromBytes(gsl::span bytes) { return adapter; } +template +struct WriteDataForLittleEndian { + Status operator()(gsl::span src, gsl::span dest) const { + auto src_span = ReinterpretAsSpan(src); + return onnxruntime::utils::WriteLittleEndian(src_span, dest); + } +}; + void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name, TensorDataType data_type, gsl::span shape, gsl::span data, flatbuffers::Offset& fbs_tensor) { auto name_str = (name.empty()) ? 0 : flat_builder.CreateString(name.data(), name.size()); auto shape_vec = flat_builder.CreateVector(shape.data(), shape.size()); - auto data_vec = flat_builder.CreateVector(data.data(), data.size()); + flatbuffers::Offset> data_vec; + if constexpr (endian::native == endian::big) { + const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); + if (elem_type->Size() > 1) { + InlinedVector be_data(data.size()); + auto be_data_span = ReinterpretAsSpan(AsSpan(be_data)); + + onnxruntime::utils::MLTypeCallDispatcher + disp(static_cast(data_type)); + + ORT_THROW_IF_ERROR((disp.InvokeRet(data, be_data_span))); + data_vec = flat_builder.CreateVector(be_data.data(), be_data.size()); + } else { + data_vec = flat_builder.CreateVector(data.data(), data.size()); + } + } else { + data_vec = flat_builder.CreateVector(data.data(), data.size()); + } fbs_tensor = CreateParameter(flat_builder, name_str, shape_vec, data_type, data_vec); } +template +struct ReadDataForBigEndian { + Status operator()(gsl::span src, Tensor& dst) const { + auto dst_span = dst.MutableDataAsSpan(); + return onnxruntime::utils::ReadLittleEndian(src, dst_span); + } +}; + +// If BE, we a allocate memory within the tensor and copy there swapping bytes +[[maybe_unused]] static Status CreateOrtValueForBePlatforms(const Parameter& param, const MLDataType elem_type, + gsl::span shape, OrtValue& result) { + static const AllocatorPtr cpu_allocator = std::make_shared(); + + auto src_span = ReinterpretAsSpan( + gsl::make_span(param.raw_data()->data(), param.raw_data()->size())); + + const auto data_type = param.data_type(); + + Tensor tensor(elem_type, shape, cpu_allocator); + onnxruntime::utils::MLTypeCallDispatcher + disp(static_cast(data_type)); + + ORT_RETURN_IF_ERROR((disp.InvokeRet(src_span, tensor))); + Tensor::InitOrtValue(std::move(tensor), result); + return Status::OK(); +} + std::pair CreateOrtValueOverLoraParameter(const Parameter& param) { OrtValue result; @@ -93,17 +153,32 @@ std::pair CreateOrtValueOverLoraParameter(const Parameter LoadStringFromLoraFormat(name, param.name()); const auto data_type = param.data_type(); - gsl::span shape_span(param.dims()->data(), param.dims()->size()); - + // Copying shape takes care of endianess using flatbuffers accessors + TensorShapeVector shape(param.dims()->begin(), param.dims()->end()); + const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator); - auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast(data_type))->GetElementType(); - // const_cast is necessery due to Tensor class API - Tensor::InitOrtValue(elem_type, - TensorShape(shape_span), - const_cast(param.raw_data()->data()), - cpu_meminfo, - result); + if constexpr (endian::native == endian::big) { + if (elem_type->Size() > 1) { + ORT_THROW_IF_ERROR(CreateOrtValueForBePlatforms(param, elem_type, shape, result)); + } else { + // Single byte elements allow us to create OrtValue directly on top + // of raw data + // const_cast is necessary due to Tensor class API + Tensor::InitOrtValue(elem_type, + TensorShape(shape), + const_cast(param.raw_data()->data()), + cpu_meminfo, + result); + } + } else { + // const_cast is necessary due to Tensor class API + Tensor::InitOrtValue(elem_type, + TensorShape(shape), + const_cast(param.raw_data()->data()), + cpu_meminfo, + result); + } return std::make_pair(std::move(name), std::move(result)); }