Bug Fix - C# API order incompatibile with C API (#13191)

### Description
Training C# bindings (ReleaseTrainingSession and ReleaseCheckpointState)
broke after an API order change in Training C API. This PR fixes this
issue.



### Motivation and Context
Bug Fix for Training C# bindings
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Ashwini Khade 2022-10-04 09:29:20 -07:00 committed by GitHub
parent 595a0c8658
commit 4fc8f7139a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View file

@ -7,6 +7,7 @@ using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
// NOTE: The order of the APIs in this struct should match exactly that in
// OrtTrainingApi (onnxruntime_training_c_api.cc)
[StructLayout(LayoutKind.Sequential)]
public struct OrtTrainingApi
@ -26,6 +27,9 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr OptimizerStep;
public IntPtr RegisterLinearLRScheduler;
public IntPtr SchedulerStep;
public IntPtr GetParametersSize;
public IntPtr CopyParametersToBuffer;
public IntPtr CopyBufferToParameters;
public IntPtr ReleaseTrainingSession;
public IntPtr ReleaseCheckpointState;
}

View file

@ -312,7 +312,7 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::CopyParametersToBuffer, _Inout_ OrtTraining
}
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
ORT_API_RETURN_IF_STATUS_NOT_OK(session->CopyParametersToBuffer(*parameters_buffer, trainable_only));
return nullptr;
API_IMPL_END
}
@ -339,6 +339,9 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp
}
static constexpr OrtTrainingApi ort_training_api = {
// NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially
// released, it is OK to change the order here, however a corresponding matching change should also be done in the
// "OrtTrainingApi" struct in NativeTrainingMethods.shared.cs
&OrtTrainingApis::LoadCheckpoint,
&OrtTrainingApis::SaveCheckpoint,
&OrtTrainingApis::CreateTrainingSession,