From bb33285ec294da2e702bd8c9285a7bfeca0114bf Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 1 May 2023 10:01:38 -0700 Subject: [PATCH] C# training api updates for on device training (#15720) --- .../Training/CheckpointState.shared.cs | 128 +++++++-- .../Training/NativeTrainingMethods.shared.cs | 35 ++- .../Training/OrtPropertyType.shared.cs | 17 -- .../Training/TrainingSession.shared.cs | 184 ++++++++++++- .../TrainingTest.cs | 244 +++++++++++++++++- 5 files changed, 556 insertions(+), 52 deletions(-) delete mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index f775f3ad49..30fa46ccfc 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -19,21 +19,32 @@ namespace Microsoft.ML.OnnxRuntime } } - /// - /// Creates CheckpointState by loading state from path. - /// absolute path to checkpoint file. - /// - public CheckpointState(string checkpointPath) - : base(IntPtr.Zero, true) + private CheckpointState(IntPtr checkpointHandle) + : base(checkpointHandle, true) { - if (NativeTrainingMethods.TrainingEnabled()) + } + + internal enum PropertyType : long + { + Int = 0, + Float = 1, + String = 2 + } + + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + { + var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); + T[] value = new T[1]; + value[0] = propertyValue; + Memory memory = value; + using (var memHandle = memory.Pin()) { - var envHandle = OrtEnv.Instance().Handle; // just so it is initialized - LoadCheckpoint(checkpointPath); - } - else - { - throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n"); + IntPtr memPtr; + unsafe + { + memPtr = (IntPtr)memHandle.Pointer; + } + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -47,9 +58,18 @@ namespace Microsoft.ML.OnnxRuntime /// Loads Checkpoint state from path /// /// absolute path to checkpoint - private void LoadCheckpoint(string checkpointPath) + public static CheckpointState LoadCheckpoint(string checkpointPath) { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle)); + if (!NativeTrainingMethods.TrainingEnabled()) + { + throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n"); + } + + var envHandle = OrtEnv.Instance().Handle; // just so it is initialized + IntPtr checkpointHandle = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out checkpointHandle)); + + return new CheckpointState(checkpointHandle); } /// @@ -57,9 +77,83 @@ namespace Microsoft.ML.OnnxRuntime /// absolute path to the checkpoint file. /// absolute path to the checkpoint file. /// - public void SaveCheckpoint(string checkpointPath, bool includeOptimizerState = false) + public static void SaveCheckpoint(CheckpointState state, string checkpointPath, bool includeOptimizerState = false) { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(state.Handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState)); + } + + /// + /// Adds the given int property to the checkpoint state. + /// Unique name of the property being added. + /// Property value associated with the given name. + /// + public void AddProperty(string propertyName, long propertyValue) + { + AddPropertyImpl(propertyName, PropertyType.Int, propertyValue); + } + + /// + /// Adds the given float property to the checkpoint state. + /// Unique name of the property being added. + /// Property value associated with the given name. + /// + public void AddProperty(string propertyName, float propertyValue) + { + AddPropertyImpl(propertyName, PropertyType.Float, propertyValue); + } + + /// + /// Adds the given string property to the checkpoint state. + /// Unique name of the property being added. + /// Property value associated with the given name. + /// + public void AddProperty(string propertyName, string propertyValue) + { + var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); + var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); + + IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); + try + { + Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); + } + finally + { + Marshal.FreeHGlobal(unmanagedPointer); + } + } + + /// + /// Gets the property value associated with the given name from the checkpoint state. + /// Unique name of the property being retrieved. + /// + public object GetProperty(string propertyName) + { + var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); + var allocator = OrtAllocator.DefaultInstance; + IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); + + if (propertyType == PropertyType.Int) + { + var longPropertyValue = Marshal.ReadInt64(propertyValue); + allocator.FreeMemory(propertyValue); + return longPropertyValue; + } + else if (propertyType == PropertyType.Float) + { + float[] value = new float[1]; + Marshal.Copy(propertyValue, value, 0, 1); + allocator.FreeMemory(propertyValue); + return value[0]; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 5df0720022..35c44c44a0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -82,6 +82,9 @@ namespace Microsoft.ML.OnnxRuntime OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep)); OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler)); OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep)); + OrtGetParametersSize = (DOrtGetParametersSize)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParametersSize, typeof(DOrtGetParametersSize)); + OrtCopyParametersToBuffer = (DOrtCopyParametersToBuffer)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyParametersToBuffer, typeof(DOrtCopyParametersToBuffer)); + OrtCopyBufferToParameters = (DOrtCopyBufferToParameters)Marshal.GetDelegateForFunctionPointer(trainingApi_.CopyBufferToParameters, typeof(DOrtCopyBufferToParameters)); OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession)); OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState)); OrtExportModelForInferencing = (DOrtExportModelForInferencing)Marshal.GetDelegateForFunctionPointer(trainingApi_.ExportModelForInferencing, typeof(DOrtExportModelForInferencing)); @@ -248,6 +251,30 @@ namespace Microsoft.ML.OnnxRuntime ); public static DOrtSchedulerStep OrtSchedulerStep; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParametersSize( + IntPtr /*(OrtTrainingSession*)*/ session, + out UIntPtr buffer_size, + bool only_trainable + ); + public static DOrtGetParametersSize OrtGetParametersSize; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyParametersToBuffer( + IntPtr /*(OrtTrainingSession*)*/ session, + IntPtr /*(OrtValue*)*/ buffer, + bool only_trainable + ); + public static DOrtCopyParametersToBuffer OrtCopyParametersToBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtCopyBufferToParameters( + IntPtr /*(OrtTrainingSession*)*/ session, + IntPtr /*(OrtValue*)*/ buffer, + bool only_trainable + ); + public static DOrtCopyBufferToParameters OrtCopyBufferToParameters; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtTrainingSession*)*/session); public static DOrtReleaseTrainingSession OrtReleaseTrainingSession; @@ -312,8 +339,8 @@ namespace Microsoft.ML.OnnxRuntime [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddProperty( IntPtr /*(OrtCheckpointState*)*/ checkpointState, - IntPtr /*(const char*)*/ propertyName, - OrtPropertyType propertyType, + byte[] /*(const char*)*/ propertyName, + CheckpointState.PropertyType propertyType, IntPtr /*(const void*)*/ propertyValue ); @@ -322,9 +349,9 @@ namespace Microsoft.ML.OnnxRuntime [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetProperty( IntPtr /*(OrtCheckpointState*)*/ checkpointState, - IntPtr /*(const char*)*/ propertyName, + byte[] /*(const char*)*/ propertyName, IntPtr /*(OrtAllocator*)*/ allocator, - out OrtPropertyType propertyType, + out CheckpointState.PropertyType propertyType, out IntPtr /*(const void**)*/ propertyValue ); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs deleted file mode 100644 index 17505d6a27..0000000000 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/OrtPropertyType.shared.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace Microsoft.ML.OnnxRuntime -{ -#if __ENABLE_TRAINING_APIS__ - /// - /// Property types - /// - public enum OrtPropertyType - { - OrtIntProperty = 0, - OrtFloatProperty = 1, - OrtStringProperty = 2, - } -#endif -} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 9874863e91..5fb983ab37 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -9,12 +9,29 @@ using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { #if __ENABLE_TRAINING_APIS__ + /// + /// This class defines utility methods for training. + /// + public class TrainingUtils + { + /// + /// Use this function to generate reproducible results. It should be noted that completely + /// reproducible results are not guaranteed. + /// + /// Manual seed to use for random number generation. + public static void SetSeed(long seed) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetSeed(seed)); + } + } + enum LRScheduler { None = 0, Constant = 1, Linear = 2 } + /// /// Represents a Training Session on an ONNX Model. /// This is a IDisposable class and it must be disposed of @@ -34,6 +51,8 @@ namespace Microsoft.ML.OnnxRuntime private ulong _evalOutputCount; private List _trainOutputNames; private List _evalOutputNames; + private List _trainInputNames; + private List _evalInputNames; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; @@ -240,7 +259,7 @@ namespace Microsoft.ML.OnnxRuntime IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -319,6 +338,110 @@ namespace Microsoft.ML.OnnxRuntime } + /// + /// Export a model that can be used for inferencing. + /// If the training session was provided with an eval model, the training session can generate + /// an inference model if it knows the inference graph outputs. The input inference graph outputs + /// are used to prune the eval model so that the inference model's outputs align with the provided outputs. + /// The exported model is saved at the path provided and can be used for inferencing with Ort::Session. + /// Note that the function re-loads the eval model from the path provided to Ort::TrainingSession + /// and expects that this path still be valid. + /// + /// Path where the inference model should be serialized to. + /// Names of the outputs that are needed in the inference model. + public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection graphOutputNames) + { + using (var cleanupList = new DisposableList()) + { + var outputNamesArray = ConvertNamesToUtf8(graphOutputNames, cleanupList); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtExportModelForInferencing( + _nativeHandle, NativeOnnxValueHelper.GetPlatformSerializedString(inferenceModelPath), + (UIntPtr)graphOutputNames.Count, outputNamesArray)); + } + } + + /// + /// Returns a contiguous buffer that holds a copy of all training state parameters + /// + /// Whether to only copy trainable parameters or to copy all parameters. + public FixedBufferOnnxValue ToBuffer(bool only_trainable) + { + UIntPtr bufferSize = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, only_trainable)); + + float[] bufferMemory = new float[bufferSize.ToUInt64()]; + + var memInfo = OrtMemoryInfo.DefaultInstance; // CPU + var shape = new long[] {(long)bufferSize.ToUInt64()}; + var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, only_trainable)); + + return buffer; + } + + /// + /// Loads the training session model parameters from a contiguous buffer + /// + /// Contiguous buffer to load the parameters from. + public void FromBuffer(FixedBufferOnnxValue buffer) + { + if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + { + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); + } + + IntPtr typeAndShapeInfo = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); + UIntPtr numDimensions = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); + if (numDimensions.ToUInt64() != 1) + { + string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); + throw new ArgumentException(errorMessage); + } + + IntPtr numElementsTrainingOnly = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly)); + + UIntPtr bufferSize = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true)); + if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64()) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); + return; + } + + IntPtr numElements = IntPtr.Zero; + bufferSize = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false)); + if ((long)bufferSize.ToUInt64() != numElements.ToInt64()) + { + string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + throw new ArgumentException(errorMessage); + } + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + } + + /// + /// Retrieves the names of the user outputs for the training and eval models. + /// + /// Whether the training model output names are requested or eval model output names. + public List OutputNames(bool training) + { + return training ? _trainOutputNames : _evalOutputNames; + } + + /// + /// Retrieves the names of the user inputs for the training and eval models. + /// + /// Whether the training model input names are requested or eval model input names. + public List InputNames(bool training) + { + return training ? _trainInputNames : _evalInputNames; + } + #endregion #region private methods @@ -326,7 +449,7 @@ namespace Microsoft.ML.OnnxRuntime { if (!NativeTrainingMethods.TrainingEnabled()) { - throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training_apis. \n"); + throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n"); } var options = sessOptions; if (sessOptions == null) @@ -351,6 +474,14 @@ namespace Microsoft.ML.OnnxRuntime _trainOutputNames.Add(GetOutputName(i, true)); } + _trainInputNames = new List(); + UIntPtr inputCount = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputCount(_nativeHandle, out inputCount)); + for (ulong i = 0; i < inputCount.ToUInt64(); i++) + { + _trainInputNames.Add(GetInputName(i, true)); + } + if (evalModelPath != null) { outputCount = UIntPtr.Zero; @@ -361,6 +492,14 @@ namespace Microsoft.ML.OnnxRuntime { _evalOutputNames.Add(GetOutputName(i, false)); } + + _evalInputNames = new List(); + inputCount = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputCount(_nativeHandle, out inputCount)); + for (ulong i = 0; i < inputCount.ToUInt64(); i++) + { + _evalInputNames.Add(GetInputName(i, false)); + } } _builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call @@ -395,6 +534,29 @@ namespace Microsoft.ML.OnnxRuntime return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator); } + private string GetInputName(ulong index, bool training) + { + var allocator = OrtAllocator.DefaultInstance; + IntPtr nameHandle; + if (training) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelInputName( + _nativeHandle, + (UIntPtr)index, + allocator.Pointer, + out nameHandle)); + } + else + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelInputName( + _nativeHandle, + (UIntPtr)index, + allocator.Pointer, + out nameHandle)); + } + return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle, allocator); + } + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, bool input) { var valuesArray = new IntPtr[values.Count]; @@ -410,6 +572,24 @@ namespace Microsoft.ML.OnnxRuntime return valuesArray; } + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) + { + cleanupList.Capacity += names.Count; + var result = new IntPtr[names.Count]; + for (int i = 0; i < names.Count; ++i) + { + var name = names.ElementAt(i); + var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name); + var pinnedHandle = new Memory(utf8Name).Pin(); + unsafe + { + result[i] = (IntPtr)pinnedHandle.Pointer; + } + cleanupList.Add(pinnedHandle); + } + return result; + } + /// /// Other classes access /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 0d16203dbc..7babed7f42 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -28,8 +28,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests public void TestLoadCheckpointThrows() { string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - var ex = Assert.Throws(() => { var opt = new CheckpointState(path); }); - Assert.Contains("Training is disabled in the current build.", ex.Message); + var ex = Assert.Throws(() => { var opt = CheckpointState.LoadCheckpoint(path); }); + Assert.Contains("Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.", ex.Message); } #endif @@ -38,7 +38,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests public void TestLoadCheckpoint() { string path = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var opt = new CheckpointState(path)) + using (var opt = CheckpointState.LoadCheckpoint(path)) { Assert.NotNull(opt); } @@ -50,7 +50,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -66,7 +66,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -149,7 +149,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -167,7 +167,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -176,10 +176,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests // Save checkpoint string savedCheckpointPath = Path.Combine(Directory.GetCurrentDirectory(), "saved_checkpoint.ckpt"); - state.SaveCheckpoint(savedCheckpointPath, true); + CheckpointState.SaveCheckpoint(state, savedCheckpointPath, true); // Load checkpoint and run train step - var loadedState = new CheckpointState(savedCheckpointPath); + var loadedState = CheckpointState.LoadCheckpoint(savedCheckpointPath); cleanUp.Add(loadedState); var newTrainingSession = new TrainingSession(loadedState, trainingPath); cleanUp.Add(newTrainingSession); @@ -193,7 +193,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -252,7 +252,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -274,7 +274,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); using (var cleanUp = new DisposableListTest()) { - var state = new CheckpointState(checkpointPath); + var state = CheckpointState.LoadCheckpoint(checkpointPath); cleanUp.Add(state); Assert.NotNull(state); string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); @@ -301,6 +301,226 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } + [Fact(DisplayName = "TestTrainingSessionExportModelForInferencing")] + public void TestTrainingSessionExportModelForInferencing() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var graphOutputs = new List(){"output-0"}; + + string inferencePath = Path.Combine(Directory.GetCurrentDirectory(), "inference_model.onnx"); + + trainingSession.ExportModelForInferencing(inferencePath, graphOutputs); + Assert.True(File.Exists(inferencePath)); + } + } + + [Fact(DisplayName = "TestCheckpointStateAddProperty")] + public void TestCheckpointStateAddProperty() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + + string propertyName = "days in a week"; + state.AddProperty(propertyName, (long)7); + + var value = state.GetProperty(propertyName); + Assert.True(value is long); + Assert.Equal((long)7, value); + } + } + + [Fact(DisplayName = "TestCheckpointStateAddFloatProperty")] + public void TestCheckpointStateAddFloatProperty() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + + string propertyName = "pi"; + state.AddProperty(propertyName, (float)3.14); + + var value = state.GetProperty(propertyName); + Assert.True(value is float); + Assert.Equal((float)3.14, value); + } + } + + [Fact(DisplayName = "TestCheckpointStateAddStringProperty")] + public void TestCheckpointStateAddStringProperty() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + + string propertyName = "best ai framework"; + state.AddProperty(propertyName, "onnxruntime"); + + var value = state.GetProperty(propertyName); + Assert.True(value is string); + Assert.Equal("onnxruntime", value); + } + } + + [Fact(DisplayName = "TestTrainModelInputNames")] + public void TestTrainModelInputNames() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + var trainingSession = new TrainingSession(state, trainingPath); + cleanUp.Add(trainingSession); + + var inputNames = trainingSession.InputNames(true); + + Assert.True(inputNames.Count == 2); + Assert.Equal("input-0", inputNames[0]); + Assert.Equal("labels", inputNames[1]); + } + } + + [Fact(DisplayName = "TestEvalModelInputNames")] + public void TestEvalModelInputNames() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var inputNames = trainingSession.InputNames(false); + + Assert.True(inputNames.Count == 2); + Assert.Equal("input-0", inputNames[0]); + Assert.Equal("labels", inputNames[1]); + } + } + + [Fact(DisplayName = "TestTrainModelOutputNames")] + public void TestTrainModelOutputNames() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + var trainingSession = new TrainingSession(state, trainingPath); + cleanUp.Add(trainingSession); + + var outputNames = trainingSession.OutputNames(true); + + Assert.Single(outputNames); + Assert.Equal("onnx::loss::21273", outputNames[0]); + } + } + + [Fact(DisplayName = "TestEvalModelOutputNames")] + public void TestEvalModelOutputNames() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var outputNames = trainingSession.OutputNames(false); + + Assert.Single(outputNames); + Assert.Equal("onnx::loss::21273", outputNames[0]); + } + } + + [Fact(DisplayName = "TestToBuffer")] + public void TestToBuffer() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var buffer = trainingSession.ToBuffer(true); + cleanUp.Add(buffer); + } + } + + [Fact(DisplayName = "TestFromBuffer")] + public void TestFromBuffer() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + var buffer = trainingSession.ToBuffer(true); + cleanUp.Add(buffer); + + trainingSession.FromBuffer(buffer); + } + } + + [Fact(DisplayName = "TestSetSeed")] + public void TestSetSeed() + { + TrainingUtils.SetSeed(8888); + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f;