mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
[On-Device Training] Expose Parameters through the Training API (#17364)
This commit is contained in:
parent
95e8dfaea5
commit
ccb73fd827
16 changed files with 942 additions and 165 deletions
|
|
@ -40,20 +40,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
String = 2
|
||||
}
|
||||
|
||||
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
|
||||
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
T[] value = new T[1];
|
||||
value[0] = propertyValue;
|
||||
Memory<T> memory = value;
|
||||
using (var memHandle = memory.Pin())
|
||||
T[] value = { propertyValue };
|
||||
unsafe
|
||||
{
|
||||
IntPtr memPtr;
|
||||
unsafe
|
||||
fixed (T* memPtr = value)
|
||||
{
|
||||
memPtr = (IntPtr)memHandle.Pointer;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
|
||||
}
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,13 +99,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given int property to the checkpoint state.
|
||||
/// Adds or updates the given int property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
|
||||
/// state by the user if they desire by calling this function with the appropriate property name and
|
||||
/// value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, long propertyValue)
|
||||
{
|
||||
|
|
@ -117,13 +113,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given float property to the checkpoint state.
|
||||
/// Adds or updates the given float property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
|
||||
/// state by the user if they desire by calling this function with the appropriate property name and
|
||||
/// value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, float propertyValue)
|
||||
{
|
||||
|
|
@ -131,28 +127,25 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given string property to the checkpoint state.
|
||||
/// Adds or updates the given string property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
|
||||
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
|
||||
/// name and value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, string propertyValue)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
|
||||
|
||||
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
|
||||
try
|
||||
unsafe
|
||||
{
|
||||
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
|
||||
}
|
||||
finally
|
||||
{
|
||||
Marshal.FreeHGlobal(unmanagedPointer);
|
||||
fixed (byte* p = propertyValueUtf8)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -162,34 +155,86 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
/// exist in the checkpoint state to be able to retrieve it successfully.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being retrieved.</param>
|
||||
/// <param name="propertyName">Name of the property being retrieved.</param>
|
||||
/// <returns>Property value associated with the given property name.</returns>
|
||||
public object GetProperty(string propertyName)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var allocator = OrtAllocator.DefaultInstance;
|
||||
IntPtr propertyValue = IntPtr.Zero;
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
|
||||
|
||||
if (propertyType == PropertyType.Int)
|
||||
try
|
||||
{
|
||||
var longPropertyValue = Marshal.ReadInt64(propertyValue);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return longPropertyValue;
|
||||
if (propertyType == PropertyType.Int)
|
||||
{
|
||||
Int64 value;
|
||||
unsafe
|
||||
{
|
||||
value = *(Int64*)propertyValue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
else if (propertyType == PropertyType.Float)
|
||||
{
|
||||
float value;
|
||||
unsafe
|
||||
{
|
||||
value = *(float*)propertyValue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
else if (propertyType == PropertyType.String)
|
||||
{
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
|
||||
}
|
||||
|
||||
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
|
||||
}
|
||||
else if (propertyType == PropertyType.Float)
|
||||
finally
|
||||
{
|
||||
float[] value = new float[1];
|
||||
Marshal.Copy(propertyValue, value, 0, 1);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return value[0];
|
||||
}
|
||||
else if (propertyType == PropertyType.String)
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
///
|
||||
/// This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
/// The training session must be already created with the checkpoint state that contains the parameter
|
||||
/// being updated. The given parameter is copied over to the registered device for the training session.
|
||||
/// The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
/// </summary>
|
||||
/// <param name="parameterName">Name of the parameter being updated.</param>
|
||||
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
|
||||
public void UpdateParameter(string parameterName, OrtValue parameter)
|
||||
{
|
||||
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
{
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
|
||||
}
|
||||
|
||||
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
|
||||
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
///
|
||||
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
/// The parameter is copied over to the provided OrtValue. The training session must be already created
|
||||
/// with the checkpoint state that contains the parameter being retrieved.
|
||||
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
/// </summary>
|
||||
/// <param name="parameterName">Name of the parameter being updated.</param>
|
||||
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
|
||||
public OrtValue GetParameter(string parameterName)
|
||||
{
|
||||
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));
|
||||
|
||||
return new OrtValue(parameterHandle);
|
||||
}
|
||||
|
||||
#region SafeHandle
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr AddProperty;
|
||||
public IntPtr GetProperty;
|
||||
public IntPtr LoadCheckpointFromBuffer;
|
||||
public IntPtr GetParameterTypeAndShape;
|
||||
public IntPtr UpdateParameter;
|
||||
public IntPtr GetParameter;
|
||||
}
|
||||
|
||||
internal static class NativeTrainingMethods
|
||||
|
|
@ -97,6 +100,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName));
|
||||
OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty));
|
||||
OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty));
|
||||
OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape));
|
||||
OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter));
|
||||
OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -359,6 +365,34 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
public static DOrtGetProperty OrtGetProperty;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape
|
||||
);
|
||||
|
||||
public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
IntPtr /*(OrtValue*)*/ parameter
|
||||
);
|
||||
|
||||
public static DOrtUpdateParameter OrtUpdateParameter;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out IntPtr /*(OrtValue**)*/ parameter
|
||||
);
|
||||
|
||||
public static DOrtGetParameter OrtGetParameter;
|
||||
|
||||
#endregion TrainingSession API
|
||||
|
||||
public static bool TrainingEnabled()
|
||||
|
|
|
|||
|
|
@ -358,13 +358,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
|
||||
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
|
||||
{
|
||||
if (!_evalOutputCount.Equals(outputValues.Count))
|
||||
if (_evalOutputCount != (ulong)outputValues.Count())
|
||||
{
|
||||
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
|
||||
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount}).");
|
||||
}
|
||||
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
|
||||
const bool isInput = true;
|
||||
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput);
|
||||
|
||||
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
|
||||
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
|
||||
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
|
||||
}
|
||||
|
|
@ -509,18 +510,17 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// Returns a contiguous buffer that holds a copy of all training state parameters
|
||||
/// </summary>
|
||||
/// <param name="onlyTrainable">Whether to only copy trainable parameters or to copy all parameters.</param>
|
||||
public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
|
||||
public OrtValue ToBuffer(bool onlyTrainable)
|
||||
{
|
||||
UIntPtr bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable));
|
||||
|
||||
float[] bufferMemory = new float[bufferSize.ToUInt64()];
|
||||
|
||||
var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
|
||||
var shape = new long[] { (long)bufferSize.ToUInt64() };
|
||||
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
|
||||
var shape = new long[] { (long)bufferSize };
|
||||
var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable));
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
|
@ -528,45 +528,30 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <summary>
|
||||
/// Loads the training session model parameters from a contiguous buffer
|
||||
/// </summary>
|
||||
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
|
||||
public void FromBuffer(FixedBufferOnnxValue buffer)
|
||||
/// <param name="ortValue">Contiguous buffer to load the parameters from.</param>
|
||||
/// <param name="onlyTrainable">Whether to only load trainable parameters or to load all parameters.</param>
|
||||
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
|
||||
{
|
||||
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
{
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
|
||||
}
|
||||
|
||||
IntPtr typeAndShapeInfo = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
|
||||
UIntPtr numDimensions = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
|
||||
if (numDimensions.ToUInt64() != 1)
|
||||
var tensorInfo = ortValue.GetTensorTypeAndShape();
|
||||
if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float)
|
||||
{
|
||||
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
// Here buffer size represents the number of elements in the buffer
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize));
|
||||
|
||||
// OrtGetParametersSize returns the total number of elements in the model's parameters.
|
||||
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
|
||||
if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
|
||||
return;
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float.");
|
||||
}
|
||||
|
||||
UIntPtr numElements = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
|
||||
if ((ulong)bufferSize != (ulong)numElements)
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable));
|
||||
if ((ulong)tensorInfo.ElementCount != (ulong)numElements)
|
||||
{
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -484,20 +484,23 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestToBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
using (var buffer = trainingSession.ToBuffer(true))
|
||||
{
|
||||
Assert.NotNull(buffer);
|
||||
var typeShape = buffer.GetTensorTypeAndShape();
|
||||
Assert.Equal(1, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(397510, fetchedShape[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -505,22 +508,25 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestFromBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
using (var buffer = trainingSession.ToBuffer(true))
|
||||
{
|
||||
Assert.NotNull(buffer);
|
||||
var typeShape = buffer.GetTensorTypeAndShape();
|
||||
Assert.Equal(1, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(397510, fetchedShape[0]);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
|
||||
trainingSession.FromBuffer(buffer);
|
||||
trainingSession.FromBuffer(buffer, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -530,6 +536,82 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
TrainingUtils.SetSeed(8888);
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestGetParameter")]
|
||||
public void TestGetParameter()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
using (var parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
Assert.NotNull(state);
|
||||
Assert.NotNull(parameter);
|
||||
|
||||
var typeShape = parameter.GetTensorTypeAndShape();
|
||||
Assert.Equal(2, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(500, fetchedShape[0]);
|
||||
Assert.Equal(784, fetchedShape[1]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestUpdateParameter")]
|
||||
public void TestUpdateParameter()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
Assert.NotNull(state);
|
||||
|
||||
using (var parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
Assert.NotNull(parameter);
|
||||
var typeShape = parameter.GetTensorTypeAndShape();
|
||||
|
||||
Assert.Equal(2, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(500, fetchedShape[0]);
|
||||
Assert.Equal(784, fetchedShape[1]);
|
||||
|
||||
float maxVal = 20;
|
||||
Random randNum = new Random();
|
||||
float[] updated_parameter_buffer = Enumerable
|
||||
.Repeat(0, 500 * 784)
|
||||
.Select(i => maxVal * (float)randNum.NextDouble())
|
||||
.ToArray();
|
||||
|
||||
using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape))
|
||||
{
|
||||
state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
using (var current_parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
|
||||
Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
|
||||
Assert.NotEqual(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
|
||||
}
|
||||
|
||||
state.UpdateParameter("fc1.weight", parameter);
|
||||
|
||||
using (var current_parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
|
||||
Assert.Equal(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
|
||||
Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class FloatComparer : IEqualityComparer<float>
|
||||
{
|
||||
private float atol = 1e-3f;
|
||||
|
|
|
|||
|
|
@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
|
||||
checkpoint_state
|
||||
.def(py::init())
|
||||
.def("add_property", [](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& property_name,
|
||||
const std::variant<int64_t, float, std::string>& property_value) {
|
||||
state->property_bag.AddProperty(property_name, property_value);
|
||||
})
|
||||
.def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
|
||||
})
|
||||
.def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.HasProperty(property_name);
|
||||
});
|
||||
.def("add_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& property_name,
|
||||
const std::variant<int64_t, float, std::string>& property_value) {
|
||||
state->property_bag.AddProperty(property_name, property_value);
|
||||
})
|
||||
.def("get_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
|
||||
})
|
||||
.def("has_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.HasProperty(property_name);
|
||||
})
|
||||
.def("copy_parameter_from",
|
||||
[](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& parameter_name, OrtValue& value) -> void {
|
||||
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == state->module_checkpoint_state.named_parameters.end()) {
|
||||
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
|
||||
}
|
||||
ORT_THROW_IF_ERROR(it->second->CopyFrom(
|
||||
state->module_checkpoint_state.train_session_data_transfer_mgr, value));
|
||||
})
|
||||
.def("get_parameter",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
|
||||
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == state->module_checkpoint_state.named_parameters.end()) {
|
||||
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
|
||||
}
|
||||
return it->second;
|
||||
})
|
||||
.def("has_parameter",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
|
||||
return state->module_checkpoint_state.named_parameters.count(parameter_name);
|
||||
})
|
||||
.def("parameter_names",
|
||||
[](onnxruntime::training::api::CheckpointState* state) {
|
||||
std::vector<std::string> names;
|
||||
for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) {
|
||||
names.push_back(name);
|
||||
}
|
||||
std::sort(names.begin(), names.end());
|
||||
return names;
|
||||
})
|
||||
.def("property_names",
|
||||
[](onnxruntime::training::api::CheckpointState* state) {
|
||||
std::vector<std::string> names;
|
||||
for ([[maybe_unused]] auto& [name, value] : state->property_bag) {
|
||||
names.push_back(name);
|
||||
}
|
||||
std::sort(names.begin(), names.end());
|
||||
return names;
|
||||
});
|
||||
|
||||
py::class_<PyOptimizer>
|
||||
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
|
||||
|
|
@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
ORT_THROW_IF_ERROR(scheduler->Step());
|
||||
});
|
||||
|
||||
py::class_<onnxruntime::training::api::Parameter,
|
||||
std::unique_ptr<onnxruntime::training::api::Parameter, py::nodelete>>
|
||||
parameter(m, "Parameter");
|
||||
parameter
|
||||
.def_property_readonly("name", &onnxruntime::training::api::Parameter::Name)
|
||||
.def_property_readonly("data", &onnxruntime::training::api::Parameter::Data)
|
||||
.def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient)
|
||||
.def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad)
|
||||
.def("copy_from",
|
||||
[](onnxruntime::training::api::Parameter* parameter,
|
||||
onnxruntime::training::api::CheckpointState* state,
|
||||
OrtValue& value) -> void {
|
||||
ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"save_checkpoint",
|
||||
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,198 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
||||
|
||||
|
||||
class Parameter:
|
||||
"""Class that represents a model parameter
|
||||
|
||||
This class represents a model parameter and provides access to its data,
|
||||
gradient and other properties. This class is not expected to be instantiated directly.
|
||||
Instead, it is returned by the `CheckpointState` object.
|
||||
|
||||
Args:
|
||||
parameter: The C.Parameter object that holds the underlying parameter data.
|
||||
state: The C.CheckpointState object that holds the underlying session state.
|
||||
"""
|
||||
|
||||
def __init__(self, parameter: C.Parameter, state: C.CheckpointState):
|
||||
self._parameter = parameter
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the parameter"""
|
||||
return self._parameter.name
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
"""The data of the parameter"""
|
||||
return self._parameter.data.numpy()
|
||||
|
||||
@data.setter
|
||||
def data(self, value: np.ndarray) -> None:
|
||||
"""Sets the data of the parameter"""
|
||||
self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue)
|
||||
|
||||
@property
|
||||
def grad(self) -> np.ndarray:
|
||||
"""The gradient of the parameter"""
|
||||
return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None
|
||||
|
||||
@property
|
||||
def requires_grad(self) -> bool:
|
||||
"""Whether or not the parameter requires its gradient to be computed"""
|
||||
return self._parameter.requires_grad
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the parameter"""
|
||||
return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"
|
||||
|
||||
|
||||
class Parameters:
|
||||
"""Class that holds all the model parameters
|
||||
|
||||
This class holds all the model parameters and provides access to them.
|
||||
This class is not expected to be instantiated directly. Instead, it is returned by the
|
||||
`CheckpointState`'s parameters attribute.
|
||||
This class behaves like a dictionary and provides access to the parameters by name.
|
||||
|
||||
Args:
|
||||
state: The C.CheckpointState object that holds the underlying session state.
|
||||
"""
|
||||
|
||||
def __init__(self, state: C.CheckpointState):
|
||||
self._state = state
|
||||
|
||||
def __getitem__(self, name: str) -> Parameter:
|
||||
"""Gets the parameter associated with the given name
|
||||
|
||||
Searches for the name in the parameters of the checkpoint state.
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
|
||||
Returns:
|
||||
The value of the parameter
|
||||
|
||||
Raises:
|
||||
KeyError: If the parameter is not found
|
||||
"""
|
||||
|
||||
if name not in self:
|
||||
raise KeyError(f"Parameter {name} not found.")
|
||||
|
||||
return Parameter(self._state.get_parameter(name), self._state)
|
||||
|
||||
def __setitem__(self, name: str, value: np.ndarray) -> None:
|
||||
"""Sets the parameter value for the given name
|
||||
|
||||
Searches for the name in the parameters of the checkpoint state.
|
||||
If the name is found in parameters, the value is updated.
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
value: The value of the parameter as a numpy array
|
||||
|
||||
Raises:
|
||||
KeyError: If the parameter is not found
|
||||
"""
|
||||
if name not in self:
|
||||
raise KeyError(f"Parameter {name} not found.")
|
||||
|
||||
self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the parameter exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
|
||||
Returns:
|
||||
True if the name is a parameter False otherwise
|
||||
"""
|
||||
|
||||
return self._state.has_parameter(name)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the properties"""
|
||||
for parameter_name in self._state.parameter_names():
|
||||
yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the parameters"""
|
||||
return self._state.parameter_names()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of parameters"""
|
||||
return len(self._state.parameter_names())
|
||||
|
||||
|
||||
class Properties:
|
||||
def __init__(self, state: C.CheckpointState):
|
||||
self._state = state
|
||||
|
||||
def __getitem__(self, name: str) -> int | float | str:
|
||||
"""Gets the property associated with the given name
|
||||
|
||||
Searches for the name in the properties of the checkpoint state.
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
The value of the property
|
||||
|
||||
Raises:
|
||||
KeyError: If the property is not found
|
||||
"""
|
||||
|
||||
if name not in self:
|
||||
raise KeyError(f"Property {name} not found.")
|
||||
|
||||
return self._state.get_property(name)
|
||||
|
||||
def __setitem__(self, name: str, value: int | float | str) -> None:
|
||||
"""Sets the property value for the given name
|
||||
|
||||
Searches for the name in the properties of the checkpoint state.
|
||||
The value is added or updated in the properties.
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
value: The value of the property
|
||||
Properties only support int, float and str values.
|
||||
"""
|
||||
self._state.add_property(name, value)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the property exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
True if the name is a property, False otherwise
|
||||
"""
|
||||
|
||||
return self._state.has_property(name)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the properties"""
|
||||
for property_name in self._state.property_names():
|
||||
yield property_name, self._state.get_property(property_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the properties"""
|
||||
return self._state.property_names()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of properties"""
|
||||
return len(self._state.property_names())
|
||||
|
||||
|
||||
class CheckpointState:
|
||||
|
|
@ -14,8 +205,6 @@ class CheckpointState:
|
|||
This class holds all the state information of the training session such as the model parameters,
|
||||
its gradients, the optimizer state and user defined properties.
|
||||
|
||||
User defined properties can be indexed by name from the `CheckpointState` object.
|
||||
|
||||
To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method.
|
||||
|
||||
Args:
|
||||
|
|
@ -26,6 +215,8 @@ class CheckpointState:
|
|||
if not isinstance(state, C.CheckpointState):
|
||||
raise TypeError(f"Invalid argument for CheckpointState received {type(state)}")
|
||||
self._state = state
|
||||
self._parameters = Parameters(self._state)
|
||||
self._properties = Properties(self._state)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
|
||||
|
|
@ -52,33 +243,12 @@ class CheckpointState:
|
|||
"""
|
||||
C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
|
||||
|
||||
def __getitem__(self, name: str) -> int | float | str:
|
||||
"""Gets the property associated with the given name
|
||||
@property
|
||||
def parameters(self) -> Parameters:
|
||||
"""Returns the model parameters from the checkpoint state"""
|
||||
return self._parameters
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
The value of the property
|
||||
"""
|
||||
return self._state.get_property(name)
|
||||
|
||||
def __setitem__(self, name: str, value: int | float | str) -> None:
|
||||
"""Sets the property value for the given name
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
value: The value of the property
|
||||
"""
|
||||
self._state.add_property(name, value)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the property exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
True if the property exists, False otherwise
|
||||
"""
|
||||
return self._state.has_property(name)
|
||||
@property
|
||||
def properties(self) -> Properties:
|
||||
"""Returns the properties from the checkpoint state"""
|
||||
return self._properties
|
||||
|
|
|
|||
|
|
@ -360,14 +360,18 @@ def test_add_get_property(property_value):
|
|||
if isinstance(property_value, float):
|
||||
property_value = float(np.float32(property_value))
|
||||
|
||||
state["property"] = property_value
|
||||
assert "property" in state
|
||||
assert state["property"] == property_value
|
||||
assert len(state.properties) == 0
|
||||
|
||||
state.properties["property"] = property_value
|
||||
assert "property" in state.properties
|
||||
assert state.properties["property"] == property_value
|
||||
assert len(state.properties) == 1
|
||||
|
||||
CheckpointState.save_checkpoint(state, checkpoint_file_path)
|
||||
new_state = CheckpointState.load_checkpoint(checkpoint_file_path)
|
||||
assert "property" in new_state
|
||||
assert new_state["property"] == property_value
|
||||
assert "property" in new_state.properties
|
||||
assert new_state.properties["property"] == property_value
|
||||
assert len(new_state.properties) == 1
|
||||
|
||||
|
||||
def test_get_input_output_names():
|
||||
|
|
@ -563,3 +567,60 @@ def test_eval_step_with_ort_values():
|
|||
fetches = model(inputs, labels)
|
||||
assert isinstance(fetches, OrtValue)
|
||||
assert fetches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_get_and_set_parameter_values(device):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
(
|
||||
checkpoint_file_path,
|
||||
training_model_file_path,
|
||||
eval_model_file_path,
|
||||
_,
|
||||
pt_model,
|
||||
) = _create_training_artifacts(
|
||||
temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
|
||||
)
|
||||
|
||||
state = CheckpointState.load_checkpoint(checkpoint_file_path)
|
||||
|
||||
model = Module(training_model_file_path, state, eval_model_file_path, device=device)
|
||||
|
||||
state_dict = pt_model.state_dict()
|
||||
assert len(state_dict) == len(state.parameters)
|
||||
for parameter_name, _ in state.parameters:
|
||||
assert parameter_name in state_dict
|
||||
|
||||
for name, pt_param in pt_model.named_parameters():
|
||||
ort_param = state.parameters[name]
|
||||
assert ort_param.name == name
|
||||
assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data)
|
||||
if name in ["fc1.weight", "fc1.bias"]:
|
||||
assert ort_param.requires_grad is False
|
||||
assert ort_param.grad is None
|
||||
else:
|
||||
assert ort_param.requires_grad is True
|
||||
assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32))
|
||||
|
||||
original_param = state.parameters["fc1.weight"].data
|
||||
state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32)
|
||||
updated_param = state.parameters["fc1.weight"].data
|
||||
assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32))
|
||||
|
||||
model.train()
|
||||
inputs = torch.randn(64, 784).numpy()
|
||||
labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
|
||||
loss = model(inputs, labels)
|
||||
assert loss is not None
|
||||
for name, _ in pt_model.named_parameters():
|
||||
ort_param = state.parameters[name]
|
||||
assert ort_param.name == name
|
||||
if name in ["fc1.weight", "fc1.bias"]:
|
||||
assert ort_param.requires_grad is False
|
||||
assert ort_param.grad is None
|
||||
else:
|
||||
assert ort_param.requires_grad is True
|
||||
assert ort_param.grad.any()
|
||||
|
||||
state.parameters["fc1.weight"] = original_param
|
||||
assert np.allclose(state.parameters["fc1.weight"].data, original_param)
|
||||
|
|
|
|||
|
|
@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
|
|||
testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, GetParameter) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, UpdateParameter) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
|
||||
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
|
||||
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
|
||||
Ort::Value updated_parameter{updated_param_value};
|
||||
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
|
||||
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
|
||||
checkpoint_state.UpdateParameter("fc1.weight", parameter);
|
||||
current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(TrainingCApiTest, UpdateParameterDifferentDevices) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
|
||||
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
|
||||
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
|
||||
Ort::Value updated_parameter{updated_param_value};
|
||||
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
|
||||
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
|
||||
checkpoint_state.UpdateParameter("fc1.weight", parameter);
|
||||
current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime::training::test
|
||||
|
|
|
|||
|
|
@ -22,10 +22,12 @@ struct PropertyBag {
|
|||
PropertyBag() = default;
|
||||
|
||||
void AddProperty(const std::string& name, const PropertyDataType& val) {
|
||||
ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
|
||||
"Duplicated property named ", name);
|
||||
|
||||
named_properties_.insert({name, val});
|
||||
auto it = named_properties_.find(name);
|
||||
if (it == named_properties_.end()) {
|
||||
named_properties_.insert({name, val});
|
||||
} else {
|
||||
it->second = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -608,14 +608,14 @@ struct OrtTrainingApi {
|
|||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Adds the given property to the checkpoint state.
|
||||
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user if they desire by calling this function with the appropriate property name and
|
||||
* value. The given property name must be unique to be able to successfully add the property.
|
||||
* state by the user by calling this function with the corresponding property name and value.
|
||||
* The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
||||
* \param[in] property_name Unique name of the property being added.
|
||||
* \param[in] property_name Name of the property being added or updated.
|
||||
* \param[in] property_type Type of the property associated with the given name.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
|
|
@ -632,7 +632,7 @@ struct OrtTrainingApi {
|
|||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
||||
* \param[in] property_name Unique name of the property being retrieved.
|
||||
* \param[in] property_name Name of the property being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
||||
* \param[out] property_type Type of the property associated with the given name.
|
||||
* \param[out] property_value Property value associated with the given name.
|
||||
|
|
@ -669,6 +669,57 @@ struct OrtTrainingApi {
|
|||
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
|
||||
*
|
||||
* This function retrieves the type and shape of the parameter associated with the given parameter name.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||
|
||||
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
* The training session must be already created with the checkpoint state that contains the parameter
|
||||
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being updated.
|
||||
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||
|
||||
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
* The parameter is copied over and returned as an OrtValue. The training session must be already created
|
||||
* with the checkpoint state that contains the parameter being retrieved.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the parameter.
|
||||
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -112,13 +112,13 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
||||
const bool include_optimizer_state = false);
|
||||
|
||||
/** \brief Adds the given property to the checkpoint state.
|
||||
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user if they desire by calling this function with the appropriate property name and
|
||||
* value. The given property name must be unique to be able to successfully add the property.
|
||||
* state by the user by calling this function with the corresponding property name and value.
|
||||
* The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] property_name Unique name of the property being added.
|
||||
* \param[in] property_name Name of the property being added or updated.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
*/
|
||||
|
|
@ -129,12 +129,38 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] property_name Unique name of the property being retrieved.
|
||||
* \param[in] property_name Name of the property being retrieved.
|
||||
* \return Property value associated with the given property name.
|
||||
*
|
||||
*/
|
||||
Property GetProperty(const std::string& property_name);
|
||||
|
||||
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
* The training session must be already created with the checkpoint state that contains the parameter
|
||||
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
*
|
||||
* \param[in] parameter_name Name of the parameter being updated.
|
||||
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||
*
|
||||
*/
|
||||
void UpdateParameter(const std::string& parameter_name, const Value& parameter);
|
||||
|
||||
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
* The parameter is copied over to the provided OrtValue. The training session must be already created
|
||||
* with the checkpoint state that contains the parameter being retrieved.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \return The parameter data that is retrieved from the checkpoint state.
|
||||
*
|
||||
*/
|
||||
Value GetParameter(const std::string& parameter_name);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) {
|
|||
return property;
|
||||
}
|
||||
|
||||
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
|
||||
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
|
||||
}
|
||||
|
||||
inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
|
||||
AllocatorWithDefaultOptions allocator;
|
||||
OrtValue* parameter;
|
||||
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter));
|
||||
|
||||
return Value{parameter};
|
||||
}
|
||||
|
||||
} // namespace Ort
|
||||
|
|
|
|||
|
|
@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph,
|
|||
#endif
|
||||
} // namespace
|
||||
|
||||
Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const {
|
||||
ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it.");
|
||||
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
|
||||
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
|
||||
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
|
||||
", Got: ", data.Get<Tensor>().Shape().ToString());
|
||||
#ifdef ENABLE_STRIDED_TENSORS
|
||||
auto data_strides = data.Get<Tensor>().Strides();
|
||||
auto param_strides = data_.Get<Tensor>().Strides();
|
||||
ORT_ENFORCE(data_strides.size() == param_strides.size(),
|
||||
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
|
||||
", Got: ", data_strides.size());
|
||||
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
|
||||
"Parameter data stride value mismatch.");
|
||||
#endif
|
||||
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
|
||||
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
|
||||
", Got: ", data.Get<Tensor>().DataType());
|
||||
ORT_ENFORCE(data_transfer_manager != nullptr,
|
||||
"Data transfer manager must be provided to copy data to the parameter. "
|
||||
"Please create the TrainingSession before trying to update the parameter.");
|
||||
|
||||
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get<Tensor>(), *data.GetMutable<Tensor>()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) {
|
||||
ORT_ENFORCE(data_.IsAllocated(),
|
||||
"The checkpoint parameter is not allocated. Cannot copy the given parameter data to it.");
|
||||
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
|
||||
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
|
||||
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
|
||||
", Got: ", data.Get<Tensor>().Shape().ToString());
|
||||
#ifdef ENABLE_STRIDED_TENSORS
|
||||
auto data_strides = data.Get<Tensor>().Strides();
|
||||
auto param_strides = data_.Get<Tensor>().Strides();
|
||||
ORT_ENFORCE(data_strides.size() == param_strides.size(),
|
||||
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
|
||||
", Got: ", data_strides.size());
|
||||
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
|
||||
"Parameter data stride value mismatch.");
|
||||
#endif
|
||||
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
|
||||
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
|
||||
", Got: ", data.Get<Tensor>().DataType());
|
||||
ORT_ENFORCE(data_transfer_manager != nullptr,
|
||||
"Data transfer manager must be provided to copy data to the parameter. "
|
||||
"Please create the TrainingSession before trying to update the parameter.");
|
||||
|
||||
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get<Tensor>(), *data_.GetMutable<Tensor>()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
|
||||
// assert param is allocated
|
||||
ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient.");
|
||||
|
|
@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers,
|
|||
}
|
||||
}
|
||||
|
||||
Module::~Module() {
|
||||
state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr;
|
||||
}
|
||||
|
||||
size_t Module::GetTrainingModelOutputCount() const noexcept {
|
||||
return train_output_names_.size();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ struct Parameter {
|
|||
|
||||
// Return the mutable data.
|
||||
OrtValue& Data() { return data_; }
|
||||
Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const;
|
||||
Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data);
|
||||
const std::string& Name() const { return name_; }
|
||||
|
||||
// Returns whether this parameter is trainable or not.
|
||||
|
|
@ -34,7 +36,6 @@ struct Parameter {
|
|||
// Reset and release the gradient buffer of this Parameter greedily.
|
||||
Status ResetGrad();
|
||||
|
||||
protected:
|
||||
Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad);
|
||||
|
||||
private:
|
||||
|
|
@ -83,6 +84,8 @@ struct Module {
|
|||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());
|
||||
|
||||
~Module();
|
||||
|
||||
// Return the trainable/nontrainable parameters
|
||||
std::vector<std::shared_ptr<Parameter>> Parameters() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void*
|
|||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
if (checkpoint_buffer == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr.");
|
||||
}
|
||||
|
||||
*checkpoint_state = nullptr;
|
||||
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
|
||||
const auto* checkpoint_bytes = reinterpret_cast<const uint8_t*>(checkpoint_buffer);
|
||||
|
|
@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState*
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape);
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter) {
|
||||
API_IMPL_BEGIN
|
||||
if (parameter == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
|
||||
}
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom(
|
||||
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter));
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
if (parameter == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
|
||||
}
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
if (!it->second->Data().IsTensor()) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type.");
|
||||
}
|
||||
const auto& parameter_tensor = it->second->Data().Get<onnxruntime::Tensor>();
|
||||
ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue(
|
||||
allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(),
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter));
|
||||
|
||||
auto status = it->second->CopyTo(
|
||||
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter);
|
||||
if (!status.IsOK()) {
|
||||
OrtApis::ReleaseValue(*parameter);
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
static constexpr OrtTrainingApi ort_training_api = {
|
||||
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
|
||||
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
|
||||
|
|
@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = {
|
|||
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
|
||||
&OrtTrainingApis::AddProperty,
|
||||
&OrtTrainingApis::GetProperty,
|
||||
&OrtTrainingApis::LoadCheckpointFromBuffer};
|
||||
&OrtTrainingApis::LoadCheckpointFromBuffer,
|
||||
&OrtTrainingApis::GetParameterTypeAndShape,
|
||||
&OrtTrainingApis::UpdateParameter,
|
||||
&OrtTrainingApis::GetParameter};
|
||||
|
||||
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
|
||||
// No constraints on the API version yet.
|
||||
|
|
|
|||
|
|
@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
|
|||
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||
|
||||
ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||
|
||||
ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter);
|
||||
|
||||
} // namespace OrtTrainingApis
|
||||
|
|
|
|||
Loading…
Reference in a new issue