diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 445662151f..63713a3977 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -7,6 +7,7 @@ using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { #if __ENABLE_TRAINING_ON_DEVICE__ + // NOTE: The order of the APIs in this struct should match exactly that in // OrtTrainingApi (onnxruntime_training_c_api.cc) [StructLayout(LayoutKind.Sequential)] public struct OrtTrainingApi @@ -26,6 +27,9 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr OptimizerStep; public IntPtr RegisterLinearLRScheduler; public IntPtr SchedulerStep; + public IntPtr GetParametersSize; + public IntPtr CopyParametersToBuffer; + public IntPtr CopyBufferToParameters; public IntPtr ReleaseTrainingSession; public IntPtr ReleaseCheckpointState; } diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index eb95549433..6cd45f7f15 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -312,7 +312,7 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CopyParametersToBuffer, _Inout_ OrtTraining } auto session = reinterpret_cast(sess); ORT_API_RETURN_IF_STATUS_NOT_OK(session->CopyParametersToBuffer(*parameters_buffer, trainable_only)); - + return nullptr; API_IMPL_END } @@ -339,6 +339,9 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp } static constexpr OrtTrainingApi ort_training_api = { + // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially + // released, it is OK to change the order here, however a corresponding matching change should also be done in the + // "OrtTrainingApi" struct in NativeTrainingMethods.shared.cs &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession,