diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
index f775f3ad49..30fa46ccfc 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
@@ -19,21 +19,32 @@ namespace Microsoft.ML.OnnxRuntime
}
}
- ///
- /// Creates CheckpointState by loading state from path.
- /// absolute path to checkpoint file.
- ///
- public CheckpointState(string checkpointPath)
- : base(IntPtr.Zero, true)
+ private CheckpointState(IntPtr checkpointHandle)
+ : base(checkpointHandle, true)
{
- if (NativeTrainingMethods.TrainingEnabled())
+ }
+
+ internal enum PropertyType : long
+ {
+ Int = 0,
+ Float = 1,
+ String = 2
+ }
+
+ private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue)
+ {
+ var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
+ T[] value = new T[1];
+ value[0] = propertyValue;
+ Memory memory = value;
+ using (var memHandle = memory.Pin())
{
- var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
- LoadCheckpoint(checkpointPath);
- }
- else
- {
- throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
+ IntPtr memPtr;
+ unsafe
+ {
+ memPtr = (IntPtr)memHandle.Pointer;
+ }
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}
@@ -47,9 +58,18 @@ namespace Microsoft.ML.OnnxRuntime
/// Loads Checkpoint state from path
///
/// absolute path to checkpoint
- private void LoadCheckpoint(string checkpointPath)
+ public static CheckpointState LoadCheckpoint(string checkpointPath)
{
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle));
+ if (!NativeTrainingMethods.TrainingEnabled())
+ {
+ throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
+ }
+
+ var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
+ IntPtr checkpointHandle = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out checkpointHandle));
+
+ return new CheckpointState(checkpointHandle);
}
///
@@ -57,9 +77,83 @@ namespace Microsoft.ML.OnnxRuntime
/// absolute path to the checkpoint file.
/// absolute path to the checkpoint file.
///
- public void SaveCheckpoint(string checkpointPath, bool includeOptimizerState = false)
+ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, bool includeOptimizerState = false)
{
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(state.Handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
+ }
+
+ ///
+ /// Adds the given int property to the checkpoint state.
+ /// Unique name of the property being added.
+ /// Property value associated with the given name.
+ ///
+ public void AddProperty(string propertyName, long propertyValue)
+ {
+ AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
+ }
+
+ ///
+ /// Adds the given float property to the checkpoint state.
+ /// Unique name of the property being added.
+ /// Property value associated with the given name.
+ ///
+ public void AddProperty(string propertyName, float propertyValue)
+ {
+ AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
+ }
+
+ ///
+ /// Adds the given string property to the checkpoint state.
+ /// Unique name of the property being added.
+ /// Property value associated with the given name.
+ ///
+ public void AddProperty(string propertyName, string propertyValue)
+ {
+ var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
+ var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
+
+ IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
+ try
+ {
+ Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
+ }
+ finally
+ {
+ Marshal.FreeHGlobal(unmanagedPointer);
+ }
+ }
+
+ ///
+ /// Gets the property value associated with the given name from the checkpoint state.
+ /// Unique name of the property being retrieved.
+ ///
+ public object GetProperty(string propertyName)
+ {
+ var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
+ var allocator = OrtAllocator.DefaultInstance;
+ IntPtr propertyValue = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
+
+ if (propertyType == PropertyType.Int)
+ {
+ var longPropertyValue = Marshal.ReadInt64(propertyValue);
+ allocator.FreeMemory(propertyValue);
+ return longPropertyValue;
+ }
+ else if (propertyType == PropertyType.Float)
+ {
+ float[] value = new float[1];
+ Marshal.Copy(propertyValue, value, 0, 1);
+ allocator.FreeMemory(propertyValue);
+ return value[0];
+ }
+ else if (propertyType == PropertyType.String)
+ {
+ return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
+ }
+
+ throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
#region SafeHandle
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
index 5df0720022..35c44c44a0 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
@@ -82,6 +82,9 @@ namespace Microsoft.ML.OnnxRuntime
OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep));
OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler));
OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep));
+ OrtGetParametersSize = (DOrtGetParametersSize)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParametersSize, typeof(DOrtGetParametersSize));
+ OrtCopyParametersToBuffer = (DOrtCopyParametersToBuffer)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyParametersToBuffer, typeof(DOrtCopyParametersToBuffer));
+ OrtCopyBufferToParameters = (DOrtCopyBufferToParameters)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyBufferToParameters, typeof(DOrtCopyBufferToParameters));
OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession));
OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState));
OrtExportModelForInferencing = (DOrtExportModelForInferencing)Marshal.GetDelegateForFunctionPointer(trainingApi_.ExportModelForInferencing, typeof(DOrtExportModelForInferencing));
@@ -248,6 +251,30 @@ namespace Microsoft.ML.OnnxRuntime
);
public static DOrtSchedulerStep OrtSchedulerStep;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParametersSize(
+ IntPtr /*(OrtTrainingSession*)*/ session,
+ out UIntPtr buffer_size,
+ bool only_trainable
+ );
+ public static DOrtGetParametersSize OrtGetParametersSize;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyParametersToBuffer(
+ IntPtr /*(OrtTrainingSession*)*/ session,
+ IntPtr /*(OrtValue*)*/ buffer,
+ bool only_trainable
+ );
+ public static DOrtCopyParametersToBuffer OrtCopyParametersToBuffer;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyBufferToParameters(
+ IntPtr /*(OrtTrainingSession*)*/ session,
+ IntPtr /*(OrtValue*)*/ buffer,
+ bool only_trainable
+ );
+ public static DOrtCopyBufferToParameters OrtCopyBufferToParameters;
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtTrainingSession*)*/session);
public static DOrtReleaseTrainingSession OrtReleaseTrainingSession;
@@ -312,8 +339,8 @@ namespace Microsoft.ML.OnnxRuntime
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddProperty(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
- IntPtr /*(const char*)*/ propertyName,
- OrtPropertyType propertyType,
+ byte[] /*(const char*)*/ propertyName,
+ CheckpointState.PropertyType propertyType,
IntPtr /*(const void*)*/ propertyValue
);
@@ -322,9 +349,9 @@ namespace Microsoft.ML.OnnxRuntime
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetProperty(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
- IntPtr /*(const char*)*/ propertyName,
+ byte[] /*(const char*)*/ propertyName,
IntPtr /*(OrtAllocator*)*/ allocator,
- out OrtPropertyType propertyType,
+ out CheckpointState.PropertyType propertyType,
out IntPtr /*(const void**)*/ propertyValue
);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs
deleted file mode 100644
index 17505d6a27..0000000000
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-namespace Microsoft.ML.OnnxRuntime
-{
-#if __ENABLE_TRAINING_APIS__
- ///
- /// Property types
- ///
- public enum OrtPropertyType
- {
- OrtIntProperty = 0,
- OrtFloatProperty = 1,
- OrtStringProperty = 2,
- }
-#endif
-}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
index 9874863e91..5fb983ab37 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
@@ -9,12 +9,29 @@ using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_APIS__
+ ///
+ /// This class defines utility methods for training.
+ ///
+ public class TrainingUtils
+ {
+ ///
+ /// Use this function to generate reproducible results. It should be noted that completely
+ /// reproducible results are not guaranteed.
+ ///
+ /// Manual seed to use for random number generation.
+ public static void SetSeed(long seed)
+ {
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetSeed(seed));
+ }
+ }
+
enum LRScheduler
{
None = 0,
Constant = 1,
Linear = 2
}
+
///
/// Represents a Training Session on an ONNX Model.
/// This is a IDisposable class and it must be disposed of
@@ -34,6 +51,8 @@ namespace Microsoft.ML.OnnxRuntime
private ulong _evalOutputCount;
private List _trainOutputNames;
private List _evalOutputNames;
+ private List _trainInputNames;
+ private List _evalInputNames;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
@@ -240,7 +259,7 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
@@ -319,6 +338,110 @@ namespace Microsoft.ML.OnnxRuntime
}
+ ///
+ /// Export a model that can be used for inferencing.
+ /// If the training session was provided with an eval model, the training session can generate
+ /// an inference model if it knows the inference graph outputs. The input inference graph outputs
+ /// are used to prune the eval model so that the inference model's outputs align with the provided outputs.
+ /// The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
+ /// Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
+ /// and expects that this path still be valid.
+ ///
+ /// Path where the inference model should be serialized to.
+ /// Names of the outputs that are needed in the inference model.
+ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection graphOutputNames)
+ {
+ using (var cleanupList = new DisposableList())
+ {
+ var outputNamesArray = ConvertNamesToUtf8(graphOutputNames, cleanupList);
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtExportModelForInferencing(
+ _nativeHandle, NativeOnnxValueHelper.GetPlatformSerializedString(inferenceModelPath),
+ (UIntPtr)graphOutputNames.Count, outputNamesArray));
+ }
+ }
+
+ ///
+ /// Returns a contiguous buffer that holds a copy of all training state parameters
+ ///
+ /// Whether to only copy trainable parameters or to copy all parameters.
+ public FixedBufferOnnxValue ToBuffer(bool only_trainable)
+ {
+ UIntPtr bufferSize = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, only_trainable));
+
+ float[] bufferMemory = new float[bufferSize.ToUInt64()];
+
+ var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
+ var shape = new long[] {(long)bufferSize.ToUInt64()};
+ var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
+
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, only_trainable));
+
+ return buffer;
+ }
+
+ ///
+ /// Loads the training session model parameters from a contiguous buffer
+ ///
+ /// Contiguous buffer to load the parameters from.
+ public void FromBuffer(FixedBufferOnnxValue buffer)
+ {
+ if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
+ {
+ throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
+ }
+
+ IntPtr typeAndShapeInfo = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
+ UIntPtr numDimensions = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
+ if (numDimensions.ToUInt64() != 1)
+ {
+ string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
+ throw new ArgumentException(errorMessage);
+ }
+
+ IntPtr numElementsTrainingOnly = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
+
+ UIntPtr bufferSize = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
+ if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
+ {
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
+ return;
+ }
+
+ IntPtr numElements = IntPtr.Zero;
+ bufferSize = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
+ if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
+ {
+ string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
+ throw new ArgumentException(errorMessage);
+ }
+
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
+ }
+
+ ///
+ /// Retrieves the names of the user outputs for the training and eval models.
+ ///
+ /// Whether the training model output names are requested or eval model output names.
+ public List OutputNames(bool training)
+ {
+ return training ? _trainOutputNames : _evalOutputNames;
+ }
+
+ ///
+ /// Retrieves the names of the user inputs for the training and eval models.
+ ///
+ /// Whether the training model input names are requested or eval model input names.
+ public List InputNames(bool training)
+ {
+ return training ? _trainInputNames : _evalInputNames;
+ }
+
#endregion
#region private methods
@@ -326,7 +449,7 @@ namespace Microsoft.ML.OnnxRuntime
{
if (!NativeTrainingMethods.TrainingEnabled())
{
- throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n");
+ throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
}
var options = sessOptions;
if (sessOptions == null)
@@ -351,6 +474,14 @@ namespace Microsoft.ML.OnnxRuntime
_trainOutputNames.Add(GetOutputName(i, true));
}
+ _trainInputNames = new List();
+ UIntPtr inputCount = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputCount(_nativeHandle, out inputCount));
+ for (ulong i = 0; i < inputCount.ToUInt64(); i++)
+ {
+ _trainInputNames.Add(GetInputName(i, true));
+ }
+
if (evalModelPath != null)
{
outputCount = UIntPtr.Zero;
@@ -361,6 +492,14 @@ namespace Microsoft.ML.OnnxRuntime
{
_evalOutputNames.Add(GetOutputName(i, false));
}
+
+ _evalInputNames = new List();
+ inputCount = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputCount(_nativeHandle, out inputCount));
+ for (ulong i = 0; i < inputCount.ToUInt64(); i++)
+ {
+ _evalInputNames.Add(GetInputName(i, false));
+ }
}
_builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call
@@ -395,6 +534,29 @@ namespace Microsoft.ML.OnnxRuntime
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
}
+ private string GetInputName(ulong index, bool training)
+ {
+ var allocator = OrtAllocator.DefaultInstance;
+ IntPtr nameHandle;
+ if (training)
+ {
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputName(
+ _nativeHandle,
+ (UIntPtr)index,
+ allocator.Pointer,
+ out nameHandle));
+ }
+ else
+ {
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputName(
+ _nativeHandle,
+ (UIntPtr)index,
+ allocator.Pointer,
+ out nameHandle));
+ }
+ return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator);
+ }
+
private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, bool input)
{
var valuesArray = new IntPtr[values.Count];
@@ -410,6 +572,24 @@ namespace Microsoft.ML.OnnxRuntime
return valuesArray;
}
+ private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList)
+ {
+ cleanupList.Capacity += names.Count;
+ var result = new IntPtr[names.Count];
+ for (int i = 0; i < names.Count; ++i)
+ {
+ var name = names.ElementAt(i);
+ var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
+ var pinnedHandle = new Memory(utf8Name).Pin();
+ unsafe
+ {
+ result[i] = (IntPtr)pinnedHandle.Pointer;
+ }
+ cleanupList.Add(pinnedHandle);
+ }
+ return result;
+ }
+
///
/// Other classes access
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
index 0d16203dbc..7babed7f42 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
@@ -28,8 +28,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestLoadCheckpointThrows()
{
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
- var ex = Assert.Throws(() => { var opt = new CheckpointState(path); });
- Assert.Contains("Training is disabled in the current build.", ex.Message);
+ var ex = Assert.Throws(() => { var opt = CheckpointState.LoadCheckpoint(path); });
+ Assert.Contains("Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.", ex.Message);
}
#endif
@@ -38,7 +38,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestLoadCheckpoint()
{
string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
- using (var opt = new CheckpointState(path))
+ using (var opt = CheckpointState.LoadCheckpoint(path))
{
Assert.NotNull(opt);
}
@@ -50,7 +50,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -66,7 +66,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -149,7 +149,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -167,7 +167,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -176,10 +176,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
// Save checkpoint
string savedCheckpointPath = Path.Combine(Directory.GetCurrentDirectory(), "saved_checkpoint.ckpt");
- state.SaveCheckpoint(savedCheckpointPath, true);
+ CheckpointState.SaveCheckpoint(state, savedCheckpointPath, true);
// Load checkpoint and run train step
- var loadedState = new CheckpointState(savedCheckpointPath);
+ var loadedState = CheckpointState.LoadCheckpoint(savedCheckpointPath);
cleanUp.Add(loadedState);
var newTrainingSession = new TrainingSession(loadedState, trainingPath);
cleanUp.Add(newTrainingSession);
@@ -193,7 +193,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -252,7 +252,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -274,7 +274,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest())
{
- var state = new CheckpointState(checkpointPath);
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
@@ -301,6 +301,226 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
+ [Fact(DisplayName = "TestTrainingSessionExportModelForInferencing")]
+ public void TestTrainingSessionExportModelForInferencing()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
+ cleanUp.Add(trainingSession);
+
+ var graphOutputs = new List(){"output-0"};
+
+ string inferencePath = Path.Combine(Directory.GetCurrentDirectory(), "inference_model.onnx");
+
+ trainingSession.ExportModelForInferencing(inferencePath, graphOutputs);
+ Assert.True(File.Exists(inferencePath));
+ }
+ }
+
+ [Fact(DisplayName = "TestCheckpointStateAddProperty")]
+ public void TestCheckpointStateAddProperty()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+
+ string propertyName = "days in a week";
+ state.AddProperty(propertyName, (long)7);
+
+ var value = state.GetProperty(propertyName);
+ Assert.True(value is long);
+ Assert.Equal((long)7, value);
+ }
+ }
+
+ [Fact(DisplayName = "TestCheckpointStateAddFloatProperty")]
+ public void TestCheckpointStateAddFloatProperty()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+
+ string propertyName = "pi";
+ state.AddProperty(propertyName, (float)3.14);
+
+ var value = state.GetProperty(propertyName);
+ Assert.True(value is float);
+ Assert.Equal((float)3.14, value);
+ }
+ }
+
+ [Fact(DisplayName = "TestCheckpointStateAddStringProperty")]
+ public void TestCheckpointStateAddStringProperty()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+
+ string propertyName = "best ai framework";
+ state.AddProperty(propertyName, "onnxruntime");
+
+ var value = state.GetProperty(propertyName);
+ Assert.True(value is string);
+ Assert.Equal("onnxruntime", value);
+ }
+ }
+
+ [Fact(DisplayName = "TestTrainModelInputNames")]
+ public void TestTrainModelInputNames()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ var trainingSession = new TrainingSession(state, trainingPath);
+ cleanUp.Add(trainingSession);
+
+ var inputNames = trainingSession.InputNames(true);
+
+ Assert.True(inputNames.Count == 2);
+ Assert.Equal("input-0", inputNames[0]);
+ Assert.Equal("labels", inputNames[1]);
+ }
+ }
+
+ [Fact(DisplayName = "TestEvalModelInputNames")]
+ public void TestEvalModelInputNames()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
+ cleanUp.Add(trainingSession);
+
+ var inputNames = trainingSession.InputNames(false);
+
+ Assert.True(inputNames.Count == 2);
+ Assert.Equal("input-0", inputNames[0]);
+ Assert.Equal("labels", inputNames[1]);
+ }
+ }
+
+ [Fact(DisplayName = "TestTrainModelOutputNames")]
+ public void TestTrainModelOutputNames()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ var trainingSession = new TrainingSession(state, trainingPath);
+ cleanUp.Add(trainingSession);
+
+ var outputNames = trainingSession.OutputNames(true);
+
+ Assert.Single(outputNames);
+ Assert.Equal("onnx::loss::21273", outputNames[0]);
+ }
+ }
+
+ [Fact(DisplayName = "TestEvalModelOutputNames")]
+ public void TestEvalModelOutputNames()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
+ cleanUp.Add(trainingSession);
+
+ var outputNames = trainingSession.OutputNames(false);
+
+ Assert.Single(outputNames);
+ Assert.Equal("onnx::loss::21273", outputNames[0]);
+ }
+ }
+
+ [Fact(DisplayName = "TestToBuffer")]
+ public void TestToBuffer()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
+ cleanUp.Add(trainingSession);
+
+ var buffer = trainingSession.ToBuffer(true);
+ cleanUp.Add(buffer);
+ }
+ }
+
+ [Fact(DisplayName = "TestFromBuffer")]
+ public void TestFromBuffer()
+ {
+ string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
+ using (var cleanUp = new DisposableListTest())
+ {
+ var state = CheckpointState.LoadCheckpoint(checkpointPath);
+ cleanUp.Add(state);
+ Assert.NotNull(state);
+ string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
+ string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
+ string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
+
+ var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
+ cleanUp.Add(trainingSession);
+
+ var buffer = trainingSession.ToBuffer(true);
+ cleanUp.Add(buffer);
+
+ trainingSession.FromBuffer(buffer);
+ }
+ }
+
+ [Fact(DisplayName = "TestSetSeed")]
+ public void TestSetSeed()
+ {
+ TrainingUtils.SetSeed(8888);
+ }
+
internal class FloatComparer : IEqualityComparer
{
private float atol = 1e-3f;