// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_APIS__
///
/// Holds the Checkpoint State as generated/consumed by on-device training APIs
///
public class CheckpointState : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
private CheckpointState(IntPtr checkpointHandle)
: base(checkpointHandle, true)
{
}
internal enum PropertyType : long
{
Int = 0,
Float = 1,
String = 2
}
private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory memory = value;
using (var memHandle = memory.Pin())
{
IntPtr memPtr;
unsafe
{
memPtr = (IntPtr)memHandle.Pointer;
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// Loads Checkpoint state from path
///
/// absolute path to checkpoint
public static CheckpointState LoadCheckpoint(string checkpointPath)
{
if (!NativeTrainingMethods.TrainingEnabled())
{
throw new InvalidOperationException("This package does not contain the training API. Please install the Microsoft.ML.OnnxRuntime.Training NuGet package.\n");
}
var envHandle = OrtEnv.Instance().Handle; // just so it is initialized
IntPtr checkpointHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out checkpointHandle));
return new CheckpointState(checkpointHandle);
}
///
/// Saves the checkpoint
/// absolute path to the checkpoint file.
/// absolute path to the checkpoint file.
///
public static void SaveCheckpoint(CheckpointState state, string checkpointPath, bool includeOptimizerState = false)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(state.Handle, NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), includeOptimizerState));
}
///
/// Adds the given int property to the checkpoint state.
/// Unique name of the property being added.
/// Property value associated with the given name.
///
public void AddProperty(string propertyName, long propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Int, propertyValue);
}
///
/// Adds the given float property to the checkpoint state.
/// Unique name of the property being added.
/// Property value associated with the given name.
///
public void AddProperty(string propertyName, float propertyValue)
{
AddPropertyImpl(propertyName, PropertyType.Float, propertyValue);
}
///
/// Adds the given string property to the checkpoint state.
/// Unique name of the property being added.
/// Property value associated with the given name.
///
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
{
Marshal.FreeHGlobal(unmanagedPointer);
}
}
///
/// Gets the property value associated with the given name from the checkpoint state.
/// Unique name of the property being retrieved.
///
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
if (propertyType == PropertyType.Int)
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
}
else if (propertyType == PropertyType.Float)
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
}
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of CheckpointState
///
/// always returns true
protected override bool ReleaseHandle()
{
NativeTrainingMethods.OrtReleaseCheckpointState(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
#endif
}