diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 558ac3de36..d4a8750999 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -496,21 +496,22 @@ namespace Microsoft.ML.OnnxRuntime throw new ArgumentException(errorMessage); } - IntPtr numElementsTrainingOnly = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly)); + // Here buffer size represents the number of elements in the buffer + IntPtr bufferSize = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out bufferSize)); - UIntPtr bufferSize = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true)); - if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64()) + // OrtGetParametersSize returns the total number of elements in the model's parameters. + UIntPtr numElementsTrainingOnly = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); + if (bufferSize.ToInt64() == (long)numElementsTrainingOnly.ToUInt64()) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); return; } - IntPtr numElements = IntPtr.Zero; - bufferSize = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false)); - if ((long)bufferSize.ToUInt64() != numElements.ToInt64()) + UIntPtr numElements = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); + if (bufferSize.ToInt64() != (long)numElements.ToUInt64()) { string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); throw new ArgumentException(errorMessage);