[Bug Fix] Incorrect comparison for FromBuffer in TrainingSession.cs (#16022)

This commit is contained in:
Baiju Meswani 2023-05-22 21:21:54 -07:00 committed by GitHub
parent 2fddc65c8c
commit de0a973b6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -496,21 +496,22 @@ namespace Microsoft.ML.OnnxRuntime
throw new ArgumentException(errorMessage);
}
IntPtr numElementsTrainingOnly = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
// Here buffer size represents the number of elements in the buffer
IntPtr bufferSize = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out bufferSize));
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
// OrtGetParametersSize returns the total number of elements in the model's parameters.
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
if (bufferSize.ToInt64() == (long)numElementsTrainingOnly.ToUInt64())
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
return;
}
IntPtr numElements = IntPtr.Zero;
bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
UIntPtr numElements = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
if (bufferSize.ToInt64() != (long)numElements.ToUInt64())
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
throw new ArgumentException(errorMessage);