mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
C# training api updates for on device training (#15720)
This commit is contained in:
parent
c10a6a9d17
commit
bb33285ec2
5 changed files with 556 additions and 52 deletions
|
|
@ -19,21 +19,32 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates CheckpointState by loading state from path.
|
||||
/// <param name="checkpointPath"> absolute path to checkpoint file.</param>
|
||||
/// </summary>
|
||||
public CheckpointState(string checkpointPath)
|
||||
: base(IntPtr.Zero, true)
|
||||
private CheckpointState(IntPtr checkpointHandle)
|
||||
: base(checkpointHandle, true)
|
||||
{
|
||||
if (NativeTrainingMethods.TrainingEnabled())
|
||||
}
|
||||
|
||||
internal enum PropertyType : long
|
||||
{
|
||||
Int = 0,
|
||||
Float = 1,
|
||||
String = 2
|
||||
}
|
||||
|
||||
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
T[] value = new T[1];
|
||||
value[0] = propertyValue;
|
||||
Memory<T> memory = value;
|
||||
using (var memHandle = memory.Pin())
|
||||
{
|
||||
var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
|
||||
LoadCheckpoint(checkpointPath);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
|
||||
IntPtr memPtr;
|
||||
unsafe
|
||||
{
|
||||
memPtr = (IntPtr)memHandle.Pointer;
|
||||
}
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -47,9 +58,18 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// Loads Checkpoint state from path
|
||||
/// </summary>
|
||||
/// <param name="checkpointPath"> absolute path to checkpoint</param>
|
||||
private void LoadCheckpoint(string checkpointPath)
|
||||
public static CheckpointState LoadCheckpoint(string checkpointPath)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle));
|
||||
if (!NativeTrainingMethods.TrainingEnabled())
|
||||
{
|
||||
throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
|
||||
}
|
||||
|
||||
var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
|
||||
IntPtr checkpointHandle = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out checkpointHandle));
|
||||
|
||||
return new CheckpointState(checkpointHandle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -57,9 +77,83 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <param name="checkpointPath"> absolute path to the checkpoint file.</param>
|
||||
/// <param name="includeOptimizerState"> absolute path to the checkpoint file.</param>
|
||||
/// </summary>
|
||||
public void SaveCheckpoint(string checkpointPath, bool includeOptimizerState = false)
|
||||
public static void SaveCheckpoint(CheckpointState state, string checkpointPath, bool includeOptimizerState = false)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(state.Handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given int property to the checkpoint state.
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
/// </summary>
|
||||
public void AddProperty(string propertyName, long propertyValue)
|
||||
{
|
||||
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given float property to the checkpoint state.
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
/// </summary>
|
||||
public void AddProperty(string propertyName, float propertyValue)
|
||||
{
|
||||
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given string property to the checkpoint state.
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
/// </summary>
|
||||
public void AddProperty(string propertyName, string propertyValue)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
|
||||
|
||||
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
|
||||
try
|
||||
{
|
||||
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
|
||||
}
|
||||
finally
|
||||
{
|
||||
Marshal.FreeHGlobal(unmanagedPointer);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the property value associated with the given name from the checkpoint state.
|
||||
/// <param name="propertyName">Unique name of the property being retrieved.</param>
|
||||
/// </summary>
|
||||
public object GetProperty(string propertyName)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var allocator = OrtAllocator.DefaultInstance;
|
||||
IntPtr propertyValue = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
|
||||
|
||||
if (propertyType == PropertyType.Int)
|
||||
{
|
||||
var longPropertyValue = Marshal.ReadInt64(propertyValue);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return longPropertyValue;
|
||||
}
|
||||
else if (propertyType == PropertyType.Float)
|
||||
{
|
||||
float[] value = new float[1];
|
||||
Marshal.Copy(propertyValue, value, 0, 1);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return value[0];
|
||||
}
|
||||
else if (propertyType == PropertyType.String)
|
||||
{
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
|
||||
}
|
||||
|
||||
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
|
||||
}
|
||||
|
||||
#region SafeHandle
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep));
|
||||
OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler));
|
||||
OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep));
|
||||
OrtGetParametersSize = (DOrtGetParametersSize)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParametersSize, typeof(DOrtGetParametersSize));
|
||||
OrtCopyParametersToBuffer = (DOrtCopyParametersToBuffer)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyParametersToBuffer, typeof(DOrtCopyParametersToBuffer));
|
||||
OrtCopyBufferToParameters = (DOrtCopyBufferToParameters)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyBufferToParameters, typeof(DOrtCopyBufferToParameters));
|
||||
OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession));
|
||||
OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState));
|
||||
OrtExportModelForInferencing = (DOrtExportModelForInferencing)Marshal.GetDelegateForFunctionPointer(trainingApi_.ExportModelForInferencing, typeof(DOrtExportModelForInferencing));
|
||||
|
|
@ -248,6 +251,30 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
);
|
||||
public static DOrtSchedulerStep OrtSchedulerStep;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParametersSize(
|
||||
IntPtr /*(OrtTrainingSession*)*/ session,
|
||||
out UIntPtr buffer_size,
|
||||
bool only_trainable
|
||||
);
|
||||
public static DOrtGetParametersSize OrtGetParametersSize;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyParametersToBuffer(
|
||||
IntPtr /*(OrtTrainingSession*)*/ session,
|
||||
IntPtr /*(OrtValue*)*/ buffer,
|
||||
bool only_trainable
|
||||
);
|
||||
public static DOrtCopyParametersToBuffer OrtCopyParametersToBuffer;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyBufferToParameters(
|
||||
IntPtr /*(OrtTrainingSession*)*/ session,
|
||||
IntPtr /*(OrtValue*)*/ buffer,
|
||||
bool only_trainable
|
||||
);
|
||||
public static DOrtCopyBufferToParameters OrtCopyBufferToParameters;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtTrainingSession*)*/session);
|
||||
public static DOrtReleaseTrainingSession OrtReleaseTrainingSession;
|
||||
|
|
@ -312,8 +339,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddProperty(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
IntPtr /*(const char*)*/ propertyName,
|
||||
OrtPropertyType propertyType,
|
||||
byte[] /*(const char*)*/ propertyName,
|
||||
CheckpointState.PropertyType propertyType,
|
||||
IntPtr /*(const void*)*/ propertyValue
|
||||
);
|
||||
|
||||
|
|
@ -322,9 +349,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetProperty(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
IntPtr /*(const char*)*/ propertyName,
|
||||
byte[] /*(const char*)*/ propertyName,
|
||||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out OrtPropertyType propertyType,
|
||||
out CheckpointState.PropertyType propertyType,
|
||||
out IntPtr /*(const void**)*/ propertyValue
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
#if __ENABLE_TRAINING_APIS__
|
||||
/// <summary>
|
||||
/// Property types
|
||||
/// </summary>
|
||||
public enum OrtPropertyType
|
||||
{
|
||||
OrtIntProperty = 0,
|
||||
OrtFloatProperty = 1,
|
||||
OrtStringProperty = 2,
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -9,12 +9,29 @@ using System.Runtime.InteropServices;
|
|||
namespace Microsoft.ML.OnnxRuntime
|
||||
{
|
||||
#if __ENABLE_TRAINING_APIS__
|
||||
/// <summary>
|
||||
/// This class defines utility methods for training.
|
||||
/// </summary>
|
||||
public class TrainingUtils
|
||||
{
|
||||
/// <summary>
|
||||
/// Use this function to generate reproducible results. It should be noted that completely
|
||||
/// reproducible results are not guaranteed.
|
||||
/// </summary>
|
||||
/// <param name="seed">Manual seed to use for random number generation.</param>
|
||||
public static void SetSeed(long seed)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetSeed(seed));
|
||||
}
|
||||
}
|
||||
|
||||
enum LRScheduler
|
||||
{
|
||||
None = 0,
|
||||
Constant = 1,
|
||||
Linear = 2
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents a Training Session on an ONNX Model.
|
||||
/// This is a IDisposable class and it must be disposed of
|
||||
|
|
@ -34,6 +51,8 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
private ulong _evalOutputCount;
|
||||
private List<string> _trainOutputNames;
|
||||
private List<string> _evalOutputNames;
|
||||
private List<string> _trainInputNames;
|
||||
private List<string> _evalInputNames;
|
||||
|
||||
private SessionOptions _builtInSessionOptions = null;
|
||||
private RunOptions _builtInRunOptions = null;
|
||||
|
|
@ -240,7 +259,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
|
||||
|
||||
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
|
||||
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
|
||||
}
|
||||
|
||||
|
|
@ -319,6 +338,110 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Export a model that can be used for inferencing.
|
||||
/// If the training session was provided with an eval model, the training session can generate
|
||||
/// an inference model if it knows the inference graph outputs. The input inference graph outputs
|
||||
/// are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
||||
/// The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
|
||||
/// Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
|
||||
/// and expects that this path still be valid.
|
||||
/// </summary>
|
||||
/// <param name="inference_model_path">Path where the inference model should be serialized to.</param>
|
||||
/// <param name="graphOutputNames">Names of the outputs that are needed in the inference model.</param>
|
||||
public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection<string> graphOutputNames)
|
||||
{
|
||||
using (var cleanupList = new DisposableList<IDisposable>())
|
||||
{
|
||||
var outputNamesArray = ConvertNamesToUtf8(graphOutputNames, cleanupList);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtExportModelForInferencing(
|
||||
_nativeHandle, NativeOnnxValueHelper.GetPlatformSerializedString(inferenceModelPath),
|
||||
(UIntPtr)graphOutputNames.Count, outputNamesArray));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a contiguous buffer that holds a copy of all training state parameters
|
||||
/// </summary>
|
||||
/// <param name="only_trainable">Whether to only copy trainable parameters or to copy all parameters.</param>
|
||||
public FixedBufferOnnxValue ToBuffer(bool only_trainable)
|
||||
{
|
||||
UIntPtr bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, only_trainable));
|
||||
|
||||
float[] bufferMemory = new float[bufferSize.ToUInt64()];
|
||||
|
||||
var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
|
||||
var shape = new long[] {(long)bufferSize.ToUInt64()};
|
||||
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, only_trainable));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads the training session model parameters from a contiguous buffer
|
||||
/// </summary>
|
||||
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
|
||||
public void FromBuffer(FixedBufferOnnxValue buffer)
|
||||
{
|
||||
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
{
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
|
||||
}
|
||||
|
||||
IntPtr typeAndShapeInfo = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
|
||||
UIntPtr numDimensions = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
|
||||
if (numDimensions.ToUInt64() != 1)
|
||||
{
|
||||
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
IntPtr numElementsTrainingOnly = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
|
||||
|
||||
UIntPtr bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
|
||||
if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
|
||||
return;
|
||||
}
|
||||
|
||||
IntPtr numElements = IntPtr.Zero;
|
||||
bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
|
||||
if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
|
||||
{
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves the names of the user outputs for the training and eval models.
|
||||
/// </summary>
|
||||
/// <param name="training">Whether the training model output names are requested or eval model output names.</param>
|
||||
public List<string> OutputNames(bool training)
|
||||
{
|
||||
return training ? _trainOutputNames : _evalOutputNames;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves the names of the user inputs for the training and eval models.
|
||||
/// </summary>
|
||||
/// <param name="training">Whether the training model input names are requested or eval model input names.</param>
|
||||
public List<string> InputNames(bool training)
|
||||
{
|
||||
return training ? _trainInputNames : _evalInputNames;
|
||||
}
|
||||
|
||||
#endregion
|
||||
#region private methods
|
||||
|
||||
|
|
@ -326,7 +449,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
if (!NativeTrainingMethods.TrainingEnabled())
|
||||
{
|
||||
throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
|
||||
throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
|
||||
}
|
||||
var options = sessOptions;
|
||||
if (sessOptions == null)
|
||||
|
|
@ -351,6 +474,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
_trainOutputNames.Add(GetOutputName(i, true));
|
||||
}
|
||||
|
||||
_trainInputNames = new List<string>();
|
||||
UIntPtr inputCount = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputCount(_nativeHandle, out inputCount));
|
||||
for (ulong i = 0; i < inputCount.ToUInt64(); i++)
|
||||
{
|
||||
_trainInputNames.Add(GetInputName(i, true));
|
||||
}
|
||||
|
||||
if (evalModelPath != null)
|
||||
{
|
||||
outputCount = UIntPtr.Zero;
|
||||
|
|
@ -361,6 +492,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
_evalOutputNames.Add(GetOutputName(i, false));
|
||||
}
|
||||
|
||||
_evalInputNames = new List<string>();
|
||||
inputCount = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputCount(_nativeHandle, out inputCount));
|
||||
for (ulong i = 0; i < inputCount.ToUInt64(); i++)
|
||||
{
|
||||
_evalInputNames.Add(GetInputName(i, false));
|
||||
}
|
||||
}
|
||||
|
||||
_builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call
|
||||
|
|
@ -395,6 +534,29 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
|
||||
}
|
||||
|
||||
private string GetInputName(ulong index, bool training)
|
||||
{
|
||||
var allocator = OrtAllocator.DefaultInstance;
|
||||
IntPtr nameHandle;
|
||||
if (training)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputName(
|
||||
_nativeHandle,
|
||||
(UIntPtr)index,
|
||||
allocator.Pointer,
|
||||
out nameHandle));
|
||||
}
|
||||
else
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputName(
|
||||
_nativeHandle,
|
||||
(UIntPtr)index,
|
||||
allocator.Pointer,
|
||||
out nameHandle));
|
||||
}
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
|
||||
}
|
||||
|
||||
private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> values, bool input)
|
||||
{
|
||||
var valuesArray = new IntPtr[values.Count];
|
||||
|
|
@ -410,6 +572,24 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return valuesArray;
|
||||
}
|
||||
|
||||
private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection<string> names, DisposableList<IDisposable> cleanupList)
|
||||
{
|
||||
cleanupList.Capacity += names.Count;
|
||||
var result = new IntPtr[names.Count];
|
||||
for (int i = 0; i < names.Count; ++i)
|
||||
{
|
||||
var name = names.ElementAt(i);
|
||||
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
|
||||
var pinnedHandle = new Memory<byte>(utf8Name).Pin();
|
||||
unsafe
|
||||
{
|
||||
result[i] = (IntPtr)pinnedHandle.Pointer;
|
||||
}
|
||||
cleanupList.Add(pinnedHandle);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Other classes access
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestLoadCheckpointThrows()
|
||||
{
|
||||
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
var ex = Assert.Throws<InvalidOperationException>(() => { var opt = new CheckpointState(path); });
|
||||
Assert.Contains("Training is disabled in the current build.", ex.Message);
|
||||
var ex = Assert.Throws<InvalidOperationException>(() => { var opt = CheckpointState.LoadCheckpoint(path); });
|
||||
Assert.Contains("Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.", ex.Message);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestLoadCheckpoint()
|
||||
{
|
||||
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var opt = new CheckpointState(path))
|
||||
using (var opt = CheckpointState.LoadCheckpoint(path))
|
||||
{
|
||||
Assert.NotNull(opt);
|
||||
}
|
||||
|
|
@ -50,7 +50,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -66,7 +66,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -149,7 +149,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -167,7 +167,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -176,10 +176,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
|
||||
// Save checkpoint
|
||||
string savedCheckpointPath = Path.Combine(Directory.GetCurrentDirectory(), "saved_checkpoint.ckpt");
|
||||
state.SaveCheckpoint(savedCheckpointPath, true);
|
||||
CheckpointState.SaveCheckpoint(state, savedCheckpointPath, true);
|
||||
|
||||
// Load checkpoint and run train step
|
||||
var loadedState = new CheckpointState(savedCheckpointPath);
|
||||
var loadedState = CheckpointState.LoadCheckpoint(savedCheckpointPath);
|
||||
cleanUp.Add(loadedState);
|
||||
var newTrainingSession = new TrainingSession(loadedState, trainingPath);
|
||||
cleanUp.Add(newTrainingSession);
|
||||
|
|
@ -193,7 +193,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -252,7 +252,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -274,7 +274,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = new CheckpointState(checkpointPath);
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
|
|
@ -301,6 +301,226 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestTrainingSessionExportModelForInferencing")]
|
||||
public void TestTrainingSessionExportModelForInferencing()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var graphOutputs = new List<string>(){"output-0"};
|
||||
|
||||
string inferencePath = Path.Combine(Directory.GetCurrentDirectory(), "inference_model.onnx");
|
||||
|
||||
trainingSession.ExportModelForInferencing(inferencePath, graphOutputs);
|
||||
Assert.True(File.Exists(inferencePath));
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestCheckpointStateAddProperty")]
|
||||
public void TestCheckpointStateAddProperty()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
|
||||
string propertyName = "days in a week";
|
||||
state.AddProperty(propertyName, (long)7);
|
||||
|
||||
var value = state.GetProperty(propertyName);
|
||||
Assert.True(value is long);
|
||||
Assert.Equal((long)7, value);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestCheckpointStateAddFloatProperty")]
|
||||
public void TestCheckpointStateAddFloatProperty()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
|
||||
string propertyName = "pi";
|
||||
state.AddProperty(propertyName, (float)3.14);
|
||||
|
||||
var value = state.GetProperty(propertyName);
|
||||
Assert.True(value is float);
|
||||
Assert.Equal((float)3.14, value);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestCheckpointStateAddStringProperty")]
|
||||
public void TestCheckpointStateAddStringProperty()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
|
||||
string propertyName = "best ai framework";
|
||||
state.AddProperty(propertyName, "onnxruntime");
|
||||
|
||||
var value = state.GetProperty(propertyName);
|
||||
Assert.True(value is string);
|
||||
Assert.Equal("onnxruntime", value);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestTrainModelInputNames")]
|
||||
public void TestTrainModelInputNames()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
var trainingSession = new TrainingSession(state, trainingPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var inputNames = trainingSession.InputNames(true);
|
||||
|
||||
Assert.True(inputNames.Count == 2);
|
||||
Assert.Equal("input-0", inputNames[0]);
|
||||
Assert.Equal("labels", inputNames[1]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestEvalModelInputNames")]
|
||||
public void TestEvalModelInputNames()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var inputNames = trainingSession.InputNames(false);
|
||||
|
||||
Assert.True(inputNames.Count == 2);
|
||||
Assert.Equal("input-0", inputNames[0]);
|
||||
Assert.Equal("labels", inputNames[1]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestTrainModelOutputNames")]
|
||||
public void TestTrainModelOutputNames()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
var trainingSession = new TrainingSession(state, trainingPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var outputNames = trainingSession.OutputNames(true);
|
||||
|
||||
Assert.Single(outputNames);
|
||||
Assert.Equal("onnx::loss::21273", outputNames[0]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestEvalModelOutputNames")]
|
||||
public void TestEvalModelOutputNames()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var outputNames = trainingSession.OutputNames(false);
|
||||
|
||||
Assert.Single(outputNames);
|
||||
Assert.Equal("onnx::loss::21273", outputNames[0]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestToBuffer")]
|
||||
public void TestToBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestFromBuffer")]
|
||||
public void TestFromBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
|
||||
trainingSession.FromBuffer(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestSetSeed")]
|
||||
public void TestSetSeed()
|
||||
{
|
||||
TrainingUtils.SetSeed(8888);
|
||||
}
|
||||
|
||||
internal class FloatComparer : IEqualityComparer<float>
|
||||
{
|
||||
private float atol = 1e-3f;
|
||||
|
|
|
|||
Loading…
Reference in a new issue