mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Fix C# handling of unicode strings (#2697)
* Fix C# handling of unicode strings * more tests * check for handle before freesing * variable reuse efficiency * refactor and cleanup utf8 o utf16 conversion block
This commit is contained in:
parent
233bdd268b
commit
64112db346
4 changed files with 62 additions and 36 deletions
|
|
@ -218,12 +218,31 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
));
|
||||
|
||||
// fill the native tensor, using GetValue(index) from the Tensor<string>
|
||||
string[] stringsInTensor = new string[tensorValue.Length];
|
||||
for (int i = 0; i < tensorValue.Length; i++)
|
||||
var len = tensorValue.Length;
|
||||
var stringsInTensor = new IntPtr[len];
|
||||
var pinnedHandles = new GCHandle[len + 1];
|
||||
pinnedHandles[len] = GCHandle.Alloc(stringsInTensor, GCHandleType.Pinned);
|
||||
try
|
||||
{
|
||||
stringsInTensor[i] = tensorValue.GetValue(i);
|
||||
for (int i = 0; i < len; i++)
|
||||
{
|
||||
var utf8str = UTF8Encoding.UTF8.GetBytes(tensorValue.GetValue(i) + "\0");
|
||||
pinnedHandles[i] = GCHandle.Alloc(utf8str, GCHandleType.Pinned);
|
||||
stringsInTensor[i] = pinnedHandles[i].AddrOfPinnedObject();
|
||||
}
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(nativeTensor, stringsInTensor, (UIntPtr)len));
|
||||
}
|
||||
finally
|
||||
{
|
||||
foreach (var handle in pinnedHandles)
|
||||
{
|
||||
if (handle.IsAllocated)
|
||||
{
|
||||
handle.Free();
|
||||
}
|
||||
}
|
||||
}
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(nativeTensor, stringsInTensor, (UIntPtr)tensorValue.Length));
|
||||
}
|
||||
catch (OnnxRuntimeException e)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -591,7 +591,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// \param len total data length, not including the trailing '\0' chars.
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtFillStringTensor(
|
||||
IntPtr /* OrtValue */ value,
|
||||
string[] /* const char* const* */s,
|
||||
IntPtr[] /* const char* const* */s,
|
||||
UIntPtr /* size_t */ s_len);
|
||||
public static DOrtFillStringTensor OrtFillStringTensor;
|
||||
|
||||
|
|
|
|||
|
|
@ -79,40 +79,38 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
var offsets = new UIntPtr[_elementCount];
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorDataLength(_onnxValueHandle, out strLen));
|
||||
var dataBuffer = new byte[strLen.ToUInt64()];
|
||||
var dataBufferMemory = new Memory<byte>(dataBuffer);
|
||||
var dataBufferHandle = dataBufferMemory.Pin();
|
||||
IntPtr dataBufferPointer = IntPtr.Zero;
|
||||
|
||||
var offsetMemory = new Memory<UIntPtr>(offsets);
|
||||
var offsetMemoryHandle = offsetMemory.Pin();
|
||||
IntPtr offsetBufferPointer = IntPtr.Zero;
|
||||
unsafe
|
||||
using (var dataBufferHandle = new Memory<byte>(dataBuffer).Pin())
|
||||
using (var offsetMemoryHandle = new Memory<UIntPtr>(offsets).Pin())
|
||||
{
|
||||
dataBufferPointer = (IntPtr)dataBufferHandle.Pointer;
|
||||
offsetBufferPointer = (IntPtr)offsetMemoryHandle.Pointer;
|
||||
}
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorContent(_onnxValueHandle, dataBufferPointer, strLen, offsetBufferPointer, (UIntPtr)_elementCount));
|
||||
_dataBufferPointer = dataBufferPointer;
|
||||
_dataBufferAsString = new string[_elementCount];
|
||||
unsafe
|
||||
{
|
||||
_dataBufferPointer = (IntPtr)dataBufferHandle.Pointer;
|
||||
NativeApiStatus.VerifySuccess(
|
||||
NativeMethods.OrtGetStringTensorContent(
|
||||
_onnxValueHandle, _dataBufferPointer, strLen,
|
||||
(IntPtr)offsetMemoryHandle.Pointer,
|
||||
(UIntPtr)_elementCount));
|
||||
}
|
||||
_dataBufferAsString = new string[_elementCount];
|
||||
|
||||
for (var i = 0; i < offsets.Length; i++)
|
||||
{
|
||||
var length = (i == offsets.Length - 1)
|
||||
? strLen.ToUInt64() - offsets[i].ToUInt64()
|
||||
: offsets[i + 1].ToUInt64() - offsets[i].ToUInt64();
|
||||
// Onnx specifies strings always in UTF-8, no trailing null, no leading BOM
|
||||
_dataBufferAsString[i] = Encoding.UTF8.GetString(dataBuffer, (int)offsets[i], (int)length);
|
||||
for (var i = 0; i < offsets.Length; i++)
|
||||
{
|
||||
var length = (i == offsets.Length - 1)
|
||||
? strLen.ToUInt64() - offsets[i].ToUInt64()
|
||||
: offsets[i + 1].ToUInt64() - offsets[i].ToUInt64();
|
||||
// Onnx specifies strings always in UTF-8, no trailing null, no leading BOM
|
||||
_dataBufferAsString[i] = Encoding.UTF8.GetString(dataBuffer, (int)offsets[i], (int)length);
|
||||
}
|
||||
}
|
||||
|
||||
// unpin memory
|
||||
offsetMemoryHandle.Dispose();
|
||||
dataBufferHandle.Dispose();
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
//TODO: cleanup any partially created state
|
||||
//Do not call ReleaseTensor here. If the constructor has thrown exception, then this NativeOnnxTensorWrapper is not created, so caller should take appropriate action to dispose
|
||||
//Do not call ReleaseTensor here. If the constructor has thrown exception,
|
||||
//then this NativeOnnxTensorWrapper is not created, so caller should take
|
||||
//appropriate action to dispose
|
||||
throw e;
|
||||
}
|
||||
finally
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using System;
|
||||
|
|
@ -601,7 +601,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
option.RegisterCustomOpLibrary(libFullPath);
|
||||
}
|
||||
catch(Exception ex)
|
||||
catch (Exception ex)
|
||||
{
|
||||
var msg = $"Failed to load custom op library {libFullPath}, error = {ex.Message}";
|
||||
throw new Exception(msg + "\n" + ex.StackTrace);
|
||||
|
|
@ -619,7 +619,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
|
||||
11.1f, 12.2f, 13.3f, 14.4f, 15.5f
|
||||
},
|
||||
new int[]{3, 5 }
|
||||
new int[] { 3, 5 }
|
||||
)));
|
||||
|
||||
inputContainer.Add(NamedOnnxValue.CreateFromTensor<float>("input_2",
|
||||
|
|
@ -645,12 +645,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
17, 18, 18, 18, 17,
|
||||
17, 17, 17, 17, 17
|
||||
},
|
||||
new int[] { 3, 5}
|
||||
new int[] { 3, 5 }
|
||||
);
|
||||
Assert.True(tensorOut.SequenceEqual(expectedOut));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
@ -773,7 +773,16 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var container = new List<NamedOnnxValue>();
|
||||
var tensorIn = new DenseTensor<string>(new string[] { "abc", "ced", "def", "", "frozen" }, new int[] { 1, 5 });
|
||||
var tensorIn = new DenseTensor<string>(new string[] {
|
||||
"hello",
|
||||
"École élémentaire",
|
||||
"mit freundlichen grüßen",
|
||||
"Понедельник",
|
||||
"最好的问候,"+
|
||||
"नमस्ते," +
|
||||
"こんにちは," +
|
||||
"안녕하세요"
|
||||
}, new int[] { 1, 5 });
|
||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||
container.Add(nov);
|
||||
using (var res = session.Run(container))
|
||||
|
|
|
|||
Loading…
Reference in a new issue