From 4fc8f7139ac13ee83d6806111dfd608981ee2f84 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 4 Oct 2022 09:29:20 -0700 Subject: [PATCH] Bug Fix - C# API order incompatibile with C API (#13191) ### Description Training C# bindings (ReleaseTrainingSession and ReleaseCheckpointState) broke after an API order change in Training C API. This PR fixes this issue. ### Motivation and Context Bug Fix for Training C# bindings --- .../Training/NativeTrainingMethods.shared.cs | 4 ++++ .../orttraining/training_api/onnxruntime_training_c_api.cc | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 445662151f..63713a3977 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -7,6 +7,7 @@ using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { #if __ENABLE_TRAINING_ON_DEVICE__ + // NOTE: The order of the APIs in this struct should match exactly that in // OrtTrainingApi (onnxruntime_training_c_api.cc) [StructLayout(LayoutKind.Sequential)] public struct OrtTrainingApi @@ -26,6 +27,9 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr OptimizerStep; public IntPtr RegisterLinearLRScheduler; public IntPtr SchedulerStep; + public IntPtr GetParametersSize; + public IntPtr CopyParametersToBuffer; + public IntPtr CopyBufferToParameters; public IntPtr ReleaseTrainingSession; public IntPtr ReleaseCheckpointState; } diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index eb95549433..6cd45f7f15 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -312,7 +312,7 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CopyParametersToBuffer, _Inout_ OrtTraining } auto session = reinterpret_cast(sess); ORT_API_RETURN_IF_STATUS_NOT_OK(session->CopyParametersToBuffer(*parameters_buffer, trainable_only)); - + return nullptr; API_IMPL_END } @@ -339,6 +339,9 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp } static constexpr OrtTrainingApi ort_training_api = { + // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially + // released, it is OK to change the order here, however a corresponding matching change should also be done in the + // "OrtTrainingApi" struct in NativeTrainingMethods.shared.cs &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession,