Some cherry-picks for the 1.16.2 release (#18218)

Cherry-pick PRs: 
#18026 
#17912 
#17901 “2 lines added whitespace errors when cherry-picking"
#17293 
#17364 
#17505 
#17885

This PR contains all the cherry-picks for the patch release except:
1. The PRs marked with sdxl_llama
2. #17772 which has a merge conflict.

---------

Co-authored-by: Chi Lo <Chi.Lo@microsoft.com>
Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
Co-authored-by: Kaz Nishimura <kazssym@linuxfront.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
This commit is contained in:
Changming Sun 2023-11-02 10:01:53 -07:00 committed by GitHub
parent bc533a6723
commit 2f57f1e4d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 1176 additions and 294 deletions

View file

@ -1860,54 +1860,61 @@ namespace Microsoft.ML.OnnxRuntime
public static DOrtFillStringTensor OrtFillStringTensor;
/// \param value A tensor created from OrtCreateTensor... function.
/// \param index The index of the entry in the tensor to resize. <summary>
/// \param length_in_bytes Length to resize the string to.
/// \param buffer The resized buffer.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer(
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer
);
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer);
public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);
IntPtr /*(OrtValue*)*/ value,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
UIntPtr[] offsets,
UIntPtr offsets_len);
public static DOrtGetStringTensorContent OrtGetStringTensorContent;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value,
out UIntPtr /*(size_t*)*/ len);
out UIntPtr /*(size_t*)*/ len);
public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);
UIntPtr /*(size_t)*/ index,
out UIntPtr /*(size_t*)*/ len);
public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value,
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);
UIntPtr /*(size_t)*/ bufferLength,
UIntPtr /*(size_t)*/ elementIndex,
byte[] buffer);
public static DOrtGetStringTensorElement OrtGetStringTensorElement;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/
DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo(
IntPtr /*(struct OrtTypeInfo*)*/ typeInfo,
out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(
IntPtr /*(OrtValue*)*/ value,
out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape;
@ -1917,12 +1924,16 @@ namespace Microsoft.ML.OnnxRuntime
public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out IntPtr /*(TensorElementType*)*/ output);
public static DOrtGetTensorElementType OrtGetTensorElementType;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output);
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
out UIntPtr output);
public static DOrtGetDimensionsCount OrtGetDimensionsCount;

View file

@ -40,20 +40,16 @@ namespace Microsoft.ML.OnnxRuntime
String = 2
}
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
T[] value = new T[1];
value[0] = propertyValue;
Memory<T> memory = value;
using (var memHandle = memory.Pin())
T[] value = { propertyValue };
unsafe
{
IntPtr memPtr;
unsafe
fixed (T* memPtr = value)
{
memPtr = (IntPtr)memHandle.Pointer;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
}
}
@ -103,13 +99,13 @@ namespace Microsoft.ML.OnnxRuntime
}
/// <summary>
/// Adds the given int property to the checkpoint state.
/// Adds or updates the given int property to/in the checkpoint state.
///
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, long propertyValue)
{
@ -117,13 +113,13 @@ namespace Microsoft.ML.OnnxRuntime
}
/// <summary>
/// Adds the given float property to the checkpoint state.
/// Adds or updates the given float property to/in the checkpoint state.
///
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
/// state by the user if they desire by calling this function with the appropriate property name and
/// value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, float propertyValue)
{
@ -131,28 +127,25 @@ namespace Microsoft.ML.OnnxRuntime
}
/// <summary>
/// Adds the given string property to the checkpoint state.
/// Adds or updates the given string property to/in the checkpoint state.
///
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
/// name and value. The given property name must be unique to be able to successfully add the property.
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
/// state by the user by calling this function with the corresponding property name and value.
/// The given property name must be unique to be able to successfully add the property.
/// </summary>
/// <param name="propertyName">Unique name of the property being added.</param>
/// <param name="propertyName">Name of the property being added or updated.</param>
/// <param name="propertyValue">Property value associated with the given name.</param>
public void AddProperty(string propertyName, string propertyValue)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
unsafe
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
{
Marshal.FreeHGlobal(unmanagedPointer);
fixed (byte* p = propertyValueUtf8)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
}
}
}
@ -162,34 +155,86 @@ namespace Microsoft.ML.OnnxRuntime
/// Gets the property value from an existing entry in the checkpoint state. The property must
/// exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="propertyName">Unique name of the property being retrieved.</param>
/// <param name="propertyName">Name of the property being retrieved.</param>
/// <returns>Property value associated with the given property name.</returns>
public object GetProperty(string propertyName)
{
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var allocator = OrtAllocator.DefaultInstance;
IntPtr propertyValue = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
if (propertyType == PropertyType.Int)
try
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
if (propertyType == PropertyType.Int)
{
Int64 value;
unsafe
{
value = *(Int64*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.Float)
{
float value;
unsafe
{
value = *(float*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
}
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
}
else if (propertyType == PropertyType.Float)
finally
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
}
else if (propertyType == PropertyType.String)
}
/// <summary>
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
///
/// This function updates a model parameter in the checkpoint state with the given parameter data.
/// The training session must be already created with the checkpoint state that contains the parameter
/// being updated. The given parameter is copied over to the registered device for the training session.
/// The parameter must exist in the checkpoint state to be able to update it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
public void UpdateParameter(string parameterName, OrtValue parameter)
{
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
}
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
}
/// <summary>
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
///
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
/// The parameter is copied over to the provided OrtValue. The training session must be already created
/// with the checkpoint state that contains the parameter being retrieved.
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
/// </summary>
/// <param name="parameterName">Name of the parameter being updated.</param>
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
public OrtValue GetParameter(string parameterName)
{
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));
return new OrtValue(parameterHandle);
}
#region SafeHandle

View file

@ -42,6 +42,9 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr AddProperty;
public IntPtr GetProperty;
public IntPtr LoadCheckpointFromBuffer;
public IntPtr GetParameterTypeAndShape;
public IntPtr UpdateParameter;
public IntPtr GetParameter;
}
internal static class NativeTrainingMethods
@ -97,6 +100,9 @@ namespace Microsoft.ML.OnnxRuntime
OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName));
OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty));
OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty));
OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape));
OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter));
OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter));
}
}
@ -359,6 +365,34 @@ namespace Microsoft.ML.OnnxRuntime
public static DOrtGetProperty OrtGetProperty;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape
);
public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
IntPtr /*(OrtValue*)*/ parameter
);
public static DOrtUpdateParameter OrtUpdateParameter;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
byte[] /*(const char*)*/ parameterName,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(OrtValue**)*/ parameter
);
public static DOrtGetParameter OrtGetParameter;
#endregion TrainingSession API
public static bool TrainingEnabled()

View file

@ -358,13 +358,14 @@ namespace Microsoft.ML.OnnxRuntime
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
{
if (!_evalOutputCount.Equals(outputValues.Count))
if (_evalOutputCount != (ulong)outputValues.Count())
{
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount}).");
}
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
const bool isInput = true;
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
@ -509,18 +510,17 @@ namespace Microsoft.ML.OnnxRuntime
/// Returns a contiguous buffer that holds a copy of all training state parameters
/// </summary>
/// <param name="onlyTrainable">Whether to only copy trainable parameters or to copy all parameters.</param>
public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
public OrtValue ToBuffer(bool onlyTrainable)
{
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable));
float[] bufferMemory = new float[bufferSize.ToUInt64()];
var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
var shape = new long[] { (long)bufferSize.ToUInt64() };
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
var shape = new long[] { (long)bufferSize };
var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));
return buffer;
}
@ -528,45 +528,30 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Loads the training session model parameters from a contiguous buffer
/// </summary>
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
public void FromBuffer(FixedBufferOnnxValue buffer)
/// <param name="ortValue">Contiguous buffer to load the parameters from.</param>
/// <param name="onlyTrainable">Whether to only load trainable parameters or to load all parameters.</param>
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
{
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
}
IntPtr typeAndShapeInfo = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
UIntPtr numDimensions = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
if (numDimensions.ToUInt64() != 1)
var tensorInfo = ortValue.GetTensorTypeAndShape();
if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float)
{
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
throw new ArgumentException(errorMessage);
}
// Here buffer size represents the number of elements in the buffer
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize));
// OrtGetParametersSize returns the total number of elements in the model's parameters.
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
return;
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float.");
}
UIntPtr numElements = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
if ((ulong)bufferSize != (ulong)numElements)
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable));
if ((ulong)tensorInfo.ElementCount != (ulong)numElements)
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString();
throw new ArgumentException(errorMessage);
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable));
}
/// <summary>

View file

