C# training api updates for on device training (#15720)

This commit is contained in:
Baiju Meswani 2023-05-01 10:01:38 -07:00 committed by GitHub
parent c10a6a9d17
commit bb33285ec2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 556 additions and 52 deletions

View file

@ -19,21 +19,32 @@ namespace Microsoft.ML.OnnxRuntime
}
}
/// <summary>
/// Creates CheckpointState by loading state from path.
/// <param name="checkpointPath"> absolute path to checkpoint file.</param>
/// </summary>
public CheckpointState(string checkpointPath)
: base(IntPtr.Zero, true)
private CheckpointState(IntPtr checkpointHandle)
: base(checkpointHandle, true)
{
if (NativeTrainingMethods.TrainingEnabled())
}
internal enum PropertyType : long
{
Int = 0,
Float = 1,
String = 2
}
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory<T> memory = value;
using (var memHandle = memory.Pin())
{
var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
LoadCheckpoint(checkpointPath);
}
else
{
throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
IntPtr memPtr;
unsafe
{
memPtr = (IntPtr)memHandle.Pointer;
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}
@ -47,9 +58,18 @@ namespace Microsoft.ML.OnnxRuntime
/// Loads Checkpoint state from path
/// </summary>
/// <param name="checkpointPath"> absolute path to checkpoint</param>
private void LoadCheckpoint(string checkpointPath)
public static CheckpointState LoadCheckpoint(string checkpointPath)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle));
if (!NativeTrainingMethods.TrainingEnabled())
{
throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
}
var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
IntPtr checkpointHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out checkpointHandle));
return new CheckpointState(checkpointHandle);
}
/// <summary>
@ -57,9 +77,83 @@ namespace Microsoft.ML.OnnxRuntime
/// <param name="checkpointPath"> absolute path to the checkpoint file.</param>
/// <param name="includeOptimizerState"> absolute path to the checkpoint file.</param>
/// </summary>
public void SaveCheckpoint(string checkpointPath, bool includeOptimizerState = false)
public static void SaveCheckpoint(CheckpointState state, string checkpointPath, bool includeOptimizerState = false)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(state.Handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
}
/// <summary>
/// Adds the given int property to the checkpoint state.
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
/// </summary>
public void AddProperty(string propertyName, long propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
}
/// <summary>
/// Adds the given float property to the checkpoint state.
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
/// </summary>
public void AddProperty(string propertyName, float propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
}
/// <summary>
/// Adds the given string property to the checkpoint state.
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
/// </summary>
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
{
Marshal.FreeHGlobal(unmanagedPointer);
}
}
/// <summary>
/// Gets the property value associated with the given name from the checkpoint state.
/// <param name="propertyName">Unique name of the property being retrieved.</param>
/// </summary>
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
if (propertyType == PropertyType.Int)
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
}
else if (propertyType == PropertyType.Float)
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
}
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
#region SafeHandle

View file

@ -82,6 +82,9 @@ namespace Microsoft.ML.OnnxRuntime
OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep));
OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler));
OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep));
OrtGetParametersSize = (DOrtGetParametersSize)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParametersSize, typeof(DOrtGetParametersSize));
OrtCopyParametersToBuffer = (DOrtCopyParametersToBuffer)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyParametersToBuffer, typeof(DOrtCopyParametersToBuffer));
OrtCopyBufferToParameters = (DOrtCopyBufferToParameters)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyBufferToParameters, typeof(DOrtCopyBufferToParameters));
OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession));
OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState));
OrtExportModelForInferencing = (DOrtExportModelForInferencing)Marshal.GetDelegateForFunctionPointer(trainingApi_.ExportModelForInferencing, typeof(DOrtExportModelForInferencing));
@ -248,6 +251,30 @@ namespace Microsoft.ML.OnnxRuntime
);
public static DOrtSchedulerStep OrtSchedulerStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParametersSize(
IntPtr /*(OrtTrainingSession*)*/ session,
out UIntPtr buffer_size,
bool only_trainable
);
public static DOrtGetParametersSize OrtGetParametersSize;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyParametersToBuffer(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtValue*)*/ buffer,
bool only_trainable
);
public static DOrtCopyParametersToBuffer OrtCopyParametersToBuffer;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyBufferToParameters(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtValue*)*/ buffer,
bool only_trainable
);
public static DOrtCopyBufferToParameters OrtCopyBufferToParameters;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtTrainingSession*)*/session);
public static DOrtReleaseTrainingSession OrtReleaseTrainingSession;
@ -312,8 +339,8 @@ namespace Microsoft.ML.OnnxRuntime
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddProperty(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
IntPtr /*(const char*)*/ propertyName,
OrtPropertyType propertyType,
byte[] /*(const char*)*/ propertyName,
CheckpointState.PropertyType propertyType,
IntPtr /*(const void*)*/ propertyValue
);
@ -322,9 +349,9 @@ namespace Microsoft.ML.OnnxRuntime
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetProperty(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
IntPtr /*(const char*)*/ propertyName,
byte[] /*(const char*)*/ propertyName,
IntPtr /*(OrtAllocator*)*/ allocator,
out OrtPropertyType propertyType,
out CheckpointState.PropertyType propertyType,
out IntPtr /*(const void**)*/ propertyValue
);

View file

@ -1,17 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_APIS__
/// <summary>
/// Property types
/// </summary>
public enum OrtPropertyType
{
OrtIntProperty = 0,
OrtFloatProperty = 1,
OrtStringProperty = 2,
}
#endif
}

View file

