mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[Java] CheckpointState AddProperty & GetProperty support (#15730)
This commit is contained in:
parent
be08b47e7b
commit
8a1a40ac63
3 changed files with 510 additions and 8 deletions
|
|
@ -221,6 +221,72 @@ public final class OrtTrainingSession implements AutoCloseable {
|
|||
return evalOutputNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a float property to this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, float value) throws OrtException {
|
||||
checkpoint.addProperty(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a int property to this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, int value) throws OrtException {
|
||||
checkpoint.addProperty(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a String property to this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, String value) throws OrtException {
|
||||
checkpoint.addProperty(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a float property from this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public float getFloatProperty(String name) throws OrtException {
|
||||
return checkpoint.getFloatProperty(allocator, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a int property from this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public int getIntProperty(String name) throws OrtException {
|
||||
return checkpoint.getIntProperty(allocator, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a String property from this training session checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public String getStringProperty(String name) throws OrtException {
|
||||
return checkpoint.getStringProperty(allocator, name);
|
||||
}
|
||||
|
||||
/** Checks if the OrtTrainingSession is closed, if so throws {@link IllegalStateException}. */
|
||||
private void checkClosed() {
|
||||
if (closed) {
|
||||
|
|
@ -927,6 +993,93 @@ public final class OrtTrainingSession implements AutoCloseable {
|
|||
saveOptimizer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a float property to this checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, float value) throws OrtException {
|
||||
addProperty(
|
||||
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a int property to this checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, int value) throws OrtException {
|
||||
addProperty(
|
||||
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a String property to this checkpoint.
|
||||
*
|
||||
* @param name The property name.
|
||||
* @param value The property value.
|
||||
* @throws OrtException If the call failed.
|
||||
*/
|
||||
public void addProperty(String name, String value) throws OrtException {
|
||||
addProperty(
|
||||
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a float property from this checkpoint.
|
||||
*
|
||||
* @param allocator The allocator.
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException {
|
||||
return getFloatProperty(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
OnnxRuntime.ortTrainingApiHandle,
|
||||
nativeHandle,
|
||||
allocator.handle,
|
||||
name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a int property from this checkpoint.
|
||||
*
|
||||
* @param allocator The allocator.
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public int getIntProperty(OrtAllocator allocator, String name) throws OrtException {
|
||||
return getIntProperty(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
OnnxRuntime.ortTrainingApiHandle,
|
||||
nativeHandle,
|
||||
allocator.handle,
|
||||
name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a String property from this checkpoint.
|
||||
*
|
||||
* @param allocator The allocator.
|
||||
* @param name The property name.
|
||||
* @return The property value.
|
||||
* @throws OrtException If the property does not exist, or is of the wrong type.
|
||||
*/
|
||||
public String getStringProperty(OrtAllocator allocator, String name) throws OrtException {
|
||||
return getStringProperty(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
OnnxRuntime.ortTrainingApiHandle,
|
||||
nativeHandle,
|
||||
allocator.handle,
|
||||
name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
close(OnnxRuntime.ortTrainingApiHandle, nativeHandle);
|
||||
|
|
@ -969,6 +1122,88 @@ public final class OrtTrainingSession implements AutoCloseable {
|
|||
long apiHandle, long trainingHandle, long nativeHandle, String path, boolean saveOptimizer)
|
||||
throws OrtException;
|
||||
|
||||
/* \brief Adds the given property to the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user if they desire by calling this function with the appropriate property name and
|
||||
* value. The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
||||
* \param[in] property_name Unique name of the property being added.
|
||||
* \param[in] property_type Type of the property associated with the given name.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
|
||||
_In_ void* property_value);
|
||||
*/
|
||||
private native void addProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
String propertyName,
|
||||
int propertyValue)
|
||||
throws OrtException;
|
||||
|
||||
private native void addProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
String propertyName,
|
||||
float propertyValue)
|
||||
throws OrtException;
|
||||
|
||||
private native void addProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
String propertyName,
|
||||
String propertyValue)
|
||||
throws OrtException;
|
||||
|
||||
/* \brief Gets the property value associated with the given name from the checkpoint state.
|
||||
*
|
||||
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
||||
* \param[in] property_name Unique name of the property being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
||||
* \param[out] property_type Type of the property associated with the given name.
|
||||
* \param[out] property_value Property value associated with the given name.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
|
||||
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
||||
*/
|
||||
private native int getIntProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
long allocatorHandle,
|
||||
String propertyName)
|
||||
throws OrtException;
|
||||
|
||||
private native float getFloatProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
long allocatorHandle,
|
||||
String propertyName)
|
||||
throws OrtException;
|
||||
|
||||
private native String getStringProperty(
|
||||
long apiHandle,
|
||||
long trainingHandle,
|
||||
long nativeHandle,
|
||||
long allocatorHandle,
|
||||
String propertyName)
|
||||
throws OrtException;
|
||||
|
||||
private native void close(long trainingApiHandle, long nativeHandle);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -81,6 +81,158 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpoint
|
|||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: addProperty
|
||||
* Signature: (JJJLjava/lang/String;I)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2I
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jint propValue) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtIntProperty, &propValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: addProperty
|
||||
* Signature: (JJJLjava/lang/String;F)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2F
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jfloat propValue) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtFloatProperty, &propValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: addProperty
|
||||
* Signature: (JJJLjava/lang/String;Ljava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2Ljava_lang_String_2
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jstring propValue) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
const char* cPropValue = (*jniEnv)->GetStringUTFChars(jniEnv, propValue, NULL);
|
||||
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtStringProperty, (void*)cPropValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propValue, cPropValue);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: getIntProperty
|
||||
* Signature: (JJJJLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getIntProperty
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
enum OrtPropertyType type;
|
||||
int* propValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
if (code == ORT_OK) {
|
||||
if (type == OrtIntProperty) {
|
||||
int output = *propValue;
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
|
||||
return output;
|
||||
} else {
|
||||
throwOrtException(jniEnv, 2, "Requested an int property but this property is not an int");
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: getFloatProperty
|
||||
* Signature: (JJJJLjava/lang/String;)F
|
||||
*/
|
||||
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getFloatProperty
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
enum OrtPropertyType type;
|
||||
float* propValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
if (code == ORT_OK) {
|
||||
if (type == OrtFloatProperty) {
|
||||
float output = *propValue;
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
|
||||
return output;
|
||||
} else {
|
||||
throwOrtException(jniEnv, 2, "Requested a float property but this property is not a float");
|
||||
}
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: getStringProperty
|
||||
* Signature: (JJJJLjava/lang/String;)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getStringProperty
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
|
||||
|
||||
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
|
||||
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
|
||||
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
|
||||
enum OrtPropertyType type;
|
||||
char* propValue = NULL;
|
||||
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
|
||||
if (code == ORT_OK) {
|
||||
if (type == OrtStringProperty) {
|
||||
jstring output = (*jniEnv)->NewStringUTF(jniEnv, propValue);
|
||||
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
|
||||
return output;
|
||||
} else {
|
||||
throwOrtException(jniEnv, 2, "Requested a string property but this property is not a string");
|
||||
}
|
||||
}
|
||||
return (jstring) 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
|
||||
* Method: close
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -45,15 +45,34 @@ public class TrainingTest {
|
|||
Assertions.assertNotNull(trainingSession);
|
||||
Set<String> inputNames = trainingSession.getTrainInputNames();
|
||||
Assertions.assertFalse(inputNames.isEmpty());
|
||||
Assertions.assertTrue(inputNames.contains("input-0"));
|
||||
Set<String> outputNames = trainingSession.getTrainOutputNames();
|
||||
Assertions.assertFalse(outputNames.isEmpty());
|
||||
Assertions.assertTrue(outputNames.contains("onnx::loss::21273"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateTrainingSessionWithEval() throws OrtException {
|
||||
String ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
String evalPath = TestHelpers.getResourcePath("/eval_model.onnx").toString();
|
||||
try (OrtTrainingSession trainingSession =
|
||||
env.createTrainingSession(ckptPath, trainPath, evalPath, null)) {
|
||||
Assertions.assertNotNull(trainingSession);
|
||||
Set<String> inputNames = trainingSession.getEvalInputNames();
|
||||
Assertions.assertFalse(inputNames.isEmpty());
|
||||
Assertions.assertTrue(inputNames.contains("input-0"));
|
||||
Set<String> outputNames = trainingSession.getEvalOutputNames();
|
||||
Assertions.assertFalse(outputNames.isEmpty());
|
||||
Assertions.assertTrue(outputNames.contains("onnx::loss::21273"));
|
||||
}
|
||||
}
|
||||
|
||||
// this test is not enabled as ORT Java doesn't support supplying an output buffer
|
||||
@Disabled
|
||||
@Test
|
||||
public void TestTrainingSessionTrainStep() throws OrtException {
|
||||
public void testTrainingSessionTrainStep() throws OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
float[] expectedOutput =
|
||||
|
|
@ -134,7 +153,7 @@ public class TrainingTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void TestTrainingSessionTrainStepOrtOutput() throws OrtException {
|
||||
public void testTrainingSessionTrainStepOrtOutput() throws OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
try (OrtTrainingSession trainingSession =
|
||||
|
|
@ -144,7 +163,7 @@ public class TrainingTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void TestSaveCheckpoint() throws IOException, OrtException {
|
||||
public void testSaveCheckpoint() throws IOException, OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
|
||||
|
|
@ -168,7 +187,7 @@ public class TrainingTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void TestTrainingSessionOptimizerStep() throws OrtException {
|
||||
public void testTrainingSessionOptimizerStep() throws OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
|
||||
|
|
@ -214,7 +233,7 @@ public class TrainingTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void TestTrainingSessionSetLearningRate() throws OrtException {
|
||||
public void testTrainingSessionSetLearningRate() throws OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
|
||||
|
|
@ -229,7 +248,7 @@ public class TrainingTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void TestTrainingSessionLinearLRScheduler() throws OrtException {
|
||||
public void testTrainingSessionLinearLRScheduler() throws OrtException {
|
||||
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
|
||||
|
|
@ -253,4 +272,100 @@ public class TrainingTest {
|
|||
Assertions.assertEquals(0.0f, trainingSession.getLearningRate());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTrainingSessionExportModelForInferencing() throws IOException, OrtException {
|
||||
|
||||
String ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
|
||||
String trainPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
|
||||
String evalPath = TestHelpers.getResourcePath("/eval_model.onnx").toString();
|
||||
try (OrtTrainingSession trainingSession =
|
||||
env.createTrainingSession(ckptPath, trainPath, evalPath, null)) {
|
||||
String[] graphOutputs = new String[] {"output-0"};
|
||||
|
||||
Path inferencePath = Files.createTempFile("inference_model", ".onnx");
|
||||
|
||||
trainingSession.exportModelForInference(inferencePath, graphOutputs);
|
||||
Assertions.assertTrue(inferencePath.toFile().exists());
|
||||
inferencePath.toFile().delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCheckpointStateAddIntProperty() throws OrtException {
|
||||
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
|
||||
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
|
||||
String propertyName = "days in a week";
|
||||
ckpt.addProperty(propertyName, 7);
|
||||
|
||||
int value = ckpt.getIntProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.assertEquals(7, value);
|
||||
|
||||
try {
|
||||
String strVal = ckpt.getStringProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
|
||||
try {
|
||||
float floatVal = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCheckpointStateAddFloatProperty() throws OrtException {
|
||||
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
|
||||
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
|
||||
String propertyName = "pi";
|
||||
ckpt.addProperty(propertyName, 3.14f);
|
||||
|
||||
float value = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.assertEquals(3.14f, value);
|
||||
|
||||
try {
|
||||
String strVal = ckpt.getStringProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
|
||||
try {
|
||||
int intVal = ckpt.getIntProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCheckpointStateAddStringProperty() throws OrtException {
|
||||
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
|
||||
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
|
||||
String propertyName = "best ai framework";
|
||||
ckpt.addProperty(propertyName, "onnxruntime");
|
||||
|
||||
String value = ckpt.getStringProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.assertEquals("onnxruntime", value);
|
||||
|
||||
try {
|
||||
float floatVal = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
|
||||
try {
|
||||
int intVal = ckpt.getIntProperty(env.defaultAllocator, propertyName);
|
||||
Assertions.fail("Should have thrown");
|
||||
} catch (OrtException e) {
|
||||
// pass
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue