2022-09-02 20:13:48 +00:00
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System ;
using System.Runtime.InteropServices ;
namespace Microsoft.ML.OnnxRuntime
{
2023-01-03 21:28:16 +00:00
#if __ENABLE_TRAINING_APIS__
2022-09-02 20:13:48 +00:00
/// <summary>
2023-05-16 10:15:24 +00:00
/// Holds the state of the training session.
/// This class holds the entire training session state that includes model parameters, their gradients,
/// optimizer parameters, and user properties. The TrainingSession leverages the CheckpointState
/// by accessing and updating the contained training state.
/// <note type="note">
/// Note that the training session created with a checkpoint state uses this state to store the entire
/// training state (including model parameters, its gradients, the optimizer states and the properties).
/// The TrainingSession does not hold a copy of the CheckpointState and as a result, it is required
/// that the checkpoint state outlives the lifetime of the training session.
/// </note>
2022-09-02 20:13:48 +00:00
/// </summary>
public class CheckpointState : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle ;
}
}
2023-05-01 17:01:38 +00:00
private CheckpointState ( IntPtr checkpointHandle )
: base ( checkpointHandle , true )
2022-09-02 20:13:48 +00:00
{
2023-05-01 17:01:38 +00:00
}
internal enum PropertyType : long
{
Int = 0 ,
Float = 1 ,
String = 2
}
2023-11-02 17:01:53 +00:00
private void AddPropertyImpl < T > ( string propertyName , PropertyType propertyType , T propertyValue ) where T : unmanaged
2023-05-01 17:01:38 +00:00
{
var propertyNameUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( propertyName ) ;
2023-11-02 17:01:53 +00:00
T [ ] value = { propertyValue } ;
unsafe
2022-09-02 20:13:48 +00:00
{
2023-11-02 17:01:53 +00:00
fixed ( T * memPtr = value )
2023-05-01 17:01:38 +00:00
{
2023-11-02 17:01:53 +00:00
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtAddProperty ( handle , propertyNameUtf8 , propertyType , ( IntPtr ) memPtr ) ) ;
2023-05-01 17:01:38 +00:00
}
2022-09-02 20:13:48 +00:00
}
}
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle = = IntPtr . Zero ; } }
/// <summary>
2023-05-16 10:15:24 +00:00
/// Load a checkpoint state from a directory on disk into checkpoint_state.
///
/// This function will parse a checkpoint directory, pull relevant files and load the training
/// state into the checkpoint_state. This checkpoint state can then be used to create the
/// training session by instantiating the TrainingSession. By doing so, the training
/// session will begin or resume training from the given checkpoint state.
2022-09-02 20:13:48 +00:00
/// </summary>
2023-05-16 10:15:24 +00:00
/// <param name="checkpointPath"> Absolute path to the checkpoint directory.</param>
/// <returns>CheckpointState object which holds the state of the training session parameters.</returns>
2023-05-01 17:01:38 +00:00
public static CheckpointState LoadCheckpoint ( string checkpointPath )
2022-09-02 20:13:48 +00:00
{
2023-05-01 17:01:38 +00:00
if ( ! NativeTrainingMethods . TrainingEnabled ( ) )
{
throw new InvalidOperationException ( "This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n" ) ;
}
var envHandle = OrtEnv . Instance ( ) . Handle ; // just so it is initialized
IntPtr checkpointHandle = IntPtr . Zero ;
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtLoadCheckpoint ( NativeOnnxValueHelper . GetPlatformSerializedString ( checkpointPath ) , out checkpointHandle ) ) ;
return new CheckpointState ( checkpointHandle ) ;
2022-09-02 20:13:48 +00:00
}
2023-04-21 18:36:01 +00:00
/// <summary>
2023-05-16 10:15:24 +00:00
/// Save the given state to a checkpoint directory on disk.
///
/// This function serializes the provided checkpoint state to a directory on disk.
/// This checkpoint can later be loaded by invoking CheckpointState.LoadCheckpoint to begin or resume
/// training from this snapshot of the state.
2023-04-21 18:36:01 +00:00
/// </summary>
2023-05-16 10:15:24 +00:00
/// <param name="state"> The checkpoint state to save.</param>
/// <param name="checkpointPath"> Absolute path to the checkpoint directory.</param>
/// <param name="includeOptimizerState"> Flag to indicate whether to save the optimizer state or not.</param>
2023-05-01 17:01:38 +00:00
public static void SaveCheckpoint ( CheckpointState state , string checkpointPath , bool includeOptimizerState = false )
{
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtSaveCheckpoint ( state . Handle , NativeOnnxValueHelper . GetPlatformSerializedString ( checkpointPath ) , includeOptimizerState ) ) ;
}
/// <summary>
2023-11-02 17:01:53 +00:00
/// Adds or updates the given int property to/in the checkpoint state.
2023-05-16 10:15:24 +00:00
///
2023-11-02 17:01:53 +00:00
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
2023-05-16 10:15:24 +00:00
/// </summary>
2023-11-02 17:01:53 +00:00
/// <param name="propertyName">Name of the property being added or updated.</param>
2023-05-01 17:01:38 +00:00
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty ( string propertyName , long propertyValue )
{
AddPropertyImpl ( propertyName , PropertyType . Int , propertyValue ) ;
}
/// <summary>
2023-11-02 17:01:53 +00:00
/// Adds or updates the given float property to/in the checkpoint state.
2023-05-16 10:15:24 +00:00
///
2023-11-02 17:01:53 +00:00
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
2023-05-16 10:15:24 +00:00
/// </summary>
2023-11-02 17:01:53 +00:00
/// <param name="propertyName">Name of the property being added or updated.</param>
2023-05-01 17:01:38 +00:00
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty ( string propertyName , float propertyValue )
2023-04-21 18:36:01 +00:00
{
2023-05-01 17:01:38 +00:00
AddPropertyImpl ( propertyName , PropertyType . Float , propertyValue ) ;
}
/// <summary>
2023-11-02 17:01:53 +00:00
/// Adds or updates the given string property to/in the checkpoint state.
2023-05-16 10:15:24 +00:00
///
2023-11-02 17:01:53 +00:00
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
2023-05-16 10:15:24 +00:00
/// </summary>
2023-11-02 17:01:53 +00:00
/// <param name="propertyName">Name of the property being added or updated.</param>
2023-05-01 17:01:38 +00:00
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty ( string propertyName , string propertyValue )
{
var propertyNameUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( propertyName ) ;
var propertyValueUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( propertyValue ) ;
2023-11-02 17:01:53 +00:00
unsafe
2023-05-01 17:01:38 +00:00
{
2023-11-02 17:01:53 +00:00
fixed ( byte * p = propertyValueUtf8 )
{
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtAddProperty ( handle , propertyNameUtf8 , PropertyType . String , ( IntPtr ) p ) ) ;
}
2023-05-01 17:01:38 +00:00
}
}
/// <summary>
/// Gets the property value associated with the given name from the checkpoint state.
2023-05-16 10:15:24 +00:00
///
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
2023-05-01 17:01:38 +00:00
/// </summary>
2023-11-02 17:01:53 +00:00
/// <param name="propertyName">Name of the property being retrieved.</param>
2023-05-16 10:15:24 +00:00
/// <returns>Property value associated with the given property name.</returns>
2023-05-01 17:01:38 +00:00
public object GetProperty ( string propertyName )
{
var propertyNameUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( propertyName ) ;
var allocator = OrtAllocator . DefaultInstance ;
IntPtr propertyValue = IntPtr . Zero ;
2023-11-02 17:01:53 +00:00
2023-05-01 17:01:38 +00:00
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtGetProperty ( handle , propertyNameUtf8 , allocator . Pointer , out PropertyType propertyType , out propertyValue ) ) ;
2023-11-02 17:01:53 +00:00
try
2023-05-01 17:01:38 +00:00
{
2023-11-02 17:01:53 +00:00
if ( propertyType = = PropertyType . Int )
{
Int64 value ;
unsafe
{
value = * ( Int64 * ) propertyValue ;
}
return value ;
}
else if ( propertyType = = PropertyType . Float )
{
float value ;
unsafe
{
value = * ( float * ) propertyValue ;
}
return value ;
}
else if ( propertyType = = PropertyType . String )
{
return NativeOnnxValueHelper . StringFromNativeUtf8 ( propertyValue ) ;
}
throw new ArgumentException ( "Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue . ToString ( ) ) ;
2023-05-01 17:01:38 +00:00
}
2023-11-02 17:01:53 +00:00
finally
2023-05-01 17:01:38 +00:00
{
allocator . FreeMemory ( propertyValue ) ;
}
2023-11-02 17:01:53 +00:00
}
/// <summary>
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
///
/// This function updates a model parameter in the checkpoint state with the given parameter data.
/// The training session must be already created with the checkpoint state that contains the parameter
/// being updated. The given parameter is copied over to the registered device for the training session.
/// The parameter must exist in the checkpoint state to be able to update it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
public void UpdateParameter ( string parameterName , OrtValue parameter )
{
if ( parameter . OnnxType ! = OnnxValueType . ONNX_TYPE_TENSOR )
2023-05-01 17:01:38 +00:00
{
2023-11-02 17:01:53 +00:00
throw new ArgumentException ( "Incorrect buffer received. Expected a tensor parameter." ) ;
2023-05-01 17:01:38 +00:00
}
2023-11-02 17:01:53 +00:00
var parameterNameUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( parameterName ) ;
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtUpdateParameter ( handle , parameterNameUtf8 , parameter . Handle ) ) ;
}
/// <summary>
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
///
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
/// The parameter is copied over to the provided OrtValue. The training session must be already created
/// with the checkpoint state that contains the parameter being retrieved.
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
public OrtValue GetParameter ( string parameterName )
{
var parameterNameUtf8 = NativeOnnxValueHelper . StringToZeroTerminatedUtf8 ( parameterName ) ;
NativeApiStatus . VerifySuccess ( NativeTrainingMethods . OrtGetParameter ( handle , parameterNameUtf8 , OrtAllocator . DefaultInstance . Pointer , out IntPtr parameterHandle ) ) ;
return new OrtValue ( parameterHandle ) ;
2023-04-21 18:36:01 +00:00
}
2022-09-02 20:13:48 +00:00
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of CheckpointState
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle ( )
{
NativeTrainingMethods . OrtReleaseCheckpointState ( handle ) ;
handle = IntPtr . Zero ;
return true ;
}
#endregion
}
#endif
}