@ -9,12 +9,29 @@ using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_APIS__
/// <summary>
/// This class defines utility methods for training.
/// </summary>
public class TrainingUtils
{
/// <summary>
/// Use this function to generate reproducible results. It should be noted that completely
/// reproducible results are not guaranteed.
/// </summary>
/// <param name="seed">Manual seed to use for random number generation.</param>
public static void SetSeed(long seed)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetSeed(seed));
}
}
enum LRScheduler
{
None = 0,
Constant = 1,
Linear = 2
}
/// <summary>
/// Represents a Training Session on an ONNX Model.
/// This is a IDisposable class and it must be disposed of
@ -34,6 +51,8 @@ namespace Microsoft.ML.OnnxRuntime
private ulong _evalOutputCount;
private List<string> _trainOutputNames;
private List<string> _evalOutputNames;
private List<string> _trainInputNames;
private List<string> _evalInputNames;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
@ -240,7 +259,7 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
@ -319,6 +338,110 @@ namespace Microsoft.ML.OnnxRuntime
}
/// <summary>
/// Export a model that can be used for inferencing.
/// If the training session was provided with an eval model, the training session can generate
/// an inference model if it knows the inference graph outputs. The input inference graph outputs
/// are used to prune the eval model so that the inference model's outputs align with the provided outputs.
/// The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
/// Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
/// and expects that this path still be valid.
/// </summary>
/// <param name="inference_model_path">Path where the inference model should be serialized to.</param>
/// <param name="graphOutputNames">Names of the outputs that are needed in the inference model.</param>
public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection<string> graphOutputNames)
{
using (var cleanupList = new DisposableList<IDisposable>())
{
var outputNamesArray = ConvertNamesToUtf8(graphOutputNames, cleanupList);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtExportModelForInferencing(
_nativeHandle, NativeOnnxValueHelper.GetPlatformSerializedString(inferenceModelPath),
(UIntPtr)graphOutputNames.Count, outputNamesArray));
}
}
/// <summary>
/// Returns a contiguous buffer that holds a copy of all training state parameters
/// </summary>
/// <param name="only_trainable">Whether to only copy trainable parameters or to copy all parameters.</param>
public FixedBufferOnnxValue ToBuffer(bool only_trainable)
{
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, only_trainable));
float[] bufferMemory = new float[bufferSize.ToUInt64()];
var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
var shape = new long[] {(long)bufferSize.ToUInt64()};
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, only_trainable));
return buffer;
}
/// <summary>
/// Loads the training session model parameters from a contiguous buffer
/// </summary>
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
public void FromBuffer(FixedBufferOnnxValue buffer)
{
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
}
IntPtr typeAndShapeInfo = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
UIntPtr numDimensions = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
if (numDimensions.ToUInt64() != 1)
{
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
throw new ArgumentException(errorMessage);
}
IntPtr numElementsTrainingOnly = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
return;
}
IntPtr numElements = IntPtr.Zero;
bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
throw new ArgumentException(errorMessage);
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
}
/// <summary>
/// Retrieves the names of the user outputs for the training and eval models.
/// </summary>
/// <param name="training">Whether the training model output names are requested or eval model output names.</param>
public List<string> OutputNames(bool training)
{
return training ? _trainOutputNames : _evalOutputNames;
}
/// <summary>
/// Retrieves the names of the user inputs for the training and eval models.
/// </summary>
/// <param name="training">Whether the training model input names are requested or eval model input names.</param>
public List<string> InputNames(bool training)
{
return training ? _trainInputNames : _evalInputNames;
}
#endregion
#region private methods
@ -326,7 +449,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (!NativeTrainingMethods.TrainingEnabled())
{
throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
}
var options = sessOptions;
if (sessOptions == null)
@ -351,6 +474,14 @@ namespace Microsoft.ML.OnnxRuntime
_trainOutputNames.Add(GetOutputName(i, true));
}
_trainInputNames = new List<string>();
UIntPtr inputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputCount(_nativeHandle, out inputCount));
for (ulong i = 0; i < inputCount.ToUInt64(); i++)
{
_trainInputNames.Add(GetInputName(i, true));
}
if (evalModelPath != null)
{
outputCount = UIntPtr.Zero;
@ -361,6 +492,14 @@ namespace Microsoft.ML.OnnxRuntime
{
_evalOutputNames.Add(GetOutputName(i, false));
}
_evalInputNames = new List<string>();
inputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputCount(_nativeHandle, out inputCount));
for (ulong i = 0; i < inputCount.ToUInt64(); i++)
{
_evalInputNames.Add(GetInputName(i, false));
}
}
_builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call
@ -395,6 +534,29 @@ namespace Microsoft.ML.OnnxRuntime
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
}
private string GetInputName(ulong index, bool training)
{
var allocator = OrtAllocator.DefaultInstance;
IntPtr nameHandle;
if (training)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
out nameHandle));
}
else
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
out nameHandle));
}
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
}
private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> values, bool input)
{
var valuesArray = new IntPtr[values.Count];
@ -410,6 +572,24 @@ namespace Microsoft.ML.OnnxRuntime
return valuesArray;
}
private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection<string> names, DisposableList<IDisposable> cleanupList)
{
cleanupList.Capacity += names.Count;
var result = new IntPtr[names.Count];
for (int i = 0; i < names.Count; ++i)
{
var name = names.ElementAt(i);
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
var pinnedHandle = new Memory<byte>(utf8Name).Pin();
unsafe
{
result[i] = (IntPtr)pinnedHandle.Pointer;
}
cleanupList.Add(pinnedHandle);
}
return result;
}
/// <summary>
/// Other classes access
/// </summary>

View file