@ -484,20 +484,23 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestToBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
using (var buffer = trainingSession.ToBuffer(true))
{
Assert.NotNull(buffer);
var typeShape = buffer.GetTensorTypeAndShape();
Assert.Equal(1, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(397510, fetchedShape[0]);
}
}
}
@ -505,22 +508,25 @@ namespace Microsoft.ML.OnnxRuntime.Tests
public void TestFromBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);
using (var buffer = trainingSession.ToBuffer(true))
{
Assert.NotNull(buffer);
var typeShape = buffer.GetTensorTypeAndShape();
Assert.Equal(1, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(397510, fetchedShape[0]);
var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
trainingSession.FromBuffer(buffer);
trainingSession.FromBuffer(buffer, true);
}
}
}
@ -530,6 +536,82 @@ namespace Microsoft.ML.OnnxRuntime.Tests
TrainingUtils.SetSeed(8888);
}
[Fact(DisplayName = "TestGetParameter")]
public void TestGetParameter()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
using (var parameter = state.GetParameter("fc1.weight"))
{
Assert.NotNull(state);
Assert.NotNull(parameter);
var typeShape = parameter.GetTensorTypeAndShape();
Assert.Equal(2, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(500, fetchedShape[0]);
Assert.Equal(784, fetchedShape[1]);
}
}
[Fact(DisplayName = "TestUpdateParameter")]
public void TestUpdateParameter()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
Assert.NotNull(state);
using (var parameter = state.GetParameter("fc1.weight"))
{
Assert.NotNull(parameter);
var typeShape = parameter.GetTensorTypeAndShape();
Assert.Equal(2, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(500, fetchedShape[0]);
Assert.Equal(784, fetchedShape[1]);
float maxVal = 20;
Random randNum = new Random();
float[] updated_parameter_buffer = Enumerable
.Repeat(0, 500 * 784)
.Select(i => maxVal * (float)randNum.NextDouble())
.ToArray();
using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape))
{
state.UpdateParameter("fc1.weight", updated_parameter);
using (var current_parameter = state.GetParameter("fc1.weight"))
{
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
Assert.NotEqual(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
}
state.UpdateParameter("fc1.weight", parameter);
using (var current_parameter = state.GetParameter("fc1.weight"))
{
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
}
}
}
}
}
internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;

View file

@ -763,6 +763,7 @@ Do not modify directly.*
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|

View file

@ -451,12 +451,16 @@ Return Value:
#if defined(_WIN32)
HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we
// disable it there.
uint64_t isar0_el1;
asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
#else
// Use the cpuinfo value which is read from sysctl and has some additional special cases.
// https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379
// Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips
// as well as failing on other ARM chips as it is an EL1 level register that requires extra
// privileges to read.
//
// uint64_t isar0_el1;
// asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
// HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();
#endif

View file

@ -20,7 +20,7 @@ namespace cuda {
// float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions
// CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2)))
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
__device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); }
__device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); }
@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
template <typename T>
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
template <typename T>
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
template <typename T>
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }
template <>
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
template <typename T>
__device__ __inline__ T _Normcdf(T a);

View file

@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,

View file

@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
UNARY_OP_BWUZCSILHFD(Sign, 13)
UNARY_LOGICALOP_NOT_TYPED(1, bool)
UNARY_OP_HFD(Round, 11)

View file

@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Sign final : public UnaryElementwise {
public:
Sign(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign)
// When casting, half needs to be converted via float type from most other types
template <typename T>
@ -119,52 +120,52 @@ struct OP_Cast {
}
};
#define IMPL_CAST_IMPL(InT, OutT) \
#define IMPL_CAST_IMPL(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
}
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
}
#if !defined(DISABLE_FLOAT8_TYPES)
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ)
#else
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
#define IMPL_CAST_IMPL_FROM(T) \
IMPL_CAST_IMPL(T, half) \
IMPL_CAST_IMPL(T, float) \
IMPL_CAST_IMPL(T, double) \
IMPL_CAST_IMPL(T, int8_t) \
IMPL_CAST_IMPL(T, int16_t) \
IMPL_CAST_IMPL(T, int32_t) \
IMPL_CAST_IMPL(T, int64_t) \
IMPL_CAST_IMPL(T, uint8_t) \
IMPL_CAST_IMPL(T, uint16_t) \
IMPL_CAST_IMPL(T, uint32_t) \
IMPL_CAST_IMPL(T, uint64_t) \
IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16)
#endif
@ -199,58 +200,58 @@ struct OP_CastNoSat {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
} \
};
#else
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), true); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), false); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
#define OP_CAST(T, NVT) \
template <> \
struct OP_CastSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), true); \
} \
}; \
template <> \
struct OP_CastNoSat<half, T> { \
__device__ __inline__ T operator()(const half& v) const { \
return T(__half2float(v), false); \
} \
}; \
template <> \
struct OP_CastSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(v, true); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
return T(v, true); \
} \
}; \
template <> \
struct OP_CastNoSat<float, T> { \
__device__ __inline__ T operator()(const float& v) const { \
return T(v, false); \
} \
return T(v, false); \
} \
};
#endif
@ -260,14 +261,13 @@ struct OP_CastNoSat {
OP_CAST(Float8E4M3FN, __NV_E4M3)
OP_CAST(Float8E5M2, __NV_E5M2)
#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \
if (saturate) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
} else { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
} \
if (saturate) { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
} else { \
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
} \
}
EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN)

View file

