mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
[Bug Fix] Incorrect comparison for FromBuffer in TrainingSession.cs (#16022)
This commit is contained in:
parent
2fddc65c8c
commit
de0a973b6e
1 changed files with 10 additions and 9 deletions
|
|
@ -496,21 +496,22 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
throw new ArgumentException(errorMessage);
|
||||
}
|
||||
|
||||
IntPtr numElementsTrainingOnly = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
|
||||
// Here buffer size represents the number of elements in the buffer
|
||||
IntPtr bufferSize = IntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out bufferSize));
|
||||
|
||||
UIntPtr bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
|
||||
if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
|
||||
// OrtGetParametersSize returns the total number of elements in the model's parameters.
|
||||
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
|
||||
if (bufferSize.ToInt64() == (long)numElementsTrainingOnly.ToUInt64())
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
|
||||
return;
|
||||
}
|
||||
|
||||
IntPtr numElements = IntPtr.Zero;
|
||||
bufferSize = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
|
||||
if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
|
||||
UIntPtr numElements = UIntPtr.Zero;
|
||||
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
|
||||
if (bufferSize.ToInt64() != (long)numElements.ToUInt64())
|
||||
{
|
||||
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
|
||||
throw new ArgumentException(errorMessage);
|
||||
|
|
|
|||
Loading…
Reference in a new issue