mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Some cherry-picks for the 1.16.2 release (#18218)
Cherry-pick PRs: #18026 #17912 #17901 “2 lines added whitespace errors when cherry-picking" #17293 #17364 #17505 #17885 This PR contains all the cherry-picks for the patch release except: 1. The PRs marked with sdxl_llama 2. #17772 which has a merge conflict. --------- Co-authored-by: Chi Lo <Chi.Lo@microsoft.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: Scott McKay <Scott.McKay@microsoft.com> Co-authored-by: Baiju Meswani <bmeswani@microsoft.com> Co-authored-by: Kaz Nishimura <kazssym@linuxfront.com> Co-authored-by: Scott McKay <skottmckay@gmail.com>
This commit is contained in:
parent
bc533a6723
commit
2f57f1e4d7
32 changed files with 1176 additions and 294 deletions
|
|
@ -1860,54 +1860,61 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
public static DOrtFillStringTensor OrtFillStringTensor;
|
||||
|
||||
/// \param value A tensor created from OrtCreateTensor... function.
|
||||
/// \param index The index of the entry in the tensor to resize. <summary>
|
||||
/// \param length_in_bytes Length to resize the string to.
|
||||
/// \param buffer The resized buffer.
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer(
|
||||
IntPtr /* OrtValue */ value,
|
||||
UIntPtr /* size_t */ index,
|
||||
UIntPtr /* size_t */ length_in_bytes,
|
||||
out IntPtr /* char** */ buffer
|
||||
);
|
||||
IntPtr /* OrtValue */ value,
|
||||
UIntPtr /* size_t */ index,
|
||||
UIntPtr /* size_t */ length_in_bytes,
|
||||
out IntPtr /* char** */ buffer);
|
||||
|
||||
public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
|
||||
IntPtr /*(OrtValue*)*/ value,
|
||||
byte[] /*(void*)*/ dst_buffer,
|
||||
UIntPtr dst_buffer_len,
|
||||
UIntPtr[] offsets,
|
||||
UIntPtr offsets_len);
|
||||
IntPtr /*(OrtValue*)*/ value,
|
||||
byte[] /*(void*)*/ dst_buffer,
|
||||
UIntPtr dst_buffer_len,
|
||||
UIntPtr[] offsets,
|
||||
UIntPtr offsets_len);
|
||||
|
||||
public static DOrtGetStringTensorContent OrtGetStringTensorContent;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value,
|
||||
out UIntPtr /*(size_t*)*/ len);
|
||||
out UIntPtr /*(size_t*)*/ len);
|
||||
|
||||
public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value,
|
||||
UIntPtr /*(size_t)*/ index,
|
||||
out UIntPtr /*(size_t*)*/ len);
|
||||
UIntPtr /*(size_t)*/ index,
|
||||
out UIntPtr /*(size_t*)*/ len);
|
||||
|
||||
public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value,
|
||||
UIntPtr /*(size_t)*/ bufferLength,
|
||||
UIntPtr /*(size_t)*/ elementIndex,
|
||||
byte[] buffer);
|
||||
UIntPtr /*(size_t)*/ bufferLength,
|
||||
UIntPtr /*(size_t)*/ elementIndex,
|
||||
byte[] buffer);
|
||||
|
||||
public static DOrtGetStringTensorElement OrtGetStringTensorElement;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/
|
||||
DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo(
|
||||
IntPtr /*(struct OrtTypeInfo*)*/ typeInfo,
|
||||
out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo);
|
||||
|
||||
public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(
|
||||
IntPtr /*(OrtValue*)*/ value,
|
||||
out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo);
|
||||
|
||||
public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape;
|
||||
|
||||
|
|
@ -1917,12 +1924,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output);
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(
|
||||
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
|
||||
out IntPtr /*(TensorElementType*)*/ output);
|
||||
|
||||
public static DOrtGetTensorElementType OrtGetTensorElementType;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output);
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(
|
||||
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
|
||||
out UIntPtr output);
|
||||
|
||||
public static DOrtGetDimensionsCount OrtGetDimensionsCount;
|
||||
|
||||
|
|
|
|||
|
|
@ -40,20 +40,16 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
String = 2
|
||||
}
|
||||
|
||||
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue)
|
||||
private void AddPropertyImpl<T>(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
T[] value = new T[1];
|
||||
value[0] = propertyValue;
|
||||
Memory<T> memory = value;
|
||||
using (var memHandle = memory.Pin())
|
||||
T[] value = { propertyValue };
|
||||
unsafe
|
||||
{
|
||||
IntPtr memPtr;
|
||||
unsafe
|
||||
fixed (T* memPtr = value)
|
||||
{
|
||||
memPtr = (IntPtr)memHandle.Pointer;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr));
|
||||
}
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,13 +99,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given int property to the checkpoint state.
|
||||
/// Adds or updates the given int property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint
|
||||
/// state by the user if they desire by calling this function with the appropriate property name and
|
||||
/// value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, long propertyValue)
|
||||
{
|
||||
|
|
@ -117,13 +113,13 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given float property to the checkpoint state.
|
||||
/// Adds or updates the given float property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint
|
||||
/// state by the user if they desire by calling this function with the appropriate property name and
|
||||
/// value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, float propertyValue)
|
||||
{
|
||||
|
|
@ -131,28 +127,25 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the given string property to the checkpoint state.
|
||||
/// Adds or updates the given string property to/in the checkpoint state.
|
||||
///
|
||||
/// Runtime properties that are strings such as parameter names, custom strings, and others can be added
|
||||
/// to the checkpoint state by the user if they desire by calling this function with the appropriate property
|
||||
/// name and value. The given property name must be unique to be able to successfully add the property.
|
||||
/// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
/// state by the user by calling this function with the corresponding property name and value.
|
||||
/// The given property name must be unique to be able to successfully add the property.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being added.</param>
|
||||
/// <param name="propertyName">Name of the property being added or updated.</param>
|
||||
/// <param name="propertyValue">Property value associated with the given name.</param>
|
||||
public void AddProperty(string propertyName, string propertyValue)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);
|
||||
|
||||
IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
|
||||
try
|
||||
unsafe
|
||||
{
|
||||
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
|
||||
}
|
||||
finally
|
||||
{
|
||||
Marshal.FreeHGlobal(unmanagedPointer);
|
||||
fixed (byte* p = propertyValueUtf8)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -162,34 +155,86 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
/// exist in the checkpoint state to be able to retrieve it successfully.
|
||||
/// </summary>
|
||||
/// <param name="propertyName">Unique name of the property being retrieved.</param>
|
||||
/// <param name="propertyName">Name of the property being retrieved.</param>
|
||||
/// <returns>Property value associated with the given property name.</returns>
|
||||
public object GetProperty(string propertyName)
|
||||
{
|
||||
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
|
||||
var allocator = OrtAllocator.DefaultInstance;
|
||||
IntPtr propertyValue = IntPtr.Zero;
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue));
|
||||
|
||||
if (propertyType == PropertyType.Int)
|
||||
try
|
||||
{
|
||||
var longPropertyValue = Marshal.ReadInt64(propertyValue);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return longPropertyValue;
|
||||
if (propertyType == PropertyType.Int)
|
||||
{
|
||||
Int64 value;
|
||||
unsafe
|
||||
{
|
||||
value = *(Int64*)propertyValue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
else if (propertyType == PropertyType.Float)
|
||||
{
|
||||
float value;
|
||||
unsafe
|
||||
{
|
||||
value = *(float*)propertyValue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
else if (propertyType == PropertyType.String)
|
||||
{
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue);
|
||||
}
|
||||
|
||||
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
|
||||
}
|
||||
else if (propertyType == PropertyType.Float)
|
||||
finally
|
||||
{
|
||||
float[] value = new float[1];
|
||||
Marshal.Copy(propertyValue, value, 0, 1);
|
||||
allocator.FreeMemory(propertyValue);
|
||||
return value[0];
|
||||
}
|
||||
else if (propertyType == PropertyType.String)
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
///
|
||||
/// This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
/// The training session must be already created with the checkpoint state that contains the parameter
|
||||
/// being updated. The given parameter is copied over to the registered device for the training session.
|
||||
/// The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
/// </summary>
|
||||
/// <param name="parameterName">Name of the parameter being updated.</param>
|
||||
/// <param name="parameter">The parameter data that should replace the existing parameter data.</param>
|
||||
public void UpdateParameter(string parameterName, OrtValue parameter)
|
||||
{
|
||||
if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
{
|
||||
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter.");
|
||||
}
|
||||
|
||||
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
|
||||
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
///
|
||||
/// This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
/// The parameter is copied over to the provided OrtValue. The training session must be already created
|
||||
/// with the checkpoint state that contains the parameter being retrieved.
|
||||
/// The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
/// </summary>
|
||||
/// <param name="parameterName">Name of the parameter being updated.</param>
|
||||
/// <returns>The parameter data that is retrieved from the checkpoint state.</returns>
|
||||
public OrtValue GetParameter(string parameterName)
|
||||
{
|
||||
var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName);
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle));
|
||||
|
||||
return new OrtValue(parameterHandle);
|
||||
}
|
||||
|
||||
#region SafeHandle
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr AddProperty;
|
||||
public IntPtr GetProperty;
|
||||
public IntPtr LoadCheckpointFromBuffer;
|
||||
public IntPtr GetParameterTypeAndShape;
|
||||
public IntPtr UpdateParameter;
|
||||
public IntPtr GetParameter;
|
||||
}
|
||||
|
||||
internal static class NativeTrainingMethods
|
||||
|
|
@ -97,6 +100,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName));
|
||||
OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty));
|
||||
OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty));
|
||||
OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape));
|
||||
OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter));
|
||||
OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -359,6 +365,34 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
|
||||
public static DOrtGetProperty OrtGetProperty;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape
|
||||
);
|
||||
|
||||
public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
IntPtr /*(OrtValue*)*/ parameter
|
||||
);
|
||||
|
||||
public static DOrtUpdateParameter OrtUpdateParameter;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter(
|
||||
IntPtr /*(OrtCheckpointState*)*/ checkpointState,
|
||||
byte[] /*(const char*)*/ parameterName,
|
||||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out IntPtr /*(OrtValue**)*/ parameter
|
||||
);
|
||||
|
||||
public static DOrtGetParameter OrtGetParameter;
|
||||
|
||||
#endregion TrainingSession API
|
||||
|
||||
public static bool TrainingEnabled()
|
||||
|
|
|
|||
|
|
@ -358,13 +358,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
|
||||
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
|
||||
{
|
||||
if (!_evalOutputCount.Equals(outputValues.Count))
|
||||
if (_evalOutputCount != (ulong)outputValues.Count())
|
||||
{
|
||||
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
|
||||
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount}).");
|
||||
}
|
||||
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
|
||||
const bool isInput = true;
|
||||
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput);
|
||||
|
||||
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
|
||||
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
|
||||
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
|
||||
}
|
||||
|
|
@ -509,18 +510,17 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// Returns a contiguous buffer that holds a copy of all training state parameters
|
||||
/// </summary>
|
||||
/// <param name="onlyTrainable">Whether to only copy trainable parameters or to copy all parameters.</param>
|
||||
public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
|
||||
public OrtValue ToBuffer(bool onlyTrainable)
|
||||
{
|
||||
UIntPtr bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable));
|
||||
|
||||
float[] bufferMemory = new float[bufferSize.ToUInt64()];
|
||||
|
||||
var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
|
||||
var shape = new long[] { (long)bufferSize.ToUInt64() };
|
||||
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
|
||||
var shape = new long[] { (long)bufferSize };
|
||||
var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable));
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
|
@ -528,45 +528,30 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <summary>
|
||||
/// Loads the training session model parameters from a contiguous buffer
|
||||
/// </summary>
|
||||
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
|
||||
public void FromBuffer(FixedBufferOnnxValue buffer)
|
||||
/// <param name="ortValue">Contiguous buffer to load the parameters from.</param>
|
||||
/// <param name="onlyTrainable">Whether to only load trainable parameters or to load all parameters.</param>
|
||||
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
|
||||
{
|
||||
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
|
||||
{
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
|
||||
}
|
||||
|
||||
IntPtr typeAndShapeInfo = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
|
||||
UIntPtr numDimensions = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
|
||||
if (numDimensions.ToUInt64() != 1)
|
||||
var tensorInfo = ortValue.GetTensorTypeAndShape();
|
||||
if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float)
|
||||
{
|
||||
string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
// Here buffer size represents the number of elements in the buffer
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize));
|
||||
|
||||
// OrtGetParametersSize returns the total number of elements in the model's parameters.
|
||||
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
|
||||
if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
|
||||
return;
|
||||
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float.");
|
||||
}
|
||||
|
||||
UIntPtr numElements = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
|
||||
if ((ulong)bufferSize != (ulong)numElements)
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable));
|
||||
if ((ulong)tensorInfo.ElementCount != (ulong)numElements)
|
||||
{
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -484,20 +484,23 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestToBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
using (var buffer = trainingSession.ToBuffer(true))
|
||||
{
|
||||
Assert.NotNull(buffer);
|
||||
var typeShape = buffer.GetTensorTypeAndShape();
|
||||
Assert.Equal(1, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(397510, fetchedShape[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -505,22 +508,25 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
public void TestFromBuffer()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
using (var cleanUp = new DisposableListTest<IDisposable>())
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
var state = CheckpointState.LoadCheckpoint(checkpointPath);
|
||||
cleanUp.Add(state);
|
||||
Assert.NotNull(state);
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
|
||||
cleanUp.Add(trainingSession);
|
||||
using (var buffer = trainingSession.ToBuffer(true))
|
||||
{
|
||||
Assert.NotNull(buffer);
|
||||
var typeShape = buffer.GetTensorTypeAndShape();
|
||||
Assert.Equal(1, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(397510, fetchedShape[0]);
|
||||
|
||||
var buffer = trainingSession.ToBuffer(true);
|
||||
cleanUp.Add(buffer);
|
||||
|
||||
trainingSession.FromBuffer(buffer);
|
||||
trainingSession.FromBuffer(buffer, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -530,6 +536,82 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
TrainingUtils.SetSeed(8888);
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestGetParameter")]
|
||||
public void TestGetParameter()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
using (var parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
Assert.NotNull(state);
|
||||
Assert.NotNull(parameter);
|
||||
|
||||
var typeShape = parameter.GetTensorTypeAndShape();
|
||||
Assert.Equal(2, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(500, fetchedShape[0]);
|
||||
Assert.Equal(784, fetchedShape[1]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(DisplayName = "TestUpdateParameter")]
|
||||
public void TestUpdateParameter()
|
||||
{
|
||||
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
|
||||
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
|
||||
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
|
||||
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
|
||||
|
||||
using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
|
||||
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
|
||||
{
|
||||
Assert.NotNull(state);
|
||||
|
||||
using (var parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
Assert.NotNull(parameter);
|
||||
var typeShape = parameter.GetTensorTypeAndShape();
|
||||
|
||||
Assert.Equal(2, typeShape.DimensionsCount);
|
||||
var fetchedShape = typeShape.Shape;
|
||||
Assert.Equal(500, fetchedShape[0]);
|
||||
Assert.Equal(784, fetchedShape[1]);
|
||||
|
||||
float maxVal = 20;
|
||||
Random randNum = new Random();
|
||||
float[] updated_parameter_buffer = Enumerable
|
||||
.Repeat(0, 500 * 784)
|
||||
.Select(i => maxVal * (float)randNum.NextDouble())
|
||||
.ToArray();
|
||||
|
||||
using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape))
|
||||
{
|
||||
state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
using (var current_parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
|
||||
Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
|
||||
Assert.NotEqual(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
|
||||
}
|
||||
|
||||
state.UpdateParameter("fc1.weight", parameter);
|
||||
|
||||
using (var current_parameter = state.GetParameter("fc1.weight"))
|
||||
{
|
||||
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
|
||||
Assert.Equal(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
|
||||
Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class FloatComparer : IEqualityComparer<float>
|
||||
{
|
||||
private float atol = 1e-3f;
|
||||
|
|
|
|||
|
|
@ -763,6 +763,7 @@ Do not modify directly.*
|
|||
|Shrink|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -451,12 +451,16 @@ Return Value:
|
|||
|
||||
#if defined(_WIN32)
|
||||
HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
|
||||
#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we
|
||||
// disable it there.
|
||||
uint64_t isar0_el1;
|
||||
asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
|
||||
HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
|
||||
#else
|
||||
// Use the cpuinfo value which is read from sysctl and has some additional special cases.
|
||||
// https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379
|
||||
// Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips
|
||||
// as well as failing on other ARM chips as it is an EL1 level register that requires extra
|
||||
// privileges to read.
|
||||
//
|
||||
// uint64_t isar0_el1;
|
||||
// asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :);
|
||||
// HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u;
|
||||
HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ namespace cuda {
|
|||
|
||||
// float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions
|
||||
// CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2)))
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
|
||||
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
|
||||
__device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); }
|
||||
__device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); }
|
||||
|
|
@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
|
|||
template <typename T>
|
||||
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Normcdf(T a);
|
||||
|
||||
|
|
|
|||
|
|
@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
|
||||
|
|
@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub)>,
|
||||
|
|
|
|||
|
|
@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13)
|
|||
UNARY_OP_HFD(Log, 13)
|
||||
UNARY_OP_HFD(Exp, 13)
|
||||
UNARY_OP_HFD(Erf, 13)
|
||||
UNARY_OP_BWUZCSILHFD(Sign, 13)
|
||||
|
||||
UNARY_LOGICALOP_NOT_TYPED(1, bool)
|
||||
UNARY_OP_HFD(Round, 11)
|
||||
|
|
|
|||
|
|
@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Sign final : public UnaryElementwise {
|
||||
public:
|
||||
Sign(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
|
|||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
|
||||
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign)
|
||||
|
||||
// When casting, half needs to be converted via float type from most other types
|
||||
template <typename T>
|
||||
|
|
@ -119,52 +120,52 @@ struct OP_Cast {
|
|||
}
|
||||
};
|
||||
|
||||
#define IMPL_CAST_IMPL(InT, OutT) \
|
||||
#define IMPL_CAST_IMPL(InT, OutT) \
|
||||
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
|
||||
}
|
||||
|
||||
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
|
||||
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
|
||||
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
|
||||
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
|
||||
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
|
||||
#define IMPL_CAST_IMPL_FROM(T) \
|
||||
IMPL_CAST_IMPL(T, half) \
|
||||
IMPL_CAST_IMPL(T, float) \
|
||||
IMPL_CAST_IMPL(T, double) \
|
||||
IMPL_CAST_IMPL(T, int8_t) \
|
||||
IMPL_CAST_IMPL(T, int16_t) \
|
||||
IMPL_CAST_IMPL(T, int32_t) \
|
||||
IMPL_CAST_IMPL(T, int64_t) \
|
||||
IMPL_CAST_IMPL(T, uint8_t) \
|
||||
IMPL_CAST_IMPL(T, uint16_t) \
|
||||
IMPL_CAST_IMPL(T, uint32_t) \
|
||||
IMPL_CAST_IMPL(T, uint64_t) \
|
||||
IMPL_CAST_IMPL(T, bool) \
|
||||
IMPL_CAST_IMPL(T, BFloat16) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
|
||||
#define IMPL_CAST_IMPL_FROM(T) \
|
||||
IMPL_CAST_IMPL(T, half) \
|
||||
IMPL_CAST_IMPL(T, float) \
|
||||
IMPL_CAST_IMPL(T, double) \
|
||||
IMPL_CAST_IMPL(T, int8_t) \
|
||||
IMPL_CAST_IMPL(T, int16_t) \
|
||||
IMPL_CAST_IMPL(T, int32_t) \
|
||||
IMPL_CAST_IMPL(T, int64_t) \
|
||||
IMPL_CAST_IMPL(T, uint8_t) \
|
||||
IMPL_CAST_IMPL(T, uint16_t) \
|
||||
IMPL_CAST_IMPL(T, uint32_t) \
|
||||
IMPL_CAST_IMPL(T, uint64_t) \
|
||||
IMPL_CAST_IMPL(T, bool) \
|
||||
IMPL_CAST_IMPL(T, BFloat16) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \
|
||||
IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ)
|
||||
|
||||
#else
|
||||
|
||||
#define IMPL_CAST_IMPL_FROM(T) \
|
||||
IMPL_CAST_IMPL(T, half) \
|
||||
IMPL_CAST_IMPL(T, float) \
|
||||
IMPL_CAST_IMPL(T, double) \
|
||||
IMPL_CAST_IMPL(T, int8_t) \
|
||||
IMPL_CAST_IMPL(T, int16_t) \
|
||||
IMPL_CAST_IMPL(T, int32_t) \
|
||||
IMPL_CAST_IMPL(T, int64_t) \
|
||||
IMPL_CAST_IMPL(T, uint8_t) \
|
||||
IMPL_CAST_IMPL(T, uint16_t) \
|
||||
IMPL_CAST_IMPL(T, uint32_t) \
|
||||
IMPL_CAST_IMPL(T, uint64_t) \
|
||||
IMPL_CAST_IMPL(T, bool) \
|
||||
#define IMPL_CAST_IMPL_FROM(T) \
|
||||
IMPL_CAST_IMPL(T, half) \
|
||||
IMPL_CAST_IMPL(T, float) \
|
||||
IMPL_CAST_IMPL(T, double) \
|
||||
IMPL_CAST_IMPL(T, int8_t) \
|
||||
IMPL_CAST_IMPL(T, int16_t) \
|
||||
IMPL_CAST_IMPL(T, int32_t) \
|
||||
IMPL_CAST_IMPL(T, int64_t) \
|
||||
IMPL_CAST_IMPL(T, uint8_t) \
|
||||
IMPL_CAST_IMPL(T, uint16_t) \
|
||||
IMPL_CAST_IMPL(T, uint32_t) \
|
||||
IMPL_CAST_IMPL(T, uint64_t) \
|
||||
IMPL_CAST_IMPL(T, bool) \
|
||||
IMPL_CAST_IMPL(T, BFloat16)
|
||||
|
||||
#endif
|
||||
|
|
@ -199,58 +200,58 @@ struct OP_CastNoSat {
|
|||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
|
||||
|
||||
#define OP_CAST(T, NVT) \
|
||||
template <> \
|
||||
struct OP_CastSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
#define OP_CAST(T, NVT) \
|
||||
template <> \
|
||||
struct OP_CastSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
|
||||
} \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(static_cast<unsigned char>(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
|
||||
} \
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
#define OP_CAST(T, NVT) \
|
||||
template <> \
|
||||
struct OP_CastSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(__half2float(v), true); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(__half2float(v), false); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastSat<float, T> { \
|
||||
#define OP_CAST(T, NVT) \
|
||||
template <> \
|
||||
struct OP_CastSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(__half2float(v), true); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<half, T> { \
|
||||
__device__ __inline__ T operator()(const half& v) const { \
|
||||
return T(__half2float(v), false); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(v, true); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<float, T> { \
|
||||
return T(v, true); \
|
||||
} \
|
||||
}; \
|
||||
template <> \
|
||||
struct OP_CastNoSat<float, T> { \
|
||||
__device__ __inline__ T operator()(const float& v) const { \
|
||||
return T(v, false); \
|
||||
} \
|
||||
return T(v, false); \
|
||||
} \
|
||||
};
|
||||
|
||||
#endif
|
||||
|
|
@ -260,14 +261,13 @@ struct OP_CastNoSat {
|
|||
OP_CAST(Float8E4M3FN, __NV_E4M3)
|
||||
OP_CAST(Float8E5M2, __NV_E5M2)
|
||||
|
||||
|
||||
#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
|
||||
#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
|
||||
void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \
|
||||
if (saturate) { \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
|
||||
} else { \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
|
||||
} \
|
||||
if (saturate) { \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat<InT, OutT>(), count); \
|
||||
} else { \
|
||||
UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat<InT, OutT>(), count); \
|
||||
} \
|
||||
}
|
||||
|
||||
EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ namespace cuda {
|
|||
UNARY_OP_NAME_EXPR(Not, !a) \
|
||||
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
|
||||
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
|
||||
UNARY_OP_NAME_EXPR(Cos, _Cos(a))
|
||||
UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \
|
||||
UNARY_OP_NAME_EXPR(Sign, _Sign(a))
|
||||
|
||||
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
// Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size.
|
||||
struct DmlGraphNodeCreateInfo
|
||||
{
|
||||
uint32_t nodeCount;
|
||||
uint32_t nodeCount = 0;
|
||||
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
|
||||
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
|
|
|
|||
|
|
@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
|
|||
template <typename T>
|
||||
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed<T>()); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Normcdf(T a);
|
||||
|
||||
|
|
@ -337,7 +349,7 @@ struct GridDim {
|
|||
};
|
||||
|
||||
// aligned vector generates vectorized load/store
|
||||
template<typename T, int vec_size>
|
||||
template <typename T, int vec_size>
|
||||
struct alignas(sizeof(T) * vec_size) aligned_vector {
|
||||
T val[vec_size];
|
||||
};
|
||||
|
|
@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector {
|
|||
// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels.
|
||||
// TODO ROCM added support recently, should verify.
|
||||
#define HIP_KERNEL_ASSERT(...)
|
||||
//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
|
||||
// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
|
||||
|
||||
// WARP related definitions and functions
|
||||
constexpr int GPU_WARP_SIZE = warpSize;
|
||||
inline int GPU_WARP_SIZE_HOST= warpSizeDynamic();
|
||||
inline int GPU_WARP_SIZE_HOST = warpSizeDynamic();
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) {
|
||||
|
|
|
|||
|
|
@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
|
||||
|
||||
// OpSet 14
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
|
||||
|
|
@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign)>,
|
||||
|
||||
// OpSet 14
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum)>,
|
||||
|
|
|
|||
|
|
@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
|
|||
if (info.has_user_compute_stream) {
|
||||
external_stream_ = true;
|
||||
stream_ = static_cast<cudaStream_t>(info.user_compute_stream);
|
||||
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
|
||||
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_)));
|
||||
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
|
||||
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_)));
|
||||
}
|
||||
|
||||
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
|
||||
|
|
@ -1860,6 +1864,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
|
|||
} else if (number_of_trt_nodes == number_of_ort_nodes) {
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
|
||||
} else {
|
||||
sync_stream_after_enqueue_ = true;
|
||||
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
|
||||
}
|
||||
|
||||
|
|
@ -2372,7 +2377,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
|
||||
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
|
||||
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
|
||||
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
|
||||
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
|
||||
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
|
||||
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
|
||||
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
|
||||
|
|
@ -2400,6 +2405,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
|
||||
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
|
||||
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
|
||||
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
|
||||
auto fused_node_name = trt_state->fused_node_name;
|
||||
auto& shape_ranges = trt_state->input_shape_ranges;
|
||||
auto trt_builder = trt_state->builder->get();
|
||||
|
|
@ -3001,6 +3007,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
|
||||
}
|
||||
|
||||
if (sync_stream_after_enqueue) {
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
|
||||
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
|
||||
const std::string& output_name = output_binding_names[i];
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ struct TensorrtFuncState {
|
|||
std::vector<std::unordered_map<std::string, size_t>> input_info;
|
||||
std::vector<std::unordered_map<std::string, size_t>> output_info;
|
||||
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
|
||||
bool sync_stream_after_enqueue = false;
|
||||
OrtMutex* tensorrt_mu_ptr = nullptr;
|
||||
bool fp16_enable = false;
|
||||
bool int8_enable = false;
|
||||
|
|
@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
cudnnHandle_t external_cudnn_handle_ = nullptr;
|
||||
cublasHandle_t external_cublas_handle_ = nullptr;
|
||||
|
||||
// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
|
||||
mutable bool sync_stream_after_enqueue_ = false;
|
||||
|
||||
CUDAGraph cuda_graph_;
|
||||
bool is_graph_captured_ = false;
|
||||
int regular_run_count_before_graph_capture_ = 0;
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) {
|
|||
|
||||
TEST(MathOpTest, Sign_uint64) {
|
||||
using namespace test_sign_internal;
|
||||
OpTester test("Sign", 9);
|
||||
OpTester test("Sign", 13);
|
||||
|
||||
std::vector<int64_t> input_dims{7};
|
||||
std::vector<uint64_t> input;
|
||||
|
|
@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) {
|
|||
// we disable this test for openvino as openvino ep supports only FP32 Precision
|
||||
TEST(MathOpTest, Sign_int64) {
|
||||
using namespace test_sign_internal;
|
||||
OpTester test("Sign", 9);
|
||||
OpTester test("Sign", 13);
|
||||
|
||||
std::vector<int64_t> input_dims{7};
|
||||
std::vector<int64_t> input;
|
||||
|
|
@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) {
|
|||
|
||||
TEST(MathOpTest, Sign_float) {
|
||||
using namespace test_sign_internal;
|
||||
OpTester test("Sign", 9);
|
||||
OpTester test("Sign", 13);
|
||||
|
||||
std::vector<int64_t> input_dims{7};
|
||||
std::vector<float> input;
|
||||
|
|
@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) {
|
|||
|
||||
TEST(MathOpTest, Sign_double) {
|
||||
using namespace test_sign_internal;
|
||||
OpTester test("Sign", 9);
|
||||
OpTester test("Sign", 13);
|
||||
|
||||
std::vector<int64_t> input_dims{7};
|
||||
std::vector<double> input;
|
||||
|
|
@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) {
|
|||
}
|
||||
TEST(MathOpTest, Sign_MLFloat16) {
|
||||
using namespace test_sign_internal;
|
||||
OpTester test("Sign", 9);
|
||||
OpTester test("Sign", 13);
|
||||
|
||||
std::vector<int64_t> input_dims{7};
|
||||
std::vector<MLFloat16> input;
|
||||
|
|
|
|||
|
|
@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
|
||||
checkpoint_state
|
||||
.def(py::init())
|
||||
.def("add_property", [](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& property_name,
|
||||
const std::variant<int64_t, float, std::string>& property_value) {
|
||||
state->property_bag.AddProperty(property_name, property_value);
|
||||
})
|
||||
.def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
|
||||
})
|
||||
.def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.HasProperty(property_name);
|
||||
});
|
||||
.def("add_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& property_name,
|
||||
const std::variant<int64_t, float, std::string>& property_value) {
|
||||
state->property_bag.AddProperty(property_name, property_value);
|
||||
})
|
||||
.def("get_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
|
||||
})
|
||||
.def("has_property",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
|
||||
return state->property_bag.HasProperty(property_name);
|
||||
})
|
||||
.def("copy_parameter_from",
|
||||
[](onnxruntime::training::api::CheckpointState* state,
|
||||
const std::string& parameter_name, OrtValue& value) -> void {
|
||||
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == state->module_checkpoint_state.named_parameters.end()) {
|
||||
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
|
||||
}
|
||||
ORT_THROW_IF_ERROR(it->second->CopyFrom(
|
||||
state->module_checkpoint_state.train_session_data_transfer_mgr, value));
|
||||
})
|
||||
.def("get_parameter",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
|
||||
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == state->module_checkpoint_state.named_parameters.end()) {
|
||||
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
|
||||
}
|
||||
return it->second;
|
||||
})
|
||||
.def("has_parameter",
|
||||
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
|
||||
return state->module_checkpoint_state.named_parameters.count(parameter_name);
|
||||
})
|
||||
.def("parameter_names",
|
||||
[](onnxruntime::training::api::CheckpointState* state) {
|
||||
std::vector<std::string> names;
|
||||
for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) {
|
||||
names.push_back(name);
|
||||
}
|
||||
std::sort(names.begin(), names.end());
|
||||
return names;
|
||||
})
|
||||
.def("property_names",
|
||||
[](onnxruntime::training::api::CheckpointState* state) {
|
||||
std::vector<std::string> names;
|
||||
for ([[maybe_unused]] auto& [name, value] : state->property_bag) {
|
||||
names.push_back(name);
|
||||
}
|
||||
std::sort(names.begin(), names.end());
|
||||
return names;
|
||||
});
|
||||
|
||||
py::class_<PyOptimizer>
|
||||
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
|
||||
|
|
@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
ORT_THROW_IF_ERROR(scheduler->Step());
|
||||
});
|
||||
|
||||
py::class_<onnxruntime::training::api::Parameter,
|
||||
std::unique_ptr<onnxruntime::training::api::Parameter, py::nodelete>>
|
||||
parameter(m, "Parameter");
|
||||
parameter
|
||||
.def_property_readonly("name", &onnxruntime::training::api::Parameter::Name)
|
||||
.def_property_readonly("data", &onnxruntime::training::api::Parameter::Data)
|
||||
.def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient)
|
||||
.def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad)
|
||||
.def("copy_from",
|
||||
[](onnxruntime::training::api::Parameter* parameter,
|
||||
onnxruntime::training::api::CheckpointState* state,
|
||||
OrtValue& value) -> void {
|
||||
ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"save_checkpoint",
|
||||
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,198 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
||||
|
||||
|
||||
class Parameter:
|
||||
"""Class that represents a model parameter
|
||||
|
||||
This class represents a model parameter and provides access to its data,
|
||||
gradient and other properties. This class is not expected to be instantiated directly.
|
||||
Instead, it is returned by the `CheckpointState` object.
|
||||
|
||||
Args:
|
||||
parameter: The C.Parameter object that holds the underlying parameter data.
|
||||
state: The C.CheckpointState object that holds the underlying session state.
|
||||
"""
|
||||
|
||||
def __init__(self, parameter: C.Parameter, state: C.CheckpointState):
|
||||
self._parameter = parameter
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the parameter"""
|
||||
return self._parameter.name
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
"""The data of the parameter"""
|
||||
return self._parameter.data.numpy()
|
||||
|
||||
@data.setter
|
||||
def data(self, value: np.ndarray) -> None:
|
||||
"""Sets the data of the parameter"""
|
||||
self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue)
|
||||
|
||||
@property
|
||||
def grad(self) -> np.ndarray:
|
||||
"""The gradient of the parameter"""
|
||||
return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None
|
||||
|
||||
@property
|
||||
def requires_grad(self) -> bool:
|
||||
"""Whether or not the parameter requires its gradient to be computed"""
|
||||
return self._parameter.requires_grad
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the parameter"""
|
||||
return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"
|
||||
|
||||
|
||||
class Parameters:
|
||||
"""Class that holds all the model parameters
|
||||
|
||||
This class holds all the model parameters and provides access to them.
|
||||
This class is not expected to be instantiated directly. Instead, it is returned by the
|
||||
`CheckpointState`'s parameters attribute.
|
||||
This class behaves like a dictionary and provides access to the parameters by name.
|
||||
|
||||
Args:
|
||||
state: The C.CheckpointState object that holds the underlying session state.
|
||||
"""
|
||||
|
||||
def __init__(self, state: C.CheckpointState):
|
||||
self._state = state
|
||||
|
||||
def __getitem__(self, name: str) -> Parameter:
|
||||
"""Gets the parameter associated with the given name
|
||||
|
||||
Searches for the name in the parameters of the checkpoint state.
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
|
||||
Returns:
|
||||
The value of the parameter
|
||||
|
||||
Raises:
|
||||
KeyError: If the parameter is not found
|
||||
"""
|
||||
|
||||
if name not in self:
|
||||
raise KeyError(f"Parameter {name} not found.")
|
||||
|
||||
return Parameter(self._state.get_parameter(name), self._state)
|
||||
|
||||
def __setitem__(self, name: str, value: np.ndarray) -> None:
|
||||
"""Sets the parameter value for the given name
|
||||
|
||||
Searches for the name in the parameters of the checkpoint state.
|
||||
If the name is found in parameters, the value is updated.
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
value: The value of the parameter as a numpy array
|
||||
|
||||
Raises:
|
||||
KeyError: If the parameter is not found
|
||||
"""
|
||||
if name not in self:
|
||||
raise KeyError(f"Parameter {name} not found.")
|
||||
|
||||
self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the parameter exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the parameter
|
||||
|
||||
Returns:
|
||||
True if the name is a parameter False otherwise
|
||||
"""
|
||||
|
||||
return self._state.has_parameter(name)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the properties"""
|
||||
for parameter_name in self._state.parameter_names():
|
||||
yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the parameters"""
|
||||
return self._state.parameter_names()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of parameters"""
|
||||
return len(self._state.parameter_names())
|
||||
|
||||
|
||||
class Properties:
|
||||
def __init__(self, state: C.CheckpointState):
|
||||
self._state = state
|
||||
|
||||
def __getitem__(self, name: str) -> int | float | str:
|
||||
"""Gets the property associated with the given name
|
||||
|
||||
Searches for the name in the properties of the checkpoint state.
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
The value of the property
|
||||
|
||||
Raises:
|
||||
KeyError: If the property is not found
|
||||
"""
|
||||
|
||||
if name not in self:
|
||||
raise KeyError(f"Property {name} not found.")
|
||||
|
||||
return self._state.get_property(name)
|
||||
|
||||
def __setitem__(self, name: str, value: int | float | str) -> None:
|
||||
"""Sets the property value for the given name
|
||||
|
||||
Searches for the name in the properties of the checkpoint state.
|
||||
The value is added or updated in the properties.
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
value: The value of the property
|
||||
Properties only support int, float and str values.
|
||||
"""
|
||||
self._state.add_property(name, value)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the property exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
True if the name is a property, False otherwise
|
||||
"""
|
||||
|
||||
return self._state.has_property(name)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the properties"""
|
||||
for property_name in self._state.property_names():
|
||||
yield property_name, self._state.get_property(property_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the properties"""
|
||||
return self._state.property_names()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of properties"""
|
||||
return len(self._state.property_names())
|
||||
|
||||
|
||||
class CheckpointState:
|
||||
|
|
@ -14,8 +205,6 @@ class CheckpointState:
|
|||
This class holds all the state information of the training session such as the model parameters,
|
||||
its gradients, the optimizer state and user defined properties.
|
||||
|
||||
User defined properties can be indexed by name from the `CheckpointState` object.
|
||||
|
||||
To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method.
|
||||
|
||||
Args:
|
||||
|
|
@ -26,6 +215,8 @@ class CheckpointState:
|
|||
if not isinstance(state, C.CheckpointState):
|
||||
raise TypeError(f"Invalid argument for CheckpointState received {type(state)}")
|
||||
self._state = state
|
||||
self._parameters = Parameters(self._state)
|
||||
self._properties = Properties(self._state)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState:
|
||||
|
|
@ -52,33 +243,12 @@ class CheckpointState:
|
|||
"""
|
||||
C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
|
||||
|
||||
def __getitem__(self, name: str) -> int | float | str:
|
||||
"""Gets the property associated with the given name
|
||||
@property
|
||||
def parameters(self) -> Parameters:
|
||||
"""Returns the model parameters from the checkpoint state"""
|
||||
return self._parameters
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
The value of the property
|
||||
"""
|
||||
return self._state.get_property(name)
|
||||
|
||||
def __setitem__(self, name: str, value: int | float | str) -> None:
|
||||
"""Sets the property value for the given name
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
value: The value of the property
|
||||
"""
|
||||
self._state.add_property(name, value)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Checks if the property exists in the state
|
||||
|
||||
Args:
|
||||
name: The name of the property
|
||||
|
||||
Returns:
|
||||
True if the property exists, False otherwise
|
||||
"""
|
||||
return self._state.has_property(name)
|
||||
@property
|
||||
def properties(self) -> Properties:
|
||||
"""Returns the properties from the checkpoint state"""
|
||||
return self._properties
|
||||
|
|
|
|||
|
|
@ -360,14 +360,18 @@ def test_add_get_property(property_value):
|
|||
if isinstance(property_value, float):
|
||||
property_value = float(np.float32(property_value))
|
||||
|
||||
state["property"] = property_value
|
||||
assert "property" in state
|
||||
assert state["property"] == property_value
|
||||
assert len(state.properties) == 0
|
||||
|
||||
state.properties["property"] = property_value
|
||||
assert "property" in state.properties
|
||||
assert state.properties["property"] == property_value
|
||||
assert len(state.properties) == 1
|
||||
|
||||
CheckpointState.save_checkpoint(state, checkpoint_file_path)
|
||||
new_state = CheckpointState.load_checkpoint(checkpoint_file_path)
|
||||
assert "property" in new_state
|
||||
assert new_state["property"] == property_value
|
||||
assert "property" in new_state.properties
|
||||
assert new_state.properties["property"] == property_value
|
||||
assert len(new_state.properties) == 1
|
||||
|
||||
|
||||
def test_get_input_output_names():
|
||||
|
|
@ -563,3 +567,60 @@ def test_eval_step_with_ort_values():
|
|||
fetches = model(inputs, labels)
|
||||
assert isinstance(fetches, OrtValue)
|
||||
assert fetches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_get_and_set_parameter_values(device):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
(
|
||||
checkpoint_file_path,
|
||||
training_model_file_path,
|
||||
eval_model_file_path,
|
||||
_,
|
||||
pt_model,
|
||||
) = _create_training_artifacts(
|
||||
temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
|
||||
)
|
||||
|
||||
state = CheckpointState.load_checkpoint(checkpoint_file_path)
|
||||
|
||||
model = Module(training_model_file_path, state, eval_model_file_path, device=device)
|
||||
|
||||
state_dict = pt_model.state_dict()
|
||||
assert len(state_dict) == len(state.parameters)
|
||||
for parameter_name, _ in state.parameters:
|
||||
assert parameter_name in state_dict
|
||||
|
||||
for name, pt_param in pt_model.named_parameters():
|
||||
ort_param = state.parameters[name]
|
||||
assert ort_param.name == name
|
||||
assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data)
|
||||
if name in ["fc1.weight", "fc1.bias"]:
|
||||
assert ort_param.requires_grad is False
|
||||
assert ort_param.grad is None
|
||||
else:
|
||||
assert ort_param.requires_grad is True
|
||||
assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32))
|
||||
|
||||
original_param = state.parameters["fc1.weight"].data
|
||||
state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32)
|
||||
updated_param = state.parameters["fc1.weight"].data
|
||||
assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32))
|
||||
|
||||
model.train()
|
||||
inputs = torch.randn(64, 784).numpy()
|
||||
labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
|
||||
loss = model(inputs, labels)
|
||||
assert loss is not None
|
||||
for name, _ in pt_model.named_parameters():
|
||||
ort_param = state.parameters[name]
|
||||
assert ort_param.name == name
|
||||
if name in ["fc1.weight", "fc1.bias"]:
|
||||
assert ort_param.requires_grad is False
|
||||
assert ort_param.grad is None
|
||||
else:
|
||||
assert ort_param.requires_grad is True
|
||||
assert ort_param.grad.any()
|
||||
|
||||
state.parameters["fc1.weight"] = original_param
|
||||
assert np.allclose(state.parameters["fc1.weight"].data, original_param)
|
||||
|
|
|
|||
|
|
@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
|
|||
testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, GetParameter) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
}
|
||||
|
||||
TEST(TrainingCApiTest, UpdateParameter) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
|
||||
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
|
||||
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
|
||||
Ort::Value updated_parameter{updated_param_value};
|
||||
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
|
||||
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
|
||||
checkpoint_state.UpdateParameter("fc1.weight", parameter);
|
||||
current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(TrainingCApiTest, UpdateParameterDifferentDevices) {
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
Ort::Env env;
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
|
||||
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
|
||||
Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri);
|
||||
|
||||
Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
auto tensor_info = parameter.GetTensorTypeAndShapeInfo();
|
||||
auto shape = tensor_info.GetShape();
|
||||
ASSERT_EQ(shape.size(), 2U);
|
||||
ASSERT_EQ(shape.front(), static_cast<int64_t>(500));
|
||||
ASSERT_EQ(shape.back(), static_cast<int64_t>(784));
|
||||
|
||||
OrtValue* updated_param_value = std::make_unique<OrtValue>().release();
|
||||
GenerateRandomInput(std::array<int64_t, 2>{500, 784}, *updated_param_value);
|
||||
Ort::Value updated_parameter{updated_param_value};
|
||||
checkpoint_state.UpdateParameter("fc1.weight", updated_parameter);
|
||||
|
||||
Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
gsl::span actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
gsl::span not_expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
|
||||
checkpoint_state.UpdateParameter("fc1.weight", parameter);
|
||||
current_parameter = checkpoint_state.GetParameter("fc1.weight");
|
||||
actual = gsl::span(current_parameter.GetTensorMutableData<float>(),
|
||||
current_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
expected = gsl::span(parameter.GetTensorMutableData<float>(),
|
||||
parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
not_expected = gsl::span(updated_parameter.GetTensorMutableData<float>(),
|
||||
updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount());
|
||||
ASSERT_EQ(actual, expected);
|
||||
ASSERT_NE(actual, not_expected);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime::training::test
|
||||
|
|
|
|||
|
|
@ -22,10 +22,12 @@ struct PropertyBag {
|
|||
PropertyBag() = default;
|
||||
|
||||
void AddProperty(const std::string& name, const PropertyDataType& val) {
|
||||
ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
|
||||
"Duplicated property named ", name);
|
||||
|
||||
named_properties_.insert({name, val});
|
||||
auto it = named_properties_.find(name);
|
||||
if (it == named_properties_.end()) {
|
||||
named_properties_.insert({name, val});
|
||||
} else {
|
||||
it->second = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -608,14 +608,14 @@ struct OrtTrainingApi {
|
|||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Adds the given property to the checkpoint state.
|
||||
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user if they desire by calling this function with the appropriate property name and
|
||||
* value. The given property name must be unique to be able to successfully add the property.
|
||||
* state by the user by calling this function with the corresponding property name and value.
|
||||
* The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
||||
* \param[in] property_name Unique name of the property being added.
|
||||
* \param[in] property_name Name of the property being added or updated.
|
||||
* \param[in] property_type Type of the property associated with the given name.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
|
|
@ -632,7 +632,7 @@ struct OrtTrainingApi {
|
|||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
||||
* \param[in] property_name Unique name of the property being retrieved.
|
||||
* \param[in] property_name Name of the property being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
||||
* \param[out] property_type Type of the property associated with the given name.
|
||||
* \param[out] property_value Property value associated with the given name.
|
||||
|
|
@ -669,6 +669,57 @@ struct OrtTrainingApi {
|
|||
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
|
||||
*
|
||||
* This function retrieves the type and shape of the parameter associated with the given parameter name.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||
|
||||
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
* The training session must be already created with the checkpoint state that contains the parameter
|
||||
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being updated.
|
||||
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||
|
||||
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
* The parameter is copied over and returned as an OrtValue. The training session must be already created
|
||||
* with the checkpoint state that contains the parameter being retrieved.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the parameter.
|
||||
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -112,13 +112,13 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
||||
const bool include_optimizer_state = false);
|
||||
|
||||
/** \brief Adds the given property to the checkpoint state.
|
||||
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user if they desire by calling this function with the appropriate property name and
|
||||
* value. The given property name must be unique to be able to successfully add the property.
|
||||
* state by the user by calling this function with the corresponding property name and value.
|
||||
* The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] property_name Unique name of the property being added.
|
||||
* \param[in] property_name Name of the property being added or updated.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
*/
|
||||
|
|
@ -129,12 +129,38 @@ class CheckpointState : public detail::Base<OrtCheckpointState> {
|
|||
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] property_name Unique name of the property being retrieved.
|
||||
* \param[in] property_name Name of the property being retrieved.
|
||||
* \return Property value associated with the given property name.
|
||||
*
|
||||
*/
|
||||
Property GetProperty(const std::string& property_name);
|
||||
|
||||
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
* The training session must be already created with the checkpoint state that contains the parameter
|
||||
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
*
|
||||
* \param[in] parameter_name Name of the parameter being updated.
|
||||
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||
*
|
||||
*/
|
||||
void UpdateParameter(const std::string& parameter_name, const Value& parameter);
|
||||
|
||||
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
* The parameter is copied over to the provided OrtValue. The training session must be already created
|
||||
* with the checkpoint state that contains the parameter being retrieved.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \return The parameter data that is retrieved from the checkpoint state.
|
||||
*
|
||||
*/
|
||||
Value GetParameter(const std::string& parameter_name);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) {
|
|||
return property;
|
||||
}
|
||||
|
||||
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
|
||||
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
|
||||
}
|
||||
|
||||
inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
|
||||
AllocatorWithDefaultOptions allocator;
|
||||
OrtValue* parameter;
|
||||
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter));
|
||||
|
||||
return Value{parameter};
|
||||
}
|
||||
|
||||
} // namespace Ort
|
||||
|
|
|
|||
|
|
@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph,
|
|||
#endif
|
||||
} // namespace
|
||||
|
||||
Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const {
|
||||
ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it.");
|
||||
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
|
||||
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
|
||||
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
|
||||
", Got: ", data.Get<Tensor>().Shape().ToString());
|
||||
#ifdef ENABLE_STRIDED_TENSORS
|
||||
auto data_strides = data.Get<Tensor>().Strides();
|
||||
auto param_strides = data_.Get<Tensor>().Strides();
|
||||
ORT_ENFORCE(data_strides.size() == param_strides.size(),
|
||||
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
|
||||
", Got: ", data_strides.size());
|
||||
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
|
||||
"Parameter data stride value mismatch.");
|
||||
#endif
|
||||
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
|
||||
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
|
||||
", Got: ", data.Get<Tensor>().DataType());
|
||||
ORT_ENFORCE(data_transfer_manager != nullptr,
|
||||
"Data transfer manager must be provided to copy data to the parameter. "
|
||||
"Please create the TrainingSession before trying to update the parameter.");
|
||||
|
||||
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get<Tensor>(), *data.GetMutable<Tensor>()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) {
|
||||
ORT_ENFORCE(data_.IsAllocated(),
|
||||
"The checkpoint parameter is not allocated. Cannot copy the given parameter data to it.");
|
||||
ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type.");
|
||||
ORT_ENFORCE(data.Get<Tensor>().Shape() == data_.Get<Tensor>().Shape(),
|
||||
"Parameter data shape mismatch. Expected: ", data_.Get<Tensor>().Shape().ToString(),
|
||||
", Got: ", data.Get<Tensor>().Shape().ToString());
|
||||
#ifdef ENABLE_STRIDED_TENSORS
|
||||
auto data_strides = data.Get<Tensor>().Strides();
|
||||
auto param_strides = data_.Get<Tensor>().Strides();
|
||||
ORT_ENFORCE(data_strides.size() == param_strides.size(),
|
||||
"Parameter data stride mismatch. Expected strides of size: ", param_strides.size(),
|
||||
", Got: ", data_strides.size());
|
||||
ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()),
|
||||
"Parameter data stride value mismatch.");
|
||||
#endif
|
||||
ORT_ENFORCE(data.Get<Tensor>().DataType() == data_.Get<Tensor>().DataType(),
|
||||
"Parameter data type mismatch. Expected: ", data_.Get<Tensor>().DataType(),
|
||||
", Got: ", data.Get<Tensor>().DataType());
|
||||
ORT_ENFORCE(data_transfer_manager != nullptr,
|
||||
"Data transfer manager must be provided to copy data to the parameter. "
|
||||
"Please create the TrainingSession before trying to update the parameter.");
|
||||
|
||||
ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get<Tensor>(), *data_.GetMutable<Tensor>()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) {
|
||||
// assert param is allocated
|
||||
ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient.");
|
||||
|
|
@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers,
|
|||
}
|
||||
}
|
||||
|
||||
Module::~Module() {
|
||||
state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr;
|
||||
}
|
||||
|
||||
size_t Module::GetTrainingModelOutputCount() const noexcept {
|
||||
return train_output_names_.size();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ struct Parameter {
|
|||
|
||||
// Return the mutable data.
|
||||
OrtValue& Data() { return data_; }
|
||||
Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const;
|
||||
Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data);
|
||||
const std::string& Name() const { return name_; }
|
||||
|
||||
// Returns whether this parameter is trainable or not.
|
||||
|
|
@ -34,7 +36,6 @@ struct Parameter {
|
|||
// Reset and release the gradient buffer of this Parameter greedily.
|
||||
Status ResetGrad();
|
||||
|
||||
protected:
|
||||
Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad);
|
||||
|
||||
private:
|
||||
|
|
@ -83,6 +84,8 @@ struct Module {
|
|||
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
|
||||
gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());
|
||||
|
||||
~Module();
|
||||
|
||||
// Return the trainable/nontrainable parameters
|
||||
std::vector<std::shared_ptr<Parameter>> Parameters() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void*
|
|||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
if (checkpoint_buffer == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr.");
|
||||
}
|
||||
|
||||
*checkpoint_state = nullptr;
|
||||
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
|
||||
const auto* checkpoint_bytes = reinterpret_cast<const uint8_t*>(checkpoint_buffer);
|
||||
|
|
@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState*
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape);
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter) {
|
||||
API_IMPL_BEGIN
|
||||
if (parameter == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
|
||||
}
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom(
|
||||
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter));
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
if (parameter == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr.");
|
||||
}
|
||||
|
||||
auto chkpt_state = reinterpret_cast<const onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name);
|
||||
if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) {
|
||||
std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state.";
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
if (!it->second->Data().IsTensor()) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type.");
|
||||
}
|
||||
const auto& parameter_tensor = it->second->Data().Get<onnxruntime::Tensor>();
|
||||
ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue(
|
||||
allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(),
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter));
|
||||
|
||||
auto status = it->second->CopyTo(
|
||||
chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter);
|
||||
if (!status.IsOK()) {
|
||||
OrtApis::ReleaseValue(*parameter);
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
static constexpr OrtTrainingApi ort_training_api = {
|
||||
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
|
||||
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
|
||||
|
|
@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = {
|
|||
&OrtTrainingApis::TrainingSessionGetEvalModelInputName,
|
||||
&OrtTrainingApis::AddProperty,
|
||||
&OrtTrainingApis::GetProperty,
|
||||
&OrtTrainingApis::LoadCheckpointFromBuffer};
|
||||
&OrtTrainingApis::LoadCheckpointFromBuffer,
|
||||
&OrtTrainingApis::GetParameterTypeAndShape,
|
||||
&OrtTrainingApis::UpdateParameter,
|
||||
&OrtTrainingApis::GetParameter};
|
||||
|
||||
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
|
||||
// No constraints on the API version yet.
|
||||
|
|
|
|||
|
|
@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state
|
|||
ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||
|
||||
ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||
|
||||
ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter);
|
||||
|
||||
} // namespace OrtTrainingApis
|
||||
|
|
|
|||
|
|
@ -163,10 +163,8 @@ std::unordered_map<std::string, std::pair<std::string, std::string>> disabledGpu
|
|||
test name -> absolute difference sampleTolerance
|
||||
*/
|
||||
std::unordered_map<std::string, double> sampleTolerancePerTests({
|
||||
{"fp16_inception_v1_opset7_GPU",0.005 },
|
||||
{"fp16_inception_v1_opset8_GPU", 0.005},
|
||||
{ "candy_opset9_GPU",
|
||||
0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
|
||||
{ "fp16_tiny_yolov2_opset8_GPU",
|
||||
0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
|
||||
{"fp16_inception_v1_opset7_GPU", 0.005},
|
||||
{"fp16_inception_v1_opset8_GPU", 0.005},
|
||||
{ "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
|
||||
{ "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue