2018-11-20 00:48:22 +00:00
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System ;
using System.Collections.Generic ;
using System.Text ;
using System.Buffers ;
using System.Runtime.CompilerServices ;
using System.Runtime.InteropServices ;
using System.Threading ;
namespace Microsoft.ML.OnnxRuntime
{
internal class NativeOnnxTensorMemory < T > : MemoryManager < T >
{
private bool _disposed ;
private int _referenceCount ;
2019-03-06 00:00:40 +00:00
private IntPtr _onnxValueHandle ; // pointer to onnxvalue object in native
private IntPtr _dataBufferPointer ; // pointer to mutable tensor data in native memory
private string [ ] _dataBufferAsString ; // string tensor values copied into managed memory
2018-11-20 00:48:22 +00:00
private int _elementCount ;
private int _elementWidth ;
private int [ ] _dimensions ;
2019-11-08 22:52:56 +00:00
public NativeOnnxTensorMemory ( IntPtr onnxValueHandle )
2018-11-20 00:48:22 +00:00
{
IntPtr typeAndShape = IntPtr . Zero ;
try
{
Type type = null ;
int width = 0 ;
2019-03-06 00:00:40 +00:00
_onnxValueHandle = onnxValueHandle ;
2019-05-20 21:57:43 +00:00
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetTensorTypeAndShape ( onnxValueHandle , out typeAndShape ) ) ;
2019-06-11 01:36:04 +00:00
TensorElementType elemType ;
unsafe
{
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetTensorElementType ( typeAndShape , new IntPtr ( & elemType ) ) ) ;
}
2018-11-23 04:56:43 +00:00
TensorElementTypeConverter . GetTypeAndWidth ( elemType , out type , out width ) ;
2019-03-06 00:00:40 +00:00
2018-11-20 00:48:22 +00:00
if ( typeof ( T ) ! = type )
2019-03-06 00:00:40 +00:00
throw new NotSupportedException ( nameof ( NativeOnnxTensorMemory < T > ) + " does not support T = " + nameof ( T ) ) ;
2018-11-20 00:48:22 +00:00
2019-03-06 00:00:40 +00:00
_elementWidth = width ;
2019-06-11 01:36:04 +00:00
UIntPtr dimension ;
long count ;
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetDimensionsCount ( typeAndShape , out dimension ) ) ;
unsafe
{
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetTensorShapeElementCount ( typeAndShape , new IntPtr ( & count ) ) ) ; // count can be negative.
}
2018-11-20 00:48:22 +00:00
if ( count < 0 )
{
throw new NotSupportedException ( "Symbolic dimensions in the tensor is not supported" ) ;
}
2019-06-11 01:36:04 +00:00
long [ ] shape = new long [ dimension . ToUInt64 ( ) ] ;
unsafe
{
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetDimensions ( typeAndShape , shape , new UIntPtr ( & dimension ) ) ) ; //Note: shape must be alive during the call
}
2018-11-20 00:48:22 +00:00
_elementCount = ( int ) count ;
2019-06-11 01:36:04 +00:00
_dimensions = new int [ dimension . ToUInt64 ( ) ] ;
for ( ulong i = 0 ; i < dimension . ToUInt64 ( ) ; i + + )
2018-11-20 00:48:22 +00:00
{
_dimensions [ i ] = ( int ) shape [ i ] ;
}
2019-03-06 00:00:40 +00:00
2019-11-08 22:52:56 +00:00
if ( typeof ( T ) ! = typeof ( string ) )
2019-03-06 00:00:40 +00:00
{
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetTensorMutableData ( _onnxValueHandle , out _dataBufferPointer ) ) ;
}
else
{
2019-05-20 22:48:14 +00:00
UIntPtr strLen ;
var offsets = new UIntPtr [ _elementCount ] ;
2019-03-06 00:00:40 +00:00
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetStringTensorDataLength ( _onnxValueHandle , out strLen ) ) ;
2019-05-20 22:48:14 +00:00
var dataBuffer = new byte [ strLen . ToUInt64 ( ) ] ;
2019-03-06 00:00:40 +00:00
var dataBufferMemory = new Memory < byte > ( dataBuffer ) ;
var dataBufferHandle = dataBufferMemory . Pin ( ) ;
IntPtr dataBufferPointer = IntPtr . Zero ;
2019-05-20 22:48:14 +00:00
var offsetMemory = new Memory < UIntPtr > ( offsets ) ;
2019-03-06 00:00:40 +00:00
var offsetMemoryHandle = offsetMemory . Pin ( ) ;
IntPtr offsetBufferPointer = IntPtr . Zero ;
unsafe
{
dataBufferPointer = ( IntPtr ) dataBufferHandle . Pointer ;
offsetBufferPointer = ( IntPtr ) offsetMemoryHandle . Pointer ;
}
2019-05-20 22:48:14 +00:00
NativeApiStatus . VerifySuccess ( NativeMethods . OrtGetStringTensorContent ( _onnxValueHandle , dataBufferPointer , strLen , offsetBufferPointer , ( UIntPtr ) _elementCount ) ) ;
2019-03-06 00:00:40 +00:00
_dataBufferPointer = dataBufferPointer ;
_dataBufferAsString = new string [ _elementCount ] ;
for ( var i = 0 ; i < offsets . Length ; i + + )
{
var length = ( i = = offsets . Length - 1 )
2019-05-20 22:48:14 +00:00
? strLen . ToUInt64 ( ) - offsets [ i ] . ToUInt64 ( )
: offsets [ i + 1 ] . ToUInt64 ( ) - offsets [ i ] . ToUInt64 ( ) ;
2019-03-06 00:00:40 +00:00
// Onnx specifies strings always in UTF-8, no trailing null, no leading BOM
_dataBufferAsString [ i ] = Encoding . UTF8 . GetString ( dataBuffer , ( int ) offsets [ i ] , ( int ) length ) ;
}
// unpin memory
offsetMemoryHandle . Dispose ( ) ;
dataBufferHandle . Dispose ( ) ;
}
2018-11-20 00:48:22 +00:00
}
catch ( Exception e )
{
//TODO: cleanup any partially created state
//Do not call ReleaseTensor here. If the constructor has thrown exception, then this NativeOnnxTensorWrapper is not created, so caller should take appropriate action to dispose
throw e ;
}
finally
{
if ( typeAndShape ! = IntPtr . Zero )
{
2019-01-26 03:41:10 +00:00
NativeMethods . OrtReleaseTensorTypeAndShapeInfo ( typeAndShape ) ;
2018-11-20 00:48:22 +00:00
}
}
}
2019-03-06 00:00:40 +00:00
2018-11-20 00:48:22 +00:00
~ NativeOnnxTensorMemory ( )
{
Dispose ( false ) ;
}
2018-11-29 16:15:18 +00:00
public void Dispose ( )
{
GC . SuppressFinalize ( this ) ;
Dispose ( true ) ;
}
2018-11-20 00:48:22 +00:00
public bool IsDisposed = > _disposed ;
protected bool IsRetained = > _referenceCount > 0 ;
public int [ ] Dimensions
{
get
{
return _dimensions ;
}
}
public int Rank
{
get
{
return _dimensions . Length ;
}
}
2018-11-29 16:15:18 +00:00
public int Count
{
get
{
return _elementCount ;
}
}
public int ElementWidth
{
get
{
return _elementWidth ;
}
}
2018-11-20 00:48:22 +00:00
public override Span < T > GetSpan ( )
{
if ( IsDisposed )
throw new ObjectDisposedException ( nameof ( NativeOnnxTensorMemory < T > ) ) ;
Span < T > span = null ;
unsafe
{
2019-03-06 00:00:40 +00:00
span = new Span < T > ( ( void * ) _dataBufferPointer , _elementCount ) ;
2018-11-20 00:48:22 +00:00
}
return span ;
}
2019-03-06 00:00:40 +00:00
public Memory < String > GetBytesAsStringMemory ( )
{
if ( IsDisposed )
throw new ObjectDisposedException ( nameof ( NativeOnnxTensorMemory < T > ) ) ;
2019-11-08 22:52:56 +00:00
if ( typeof ( T ) ! = typeof ( string ) )
2019-03-06 00:00:40 +00:00
throw new NotSupportedException ( nameof ( NativeOnnxTensorMemory < T > . GetBytesAsStringMemory ) + ": T must be byte" ) ;
return ( _dataBufferAsString = = null ) ? new Memory < string > ( ) : new Memory < string > ( _dataBufferAsString ) ;
}
2018-11-20 00:48:22 +00:00
public override MemoryHandle Pin ( int elementIndex = 0 )
{
//Note: always pin the full buffer and return
unsafe
{
if ( elementIndex > = _elementCount )
{
throw new ArgumentOutOfRangeException ( nameof ( elementIndex ) ) ;
}
Retain ( ) ;
2019-03-06 00:00:40 +00:00
return new MemoryHandle ( ( void * ) ( ( int ) _dataBufferPointer + elementIndex * _elementWidth ) ) ; //could not use Unsafe.Add
2018-11-20 00:48:22 +00:00
}
}
public override void Unpin ( )
{
Release ( ) ;
}
private bool Release ( )
{
int newRefCount = Interlocked . Decrement ( ref _referenceCount ) ;
if ( newRefCount < 0 )
{
throw new InvalidOperationException ( "Unmatched Release/Retain" ) ;
}
return newRefCount ! = 0 ;
}
private void Retain ( )
{
if ( IsDisposed )
{
throw new ObjectDisposedException ( nameof ( NativeOnnxTensorMemory < T > ) ) ;
}
Interlocked . Increment ( ref _referenceCount ) ;
}
protected override void Dispose ( bool disposing )
{
if ( _disposed )
{
return ;
}
if ( disposing )
{
// do managed objects cleanup
}
2018-12-18 19:39:46 +00:00
NativeMethods . OrtReleaseValue ( _onnxValueHandle ) ;
2018-11-20 00:48:22 +00:00
_disposed = true ;
}
protected override bool TryGetArray ( out ArraySegment < T > arraySegment )
{
// cannot expose managed array
arraySegment = default ( ArraySegment < T > ) ;
return false ;
}
}
}