onnxruntime/java/src/main/native/ai_onnxruntime_OrtTrainingSession_OrtCheckpointState.c

95 lines
4.1 KiB
C
Raw Normal View History

/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <string.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtTrainingSession_OrtCheckpointState.h"
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: loadCheckpoint
* Signature: (JJLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_loadCheckpoint
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong trainingApiHandle, jstring directory) {
(void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpoint = NULL;
#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, directory, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, directory);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, directory, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api,
trainApi->LoadCheckpoint(newString, &checkpoint));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, directory, cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, directory, NULL);
checkOrtStatus(jniEnv, api, trainApi->LoadCheckpoint(cPath, &checkpoint));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, directory, cPath);
#endif
return (jlong) checkpoint;
}
/*
* Class: ai_onnxruntime_OrtTrainingSession
* Method: saveCheckpoint
* Signature: (JJJLjava/lang/String;Z)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_saveCheckpoint
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring outputPath, jboolean saveOptimizer) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
#ifdef _WIN32
// The output of GetStringChars is not null-terminated, so we copy it and add a terminator
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, outputPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, outputPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api,
trainApi->SaveCheckpoint(checkpointState, newString, saveOptimizer));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath);
}
#else
// GetStringUTFChars is null terminated, so can be used directly
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, outputPath, NULL);
checkOrtStatus(jniEnv, api, trainApi->SaveCheckpoint(checkpointState, cPath, saveOptimizer));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, outputPath, cPath);
#endif
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_close
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtTrainingApi* api = (const OrtTrainingApi*) apiHandle;
api->ReleaseCheckpointState((OrtCheckpointState*) handle);
}