@ -28,8 +28,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestLoadCheckpointThrows()
{
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
var ex = Assert.Throws<InvalidOperationException>(() => { var opt = new CheckpointState(path); });
Assert.Contains("Training is disabled in the current build.", ex.Message);
var ex = Assert.Throws<InvalidOperationException>(() => { var opt = CheckpointState.LoadCheckpoint(path); });
Assert.Contains("Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.", ex.Message);
}
#endif
@ -38,7 +38,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestLoadCheckpoint()
{
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var opt = new CheckpointState(path))
using (var opt = CheckpointState.LoadCheckpoint(path))
{
Assert.NotNull(opt);
}
@ -50,7 +50,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -66,7 +66,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -149,7 +149,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -167,7 +167,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -176,10 +176,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
// Save checkpoint
string savedCheckpointPath = Path.Combine(Directory.GetCurrentDirectory(), "saved_checkpoint.ckpt");
state.SaveCheckpoint(savedCheckpointPath, true);
CheckpointState.SaveCheckpoint(state, savedCheckpointPath, true);
// Load checkpoint and run train step
var loadedState = new CheckpointState(savedCheckpointPath);
var loadedState = CheckpointState.LoadCheckpoint(savedCheckpointPath);
cleanUp.Add(loadedState);
var newTrainingSession = new TrainingSession(loadedState, trainingPath);
cleanUp.Add(newTrainingSession);
@ -193,7 +193,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -252,7 +252,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -274,7 +274,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = new CheckpointState(checkpointPath);
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@ -301,6 +301,226 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
[Fact(DisplayName = "TestTrainingSessionExportModelForInferencing")]
public void TestTrainingSessionExportModelForInferencing()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var graphOutputs = new List<string>(){"output-0"};
string inferencePath = Path.Combine(Directory.GetCurrentDirectory(), "inference_model.onnx");
trainingSession.ExportModelForInferencing(inferencePath, graphOutputs);
Assert.True(File.Exists(inferencePath));
}
}
[Fact(DisplayName = "TestCheckpointStateAddProperty")]
public void TestCheckpointStateAddProperty()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string propertyName = "days in a week";
state.AddProperty(propertyName, (long)7);
var value = state.GetProperty(propertyName);
Assert.True(value is long);
Assert.Equal((long)7, value);
}
}
[Fact(DisplayName = "TestCheckpointStateAddFloatProperty")]
public void TestCheckpointStateAddFloatProperty()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string propertyName = "pi";
state.AddProperty(propertyName, (float)3.14);
var value = state.GetProperty(propertyName);
Assert.True(value is float);
Assert.Equal((float)3.14, value);
}
}
[Fact(DisplayName = "TestCheckpointStateAddStringProperty")]
public void TestCheckpointStateAddStringProperty()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string propertyName = "best ai framework";
state.AddProperty(propertyName, "onnxruntime");
var value = state.GetProperty(propertyName);
Assert.True(value is string);
Assert.Equal("onnxruntime", value);
}
}
[Fact(DisplayName = "TestTrainModelInputNames")]
public void TestTrainModelInputNames()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
var trainingSession = new TrainingSession(state, trainingPath);
cleanUp.Add(trainingSession);
var inputNames = trainingSession.InputNames(true);
Assert.True(inputNames.Count == 2);
Assert.Equal("input-0", inputNames[0]);
Assert.Equal("labels", inputNames[1]);
}
}
[Fact(DisplayName = "TestEvalModelInputNames")]
public void TestEvalModelInputNames()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var inputNames = trainingSession.InputNames(false);
Assert.True(inputNames.Count == 2);
Assert.Equal("input-0", inputNames[0]);
Assert.Equal("labels", inputNames[1]);
}
}
[Fact(DisplayName = "TestTrainModelOutputNames")]
public void TestTrainModelOutputNames()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
var trainingSession = new TrainingSession(state, trainingPath);
cleanUp.Add(trainingSession);
var outputNames = trainingSession.OutputNames(true);
Assert.Single(outputNames);
Assert.Equal("onnx::loss::21273", outputNames[0]);
}
}
[Fact(DisplayName = "TestEvalModelOutputNames")]
public void TestEvalModelOutputNames()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var outputNames = trainingSession.OutputNames(false);
Assert.Single(outputNames);
Assert.Equal("onnx::loss::21273", outputNames[0]);
}
}
[Fact(DisplayName = "TestToBuffer")]
public void TestToBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
}
}
[Fact(DisplayName = "TestFromBuffer")]
public void TestFromBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
trainingSession.FromBuffer(buffer);
}
}
[Fact(DisplayName = "TestSetSeed")]
public void TestSetSeed()
{
TrainingUtils.SetSeed(8888);
}
internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;