Fix C# handling of unicode strings (#2697)

* Fix C# handling of unicode strings

* more tests

* check for handle before freesing

* variable reuse efficiency

* refactor and cleanup utf8 o utf16 conversion block
This commit is contained in:
jignparm 2019-12-19 21:02:54 -08:00 committed by GitHub
parent 233bdd268b
commit 64112db346
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 36 deletions

View file

@ -218,12 +218,31 @@ namespace Microsoft.ML.OnnxRuntime
));
// fill the native tensor, using GetValue(index) from the Tensor<string>
string[] stringsInTensor = new string[tensorValue.Length];
for (int i = 0; i < tensorValue.Length; i++)
var len = tensorValue.Length;
var stringsInTensor = new IntPtr[len];
var pinnedHandles = new GCHandle[len + 1];
pinnedHandles[len] = GCHandle.Alloc(stringsInTensor, GCHandleType.Pinned);
try
{
stringsInTensor[i] = tensorValue.GetValue(i);
for (int i = 0; i < len; i++)
{
var utf8str = UTF8Encoding.UTF8.GetBytes(tensorValue.GetValue(i) + "\0");
pinnedHandles[i] = GCHandle.Alloc(utf8str, GCHandleType.Pinned);
stringsInTensor[i] = pinnedHandles[i].AddrOfPinnedObject();
}
NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(nativeTensor, stringsInTensor, (UIntPtr)len));
}
finally
{
foreach (var handle in pinnedHandles)
{
if (handle.IsAllocated)
{
handle.Free();
}
}
}
NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(nativeTensor, stringsInTensor, (UIntPtr)tensorValue.Length));
}
catch (OnnxRuntimeException e)
{

View file

@ -591,7 +591,7 @@ namespace Microsoft.ML.OnnxRuntime
/// \param len total data length, not including the trailing '\0' chars.
public delegate IntPtr /*(OrtStatus*)*/ DOrtFillStringTensor(
IntPtr /* OrtValue */ value,
string[] /* const char* const* */s,
IntPtr[] /* const char* const* */s,
UIntPtr /* size_t */ s_len);
public static DOrtFillStringTensor OrtFillStringTensor;

View file

@ -79,40 +79,38 @@ namespace Microsoft.ML.OnnxRuntime
var offsets = new UIntPtr[_elementCount];
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorDataLength(_onnxValueHandle, out strLen));
var dataBuffer = new byte[strLen.ToUInt64()];
var dataBufferMemory = new Memory<byte>(dataBuffer);
var dataBufferHandle = dataBufferMemory.Pin();
IntPtr dataBufferPointer = IntPtr.Zero;
var offsetMemory = new Memory<UIntPtr>(offsets);
var offsetMemoryHandle = offsetMemory.Pin();
IntPtr offsetBufferPointer = IntPtr.Zero;
unsafe
using (var dataBufferHandle = new Memory<byte>(dataBuffer).Pin())
using (var offsetMemoryHandle = new Memory<UIntPtr>(offsets).Pin())
{
dataBufferPointer = (IntPtr)dataBufferHandle.Pointer;
offsetBufferPointer = (IntPtr)offsetMemoryHandle.Pointer;
}
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorContent(_onnxValueHandle, dataBufferPointer, strLen, offsetBufferPointer, (UIntPtr)_elementCount));
_dataBufferPointer = dataBufferPointer;
_dataBufferAsString = new string[_elementCount];
unsafe
{
_dataBufferPointer = (IntPtr)dataBufferHandle.Pointer;
NativeApiStatus.VerifySuccess(
NativeMethods.OrtGetStringTensorContent(
_onnxValueHandle, _dataBufferPointer, strLen,
(IntPtr)offsetMemoryHandle.Pointer,
(UIntPtr)_elementCount));
}
_dataBufferAsString = new string[_elementCount];
for (var i = 0; i < offsets.Length; i++)
{
var length = (i == offsets.Length - 1)
? strLen.ToUInt64() - offsets[i].ToUInt64()
: offsets[i + 1].ToUInt64() - offsets[i].ToUInt64();
// Onnx specifies strings always in UTF-8, no trailing null, no leading BOM
_dataBufferAsString[i] = Encoding.UTF8.GetString(dataBuffer, (int)offsets[i], (int)length);
for (var i = 0; i < offsets.Length; i++)
{
var length = (i == offsets.Length - 1)
? strLen.ToUInt64() - offsets[i].ToUInt64()
: offsets[i + 1].ToUInt64() - offsets[i].ToUInt64();
// Onnx specifies strings always in UTF-8, no trailing null, no leading BOM
_dataBufferAsString[i] = Encoding.UTF8.GetString(dataBuffer, (int)offsets[i], (int)length);
}
}
// unpin memory
offsetMemoryHandle.Dispose();
dataBufferHandle.Dispose();
}
}
catch (Exception e)
{
//TODO: cleanup any partially created state
//Do not call ReleaseTensor here. If the constructor has thrown exception, then this NativeOnnxTensorWrapper is not created, so caller should take appropriate action to dispose
//Do not call ReleaseTensor here. If the constructor has thrown exception,
//then this NativeOnnxTensorWrapper is not created, so caller should take
//appropriate action to dispose
throw e;
}
finally

View file

@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
@ -601,7 +601,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
option.RegisterCustomOpLibrary(libFullPath);
}
catch(Exception ex)
catch (Exception ex)
{
var msg = $"Failed to load custom op library {libFullPath}, error = {ex.Message}";
throw new Exception(msg + "\n" + ex.StackTrace);
@ -619,7 +619,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
11.1f, 12.2f, 13.3f, 14.4f, 15.5f
},
new int[]{3, 5 }
new int[] { 3, 5 }
)));
inputContainer.Add(NamedOnnxValue.CreateFromTensor<float>("input_2",
@ -645,12 +645,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
17, 18, 18, 18, 17,
17, 17, 17, 17, 17
},
new int[] { 3, 5}
new int[] { 3, 5 }
);
Assert.True(tensorOut.SequenceEqual(expectedOut));
}
}
}
}
}
[Fact]
@ -773,7 +773,16 @@ namespace Microsoft.ML.OnnxRuntime.Tests
using (var session = new InferenceSession(modelPath))
{
var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<string>(new string[] { "abc", "ced", "def", "", "frozen" }, new int[] { 1, 5 });
var tensorIn = new DenseTensor<string>(new string[] {
"hello",
"École élémentaire",
"mit freundlichen grüßen",
"Понедельник",
"最好的问候,"+
"नमस्ते," +
"こんにちは," +
"안녕하세요"
}, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov);
using (var res = session.Run(container))