// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Buffers; namespace Microsoft.ML.OnnxRuntime { /// /// Provides access from the underlying object that owns disposable OrtValue /// The returned value does not own the actual memory and does nothing on Dispose() /// internal interface IOrtValueOwner : IDisposable { OrtValue Value { get; } } /// /// This class is used in conjunction with DisposableNamedOnnxValue to /// own native collection OrtValue and dispose of it along with any DisposableNamedOnnxValues /// internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable { private OrtValue _ortValue; private IDisposable _disposables; bool _disposed = false; /// /// _Ctor. Takes ownership of ortValue and sets it to null on success. /// /// becomes null on success /// A collection of disposables that support composed types. /// We stick them here and dispose them when this instance is disposed. /// internal NativeOrtValueCollectionOwner(ref OrtValue ortValue, IDisposable disposables) { _ortValue = ortValue; ortValue = null; _disposables = disposables; } #region IOrtValueOwner /// /// Returns OrtValue that is owned by this instance /// public OrtValue Value { get { return _ortValue; } } #endregion IOrtValueOwner #region Disposable protected virtual void Dispose(bool disposing) { if (_disposed) { return; } // dispose managed state (managed objects). if (disposing) { if (_disposables != null) { _disposables.Dispose(); _disposables = null; } // _ortValueHolder can be null when no native memory is involved if (_ortValue != null) { _ortValue.Dispose(); _ortValue = null; } _disposed = true; } } public void Dispose() { // Do not change this code. Put cleanup code in Dispose(bool disposing) above. Dispose(true); GC.SuppressFinalize(this); } #endregion Disposable } /// /// This helper class owns the underlying OrtValue that is assumed to be a Tensor, /// it does not support any other ortValues and caches Tensor properties. /// /// It is easy to expose as a Tensor{T} as DenseTensor can take Memory Mapping from /// this. /// /// This class is disposable because of the MemoryManager inheritance. Because this class /// always backs exactly only one DenseTensor instance, it does /// not implement ref-counting for Pin/Unpin. /// /// internal class OrtValueTensor : MemoryManager, IOrtValueOwner { private OrtValue _ortValue; // Disposable private readonly IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory /// /// Constructs an instance and takes ownership of ortValue on success /// /// ortValue that is a Tensor. It becomes null on successful return. public OrtValueTensor(ref OrtValue ortValue) { var typeAndShapeInfo = ortValue.GetTensorTypeAndShape(); TensorElementType elemType = typeAndShapeInfo.ElementDataType; var typeInfo = TensorBase.GetElementTypeInfo(elemType) ?? throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Unable to query type information for data type: {elemType}"); if (typeof(T) != typeInfo.TensorType) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"The OrtValueTensor type being instantiated for T = [{typeof(T)}] while supplied OrtValue contains T = [{typeInfo.TensorType}]"); } ElementType = elemType; ElementWidth = typeInfo.TypeSize; Count = (int)typeAndShapeInfo.ElementCount; Dimensions = Array.ConvertAll(typeAndShapeInfo.Shape, Convert.ToInt32); NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(ortValue.Handle, out _dataBufferPointer)); // Transfer ownership _ortValue = ortValue; ortValue = null; } /// /// Returns OrtValue that is owned by this instance /// public OrtValue Value { get { return _ortValue; } } public bool IsDisposed { get; private set; } = false; public int[] Dimensions { get; } public int Rank => Dimensions.Length; public int Count { get; } public int ElementWidth { get; } public Tensors.TensorElementType ElementType { get; } /// /// Returns Span that is a view into the underlying native Tensor memory /// /// SpanT public override Span GetSpan() { Span span = null; unsafe { span = new Span((void*)_dataBufferPointer, Count); } return span; } /// /// Satisfy MemoryManager abstract implementation. /// /// required for override /// public override MemoryHandle Pin(int elementIndex = 0) { unsafe { if (elementIndex >= Count) { throw new ArgumentOutOfRangeException(nameof(elementIndex)); } return new MemoryHandle(new IntPtr(_dataBufferPointer.ToInt64() + (long)elementIndex * ElementWidth).ToPointer()); } } // MemoryHandle returned above by Pin() should be disposed. // Unpin() is purely to satisfy the interface. public override void Unpin() { } public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } protected override void Dispose(bool disposing) { if (IsDisposed) { return; } if (_ortValue != null) { _ortValue.Dispose(); _ortValue = null; } IsDisposed = true; } } }