@ -31,7 +31,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
UNARY_OP_NAME_EXPR(Cos, _Cos(a))
UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \
UNARY_OP_NAME_EXPR(Sign, _Sign(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \

View file

@ -83,7 +83,7 @@ namespace Windows::AI::MachineLearning::Adapter
// Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size.
struct DmlGraphNodeCreateInfo
{
uint32_t nodeCount;
uint32_t nodeCount = 0;
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;

View file

@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
template <typename T>
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
template <typename T>
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
template <typename T>
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }
template <>
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
template <typename T>
__device__ __inline__ T _Normcdf(T a);
@ -337,7 +349,7 @@ struct GridDim {
};
// aligned vector generates vectorized load/store
template<typename T, int vec_size>
template <typename T, int vec_size>
struct alignas(sizeof(T) * vec_size) aligned_vector {
T val[vec_size];
};
@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector {
// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels.
// TODO ROCM added support recently, should verify.
#define HIP_KERNEL_ASSERT(...)
//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
// WARP related definitions and functions
constexpr int GPU_WARP_SIZE = warpSize;
inline int GPU_WARP_SIZE_HOST= warpSizeDynamic();
inline int GPU_WARP_SIZE_HOST = warpSizeDynamic();
template <typename T>
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) {

View file

@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,
// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum)>,

View file

@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (info.has_user_compute_stream) {
external_stream_ = true;
stream_ = static_cast<cudaStream_t>(info.user_compute_stream);
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_)));
}
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
@ -1860,6 +1864,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
} else if (number_of_trt_nodes == number_of_ort_nodes) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
} else {
sync_stream_after_enqueue_ = true;
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}
@ -2372,7 +2377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
@ -2400,6 +2405,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
@ -3001,6 +3007,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}
if (sync_stream_after_enqueue) {
cudaStreamSynchronize(stream);
}
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
const std::string& output_name = output_binding_names[i];

View file

@ -111,6 +111,7 @@ struct TensorrtFuncState {
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
bool sync_stream_after_enqueue = false;
OrtMutex* tensorrt_mu_ptr = nullptr;
bool fp16_enable = false;
bool int8_enable = false;
@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;
// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
mutable bool sync_stream_after_enqueue_ = false;
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

View file

@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) {
TEST(MathOpTest, Sign_uint64) {
using namespace test_sign_internal;
OpTester test("Sign", 9);
OpTester test("Sign", 13);
std::vector<int64_t> input_dims{7};
std::vector<uint64_t> input;
@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) {
// we disable this test for openvino as openvino ep supports only FP32 Precision
TEST(MathOpTest, Sign_int64) {
using namespace test_sign_internal;
OpTester test("Sign", 9);
OpTester test("Sign", 13);
std::vector<int64_t> input_dims{7};
std::vector<int64_t> input;
@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) {
TEST(MathOpTest, Sign_float) {
using namespace test_sign_internal;
OpTester test("Sign", 9);
OpTester test("Sign", 13);
std::vector<int64_t> input_dims{7};
std::vector<float> input;
@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) {
TEST(MathOpTest, Sign_double) {
using namespace test_sign_internal;
OpTester test("Sign", 9);
OpTester test("Sign", 13);
std::vector<int64_t> input_dims{7};
std::vector<double> input;
@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) {
}
TEST(MathOpTest, Sign_MLFloat16) {
using namespace test_sign_internal;
OpTester test("Sign", 9);
OpTester test("Sign", 13);
std::vector<int64_t> input_dims{7};
std::vector<MLFloat16> input;

View file

@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
checkpoint_state
.def(py::init())
.def("add_property", [](onnxruntime::training::api::CheckpointState* state,
const std::string& property_name,
const std::variant<int64_t, float, std::string>& property_value) {
state->property_bag.AddProperty(property_name, property_value);
})
.def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
})
.def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.HasProperty(property_name);
});
.def("add_property",
[](onnxruntime::training::api::CheckpointState* state,
const std::string& property_name,
const std::variant<int64_t, float, std::string>& property_value) {
state->property_bag.AddProperty(property_name, property_value);
})
.def("get_property",
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
})
.def("has_property",
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.HasProperty(property_name);
})
.def("copy_parameter_from",
[](onnxruntime::training::api::CheckpointState* state,
const std::string& parameter_name, OrtValue& value) -> void {
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
}
ORT_THROW_IF_ERROR(it->second->CopyFrom(
state->module_checkpoint_state.train_session_data_transfer_mgr, value));
})
.def("get_parameter",
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
}
return it->second;
})
.def("has_parameter",
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
return state->module_checkpoint_state.named_parameters.count(parameter_name);
})
.def("parameter_names",
[](onnxruntime::training::api::CheckpointState* state) {
std::vector<std::string> names;
for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) {
names.push_back(name);
}
std::sort(names.begin(), names.end());
return names;
})
.def("property_names",
[](onnxruntime::training::api::CheckpointState* state) {
std::vector<std::string> names;
for ([[maybe_unused]] auto& [name, value] : state->property_bag) {
names.push_back(name);
}
std::sort(names.begin(), names.end());
return names;
});
py::class_<PyOptimizer>
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(scheduler->Step());
});
py::class_<onnxruntime::training::api::Parameter,
std::unique_ptr<onnxruntime::training::api::Parameter, py::nodelete>>
parameter(m, "Parameter");
parameter
.def_property_readonly("name", &onnxruntime::training::api::Parameter::Name)
.def_property_readonly("data", &onnxruntime::training::api::Parameter::Data)
.def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient)
.def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad)
.def("copy_from",
[](onnxruntime::training::api::Parameter* parameter,
onnxruntime::training::api::CheckpointState* state,
OrtValue& value) -> void {
ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value));
});
m.def(
"save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,

View file

@ -5,7 +5,198 @@ from __future__ import annotations
import os
import numpy as np
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
class Parameter:
"""Class that represents a model parameter
This class represents a model parameter and provides access to its data,
gradient and other properties. This class is not expected to be instantiated directly.
Instead, it is returned by the `CheckpointState` object.
Args:
parameter: The C.Parameter object that holds the underlying parameter data.
state: The C.CheckpointState object that holds the underlying session state.
"""
def __init__(self, parameter: C.Parameter, state: C.CheckpointState):
self._parameter = parameter
self._state = state
@property
def name(self) -> str:
"""The name of the parameter"""
return self._parameter.name
@property
def data(self) -> np.ndarray:
"""The data of the parameter"""
return self._parameter.data.numpy()
@data.setter
def data(self, value: np.ndarray) -> None:
"""Sets the data of the parameter"""
self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue)
@property
def grad(self) -> np.ndarray:
"""The gradient of the parameter"""
return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None
@property
def requires_grad(self) -> bool:
"""Whether or not the parameter requires its gradient to be computed"""
return self._parameter.requires_grad
def __repr__(self) -> str:
"""Returns a string representation of the parameter"""
return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"
class Parameters:
"""Class that holds all the model parameters
This class holds all the model parameters and provides access to them.
This class is not expected to be instantiated directly. Instead, it is returned by the
`CheckpointState`'s parameters attribute.
This class behaves like a dictionary and provides access to the parameters by name.
Args:
state: The C.CheckpointState object that holds the underlying session state.
"""
def __init__(self, state: C.CheckpointState):
self._state = state
def __getitem__(self, name: str) -> Parameter:
"""Gets the parameter associated with the given name
Searches for the name in the parameters of the checkpoint state.
Args:
name: The name of the parameter
Returns:
The value of the parameter
Raises:
KeyError: If the parameter is not found
"""
if name not in self:
raise KeyError(f"Parameter {name} not found.")
return Parameter(self._state.get_parameter(name), self._state)
def __setitem__(self, name: str, value: np.ndarray) -> None:
"""Sets the parameter value for the given name
Searches for the name in the parameters of the checkpoint state.
If the name is found in parameters, the value is updated.
Args:
name: The name of the parameter
value: The value of the parameter as a numpy array
Raises:
KeyError: If the parameter is not found
"""
if name not in self:
raise KeyError(f"Parameter {name} not found.")
self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
def __contains__(self, name: str) -> bool:
"""Checks if the parameter exists in the state
Args:
name: The name of the parameter
Returns:
True if the name is a parameter False otherwise
"""
return self._state.has_parameter(name)
def __iter__(self):
"""Returns an iterator over the properties"""
for parameter_name in self._state.parameter_names():
yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state)
def __repr__(self) -> str:
"""Returns a string representation of the parameters"""
return self._state.parameter_names()
def __len__(self) -> int:
"""Returns the number of parameters"""
return len(self._state.parameter_names())
class Properties:
def __init__(self, state: C.CheckpointState):
self._state = state
def __getitem__(self, name: str) -> int | float | str:
"""Gets the property associated with the given name
Searches for the name in the properties of the checkpoint state.
Args:
name: The name of the property
Returns:
The value of the property
Raises:
KeyError: If the property is not found
"""
if name not in self:
raise KeyError(f"Property {name} not found.")
return self._state.get_property(name)
def __setitem__(self, name: str, value: int | float | str) -> None:
"""Sets the property value for the given name
Searches for the name in the properties of the checkpoint state.
The value is added or updated in the properties.
Args:
name: The name of the property
value: The value of the property
Properties only support int, float and str values.
"""
self._state.add_property(name, value)
def __contains__(self, name: str) -> bool:
"""Checks if the property exists in the state
Args:
name: The name of the property
Returns:
True if the name is a property, False otherwise
"""
return self._state.has_property(name)
def __iter__(self):
"""Returns an iterator over the properties"""
for property_name in self._state.property_names():
yield property_name, self._state.get_property(property_name)
def __repr__(self) -> str:
"""Returns a string representation of the properties"""
return self._state.property_names()
def __len__(self) -> int:
"""Returns the number of properties"""
return len(self._state.property_names())
class CheckpointState:
@ -14,8 +205,6 @@ class CheckpointState:
This class holds all the state information of the training session such as the model parameters,
its gradients, the optimizer state and user defined properties.
User defined properties can be indexed by name from the `CheckpointState` object.
To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method.
Args:
@ -26,6 +215,8 @@ class CheckpointState:
if not isinstance(state, C.CheckpointState):
raise TypeError(f"Invalid argument for CheckpointState received {type(state)}")
self._state = state
self._parameters = Parameters(self._state)
self._properties = Properties(self._state)
@classmethod
def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
@ -52,33 +243,12 @@ class CheckpointState:
"""
C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
def __getitem__(self, name: str) -> int | float | str:
"""Gets the property associated with the given name
@property
def parameters(self) -> Parameters:
"""Returns the model parameters from the checkpoint state"""
return self._parameters
Args:
name: The name of the property
Returns:
The value of the property
"""
return self._state.get_property(name)
def __setitem__(self, name: str, value: int | float | str) -> None:
"""Sets the property value for the given name
Args:
name: The name of the property
value: The value of the property
"""
self._state.add_property(name, value)
def __contains__(self, name: str) -> bool:
"""Checks if the property exists in the state
Args:
name: The name of the property
Returns:
True if the property exists, False otherwise
"""
return self._state.has_property(name)
@property
def properties(self) -> Properties:
"""Returns the properties from the checkpoint state"""
return self._properties

View file

@ -360,14 +360,18 @@ def test_add_get_property(property_value):
if isinstance(property_value, float):
property_value = float(np.float32(property_value))
state["property"] = property_value
assert "property" in state
assert state["property"] == property_value
assert len(state.properties) == 0
state.properties["property"] = property_value
assert "property" in state.properties
assert state.properties["property"] == property_value
assert len(state.properties) == 1
CheckpointState.save_checkpoint(state, checkpoint_file_path)
new_state = CheckpointState.load_checkpoint(checkpoint_file_path)
assert "property" in new_state
assert new_state["property"] == property_value
assert "property" in new_state.properties
assert new_state.properties["property"] == property_value
assert len(new_state.properties) == 1
def test_get_input_output_names():
@ -563,3 +567,60 @@ def test_eval_step_with_ort_values():
fetches = model(inputs, labels)
assert isinstance(fetches, OrtValue)
assert fetches
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_get_and_set_parameter_values(device):
with tempfile.TemporaryDirectory() as temp_dir:
(
checkpoint_file_path,
training_model_file_path,
eval_model_file_path,
_,
pt_model,
) = _create_training_artifacts(
temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
)
state = CheckpointState.load_checkpoint(checkpoint_file_path)
model = Module(training_model_file_path, state, eval_model_file_path, device=device)
state_dict = pt_model.state_dict()
assert len(state_dict) == len(state.parameters)
for parameter_name, _ in state.parameters:
assert parameter_name in state_dict
for name, pt_param in pt_model.named_parameters():
ort_param = state.parameters[name]
assert ort_param.name == name
assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data)
if name in ["fc1.weight", "fc1.bias"]:
assert ort_param.requires_grad is False
assert ort_param.grad is None
else:
assert ort_param.requires_grad is True
assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32))
original_param = state.parameters["fc1.weight"].data
state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32)
updated_param = state.parameters["fc1.weight"].data
assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32))
model.train()
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
loss = model(inputs, labels)
assert loss is not None
for name, _ in pt_model.named_parameters():
ort_param = state.parameters[name]
assert ort_param.name == name
if name in ["fc1.weight", "fc1.bias"]:
assert ort_param.requires_grad is False
assert ort_param.grad is None
else:
assert ort_param.requires_grad is True
assert ort_param.grad.any()
state.parameters["fc1.weight"] = original_param
assert np.allclose(state.parameters["fc1.weight"].data, original_param)

View file

@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
}
}
TEST(TrainingCApiTest, GetParameter) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();
ASSERT_EQ(shape.size(), 2U);
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
}
TEST(TrainingCApiTest, UpdateParameter) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();
ASSERT_EQ(shape.size(), 2U);
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
Ort::Value updated_parameter{updated_param_value};
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
ASSERT_EQ(actual, expected);
ASSERT_NE(actual, not_expected);
checkpoint_state.UpdateParameter("fc1.weight", parameter);
current_parameter = checkpoint_state.GetParameter("fc1.weight");
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
expected = gsl::span(parameter.GetTensorMutableData<float>(),
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
ASSERT_EQ(actual, expected);
ASSERT_NE(actual, not_expected);
}
#ifdef USE_CUDA
TEST(TrainingCApiTest, UpdateParameterDifferentDevices) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
Ort::Env env;
Ort::SessionOptions session_options;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri);
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();
ASSERT_EQ(shape.size(), 2U);
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
Ort::Value updated_parameter{updated_param_value};
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
ASSERT_EQ(actual, expected);
ASSERT_NE(actual, not_expected);
checkpoint_state.UpdateParameter("fc1.weight", parameter);
current_parameter = checkpoint_state.GetParameter("fc1.weight");
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
expected = gsl::span(parameter.GetTensorMutableData<float>(),
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
ASSERT_EQ(actual, expected);
ASSERT_NE(actual, not_expected);
}
#endif
} // namespace onnxruntime::training::test

View file

@ -22,10 +22,12 @@ struct PropertyBag {
PropertyBag() = default;
void AddProperty(const std::string& name, const PropertyDataType& val) {
ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
"Duplicated property named ", name);
named_properties_.insert({name, val});
auto it = named_properties_.find(name);
if (it == named_properties_.end()) {
named_properties_.insert({name, val});
} else {
it->second = val;
}
}
template <typename T>

View file

@ -608,14 +608,14 @@ struct OrtTrainingApi {
/// \name Accessing The Training Session State
/// @{
/** \brief Adds the given property to the checkpoint state.
/** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user if they desire by calling this function with the appropriate property name and
* value. The given property name must be unique to be able to successfully add the property.
* state by the user by calling this function with the corresponding property name and value.
* The given property name must be unique to be able to successfully add the property.
*
* \param[in] checkpoint_state The checkpoint state which should hold the property.
* \param[in] property_name Unique name of the property being added.
* \param[in] property_name Name of the property being added or updated.
* \param[in] property_type Type of the property associated with the given name.
* \param[in] property_value Property value associated with the given name.
*
@ -632,7 +632,7 @@ struct OrtTrainingApi {
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
* \param[in] property_name Unique name of the property being retrieved.
* \param[in] property_name Name of the property being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the property_value.
* \param[out] property_type Type of the property associated with the given name.
* \param[out] property_value Property value associated with the given name.
@ -669,6 +669,57 @@ struct OrtTrainingApi {
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
*
* This function retrieves the type and shape of the parameter associated with the given parameter name.
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
*
* This function updates a model parameter in the checkpoint state with the given parameter data.
* The training session must be already created with the checkpoint state that contains the parameter
* being updated. The given parameter is copied over to the registered device for the training session.
* The parameter must exist in the checkpoint state to be able to update it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being updated.
* \param[in] parameter The parameter data that should replace the existing parameter data.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _In_ OrtValue* parameter);
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over and returned as an OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the parameter.
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);
/// @}
};

View file

@ -112,13 +112,13 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
const bool include_optimizer_state = false);
/** \brief Adds the given property to the checkpoint state.
/** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user if they desire by calling this function with the appropriate property name and
* value. The given property name must be unique to be able to successfully add the property.
* state by the user by calling this function with the corresponding property name and value.
* The given property name must be unique to be able to successfully add the property.
*
* \param[in] property_name Unique name of the property being added.
* \param[in] property_name Name of the property being added or updated.
* \param[in] property_value Property value associated with the given name.
*
*/
@ -129,12 +129,38 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] property_name Unique name of the property being retrieved.
* \param[in] property_name Name of the property being retrieved.
* \return Property value associated with the given property name.
*
*/
Property GetProperty(const std::string& property_name);
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
*
* This function updates a model parameter in the checkpoint state with the given parameter data.
* The training session must be already created with the checkpoint state that contains the parameter
* being updated. The given parameter is copied over to the registered device for the training session.
* The parameter must exist in the checkpoint state to be able to update it successfully.
*
* \param[in] parameter_name Name of the parameter being updated.
* \param[in] parameter The parameter data that should replace the existing parameter data.
*
*/
void UpdateParameter(const std::string& parameter_name, const Value& parameter);
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over to the provided OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] parameter_name Name of the parameter being retrieved.
* \return The parameter data that is retrieved from the checkpoint state.
*
*/
Value GetParameter(const std::string& parameter_name);
/// @}
};

View file

@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) {
return property;
}
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
}
inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
AllocatorWithDefaultOptions allocator;
OrtValue* parameter;
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, &parameter));
return Value{parameter};
}
} // namespace Ort

View file

@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph,
#endif
} // namespace
Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const {
ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it.");
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
", Got: ", data.Get<Tensor>().Shape().ToString());
#ifdef ENABLE_STRIDED_TENSORS
auto data_strides = data.Get<Tensor>().Strides();
auto param_strides = data_.Get<Tensor>().Strides();
ORT_ENFORCE(data_strides.size() == param_strides.size(),
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
", Got: ", data_strides.size());
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
"Parameter data stride value mismatch.");
#endif
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
", Got: ", data.Get<Tensor>().DataType());
ORT_ENFORCE(data_transfer_manager != nullptr,
"Data transfer manager must be provided to copy data to the parameter. "
"Please create the TrainingSession before trying to update the parameter.");
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get<Tensor>(), *data.GetMutable<Tensor>()));
return Status::OK();
}
Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) {
ORT_ENFORCE(data_.IsAllocated(),
"The checkpoint parameter is not allocated. Cannot copy the given parameter data to it.");
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
", Got: ", data.Get<Tensor>().Shape().ToString());
#ifdef ENABLE_STRIDED_TENSORS
auto data_strides = data.Get<Tensor>().Strides();
auto param_strides = data_.Get<Tensor>().Strides();
ORT_ENFORCE(data_strides.size() == param_strides.size(),
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
", Got: ", data_strides.size());
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
"Parameter data stride value mismatch.");
#endif
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
", Got: ", data.Get<Tensor>().DataType());
ORT_ENFORCE(data_transfer_manager != nullptr,
"Data transfer manager must be provided to copy data to the parameter. "
"Please create the TrainingSession before trying to update the parameter.");
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get<Tensor>(), *data_.GetMutable<Tensor>()));
return Status::OK();
}
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
// assert param is allocated
ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient.");
@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers,
}
}
Module::~Module() {
state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr;
}
size_t Module::GetTrainingModelOutputCount() const noexcept {
return train_output_names_.size();
}

View file

@ -21,6 +21,8 @@ struct Parameter {
// Return the mutable data.
OrtValue& Data() { return data_; }
Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const;
Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data);
const std::string& Name() const { return name_; }
// Returns whether this parameter is trainable or not.
@ -34,7 +36,6 @@ struct Parameter {
// Reset and release the gradient buffer of this Parameter greedily.
Status ResetGrad();
protected:
Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad);
private:
@ -83,6 +84,8 @@ struct Module {
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());
~Module();
// Return the trainable/nontrainable parameters
std::vector<std::shared_ptr<Parameter>> Parameters() const;

View file

@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void*
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
API_IMPL_BEGIN
if (checkpoint_buffer == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr.");
}
*checkpoint_state = nullptr;
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
const auto* checkpoint_bytes = reinterpret_cast<const uint8_t*>(checkpoint_buffer);
@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState*
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) {
API_IMPL_BEGIN
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
}
return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape);
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _In_ OrtValue* parameter) {
API_IMPL_BEGIN
if (parameter == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
}
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
}
ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom(
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter));
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter) {
API_IMPL_BEGIN
if (parameter == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
}
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
}
if (!it->second->Data().IsTensor()) {
return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type.");
}
const auto& parameter_tensor = it->second->Data().Get<onnxruntime::Tensor>();
ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue(
allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter));
auto status = it->second->CopyTo(
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter);
if (!status.IsOK()) {
OrtApis::ReleaseValue(*parameter);
return onnxruntime::ToOrtStatus(status);
}
return nullptr;
API_IMPL_END
}
static constexpr OrtTrainingApi ort_training_api = {
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
&OrtTrainingApis::AddProperty,
&OrtTrainingApis::GetProperty,
&OrtTrainingApis::LoadCheckpointFromBuffer};
&OrtTrainingApis::LoadCheckpointFromBuffer,
&OrtTrainingApis::GetParameterTypeAndShape,
&OrtTrainingApis::UpdateParameter,
&OrtTrainingApis::GetParameter};
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
// No constraints on the API version yet.

View file

@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _In_ OrtValue* parameter);
ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);
} // namespace OrtTrainingApis

View file

@ -163,10 +163,8 @@ std::unordered_map<std::string, std::pair<std::string, std::string>> disabledGpu
test name -> absolute difference sampleTolerance
*/
std::unordered_map<std::string, double> sampleTolerancePerTests({
{"fp16_inception_v1_opset7_GPU",0.005 },
{"fp16_inception_v1_opset8_GPU", 0.005},
{ "candy_opset9_GPU",
0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
{ "fp16_tiny_yolov2_opset8_GPU",
0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
{"fp16_inception_v1_opset7_GPU", 0.005},
{"fp16_inception_v1_opset8_GPU", 0.005},
{ "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
{ "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
});