[Java] CheckpointState AddProperty & GetProperty support (#15730)

This commit is contained in:
Adam Pocock 2023-04-28 12:52:52 -04:00 committed by GitHub
parent be08b47e7b
commit 8a1a40ac63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 510 additions and 8 deletions

View file

@ -221,6 +221,72 @@ public final class OrtTrainingSession implements AutoCloseable {
return evalOutputNames;
}
/**
* Adds a float property to this training session checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, float value) throws OrtException {
checkpoint.addProperty(name, value);
}
/**
* Adds a int property to this training session checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, int value) throws OrtException {
checkpoint.addProperty(name, value);
}
/**
* Adds a String property to this training session checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, String value) throws OrtException {
checkpoint.addProperty(name, value);
}
/**
* Gets a float property from this training session checkpoint.
*
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public float getFloatProperty(String name) throws OrtException {
return checkpoint.getFloatProperty(allocator, name);
}
/**
* Gets a int property from this training session checkpoint.
*
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public int getIntProperty(String name) throws OrtException {
return checkpoint.getIntProperty(allocator, name);
}
/**
* Gets a String property from this training session checkpoint.
*
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public String getStringProperty(String name) throws OrtException {
return checkpoint.getStringProperty(allocator, name);
}
/** Checks if the OrtTrainingSession is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
@ -927,6 +993,93 @@ public final class OrtTrainingSession implements AutoCloseable {
saveOptimizer);
}
/**
* Adds a float property to this checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, float value) throws OrtException {
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
/**
* Adds a int property to this checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, int value) throws OrtException {
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
/**
* Adds a String property to this checkpoint.
*
* @param name The property name.
* @param value The property value.
* @throws OrtException If the call failed.
*/
public void addProperty(String name, String value) throws OrtException {
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
/**
* Gets a float property from this checkpoint.
*
* @param allocator The allocator.
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException {
return getFloatProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
nativeHandle,
allocator.handle,
name);
}
/**
* Gets a int property from this checkpoint.
*
* @param allocator The allocator.
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public int getIntProperty(OrtAllocator allocator, String name) throws OrtException {
return getIntProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
nativeHandle,
allocator.handle,
name);
}
/**
* Gets a String property from this checkpoint.
*
* @param allocator The allocator.
* @param name The property name.
* @return The property value.
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public String getStringProperty(OrtAllocator allocator, String name) throws OrtException {
return getStringProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
nativeHandle,
allocator.handle,
name);
}
@Override
public void close() {
close(OnnxRuntime.ortTrainingApiHandle, nativeHandle);
@ -969,6 +1122,88 @@ public final class OrtTrainingSession implements AutoCloseable {
long apiHandle, long trainingHandle, long nativeHandle, String path, boolean saveOptimizer)
throws OrtException;
/* \brief Adds the given property to the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user if they desire by calling this function with the appropriate property name and
* value. The given property name must be unique to be able to successfully add the property.
*
* \param[in] checkpoint_state The checkpoint state which should hold the property.
* \param[in] property_name Unique name of the property being added.
* \param[in] property_type Type of the property associated with the given name.
* \param[in] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
_In_ void* property_value);
*/
private native void addProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
String propertyName,
int propertyValue)
throws OrtException;
private native void addProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
String propertyName,
float propertyValue)
throws OrtException;
private native void addProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
String propertyName,
String propertyValue)
throws OrtException;
/* \brief Gets the property value associated with the given name from the checkpoint state.
*
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
* \param[in] property_name Unique name of the property being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the property_value.
* \param[out] property_type Type of the property associated with the given name.
* \param[out] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
*/
private native int getIntProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
long allocatorHandle,
String propertyName)
throws OrtException;
private native float getFloatProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
long allocatorHandle,
String propertyName)
throws OrtException;
private native String getStringProperty(
long apiHandle,
long trainingHandle,
long nativeHandle,
long allocatorHandle,
String propertyName)
throws OrtException;
private native void close(long trainingApiHandle, long nativeHandle);
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -81,6 +81,158 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpoint
#endif
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: addProperty
* Signature: (JJJLjava/lang/String;I)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2I
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jint propValue) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtIntProperty, &propValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: addProperty
* Signature: (JJJLjava/lang/String;F)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2F
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jfloat propValue) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtFloatProperty, &propValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: addProperty
* Signature: (JJJLjava/lang/String;Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_addProperty__JJJLjava_lang_String_2Ljava_lang_String_2
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jstring propName, jstring propValue) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
const char* cPropValue = (*jniEnv)->GetStringUTFChars(jniEnv, propValue, NULL);
checkOrtStatus(jniEnv, api, trainApi->AddProperty(checkpointState, cPropName, OrtStringProperty, (void*)cPropValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propValue, cPropValue);
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: getIntProperty
* Signature: (JJJJLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getIntProperty
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
enum OrtPropertyType type;
int* propValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
if (code == ORT_OK) {
if (type == OrtIntProperty) {
int output = *propValue;
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
return output;
} else {
throwOrtException(jniEnv, 2, "Requested an int property but this property is not an int");
}
}
return 0;
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: getFloatProperty
* Signature: (JJJJLjava/lang/String;)F
*/
JNIEXPORT jfloat JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getFloatProperty
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
enum OrtPropertyType type;
float* propValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
if (code == ORT_OK) {
if (type == OrtFloatProperty) {
float output = *propValue;
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
return output;
} else {
throwOrtException(jniEnv, 2, "Requested a float property but this property is not a float");
}
}
return 0.0f;
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: getStringProperty
* Signature: (JJJJLjava/lang/String;)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtTrainingSession_00024OrtCheckpointState_getStringProperty
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainingApiHandle, jlong nativeHandle, jlong allocatorHandle, jstring propName) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
const OrtTrainingApi* trainApi = (const OrtTrainingApi*) trainingApiHandle;
OrtCheckpointState* checkpointState = (OrtCheckpointState*) nativeHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
const char* cPropName = (*jniEnv)->GetStringUTFChars(jniEnv, propName, NULL);
enum OrtPropertyType type;
char* propValue = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, trainApi->GetProperty(checkpointState, cPropName, allocator, &type, (void**)&propValue));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, propName, cPropName);
if (code == ORT_OK) {
if (type == OrtStringProperty) {
jstring output = (*jniEnv)->NewStringUTF(jniEnv, propValue);
checkOrtStatus(jniEnv, api, api->AllocatorFree(allocator, propValue));
return output;
} else {
throwOrtException(jniEnv, 2, "Requested a string property but this property is not a string");
}
}
return (jstring) 0;
}
/*
* Class: ai_onnxruntime_OrtTrainingSession_OrtCheckpointState
* Method: close

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -45,15 +45,34 @@ public class TrainingTest {
Assertions.assertNotNull(trainingSession);
Set<String> inputNames = trainingSession.getTrainInputNames();
Assertions.assertFalse(inputNames.isEmpty());
Assertions.assertTrue(inputNames.contains("input-0"));
Set<String> outputNames = trainingSession.getTrainOutputNames();
Assertions.assertFalse(outputNames.isEmpty());
Assertions.assertTrue(outputNames.contains("onnx::loss::21273"));
}
}
@Test
public void testCreateTrainingSessionWithEval() throws OrtException {
String ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
String evalPath = TestHelpers.getResourcePath("/eval_model.onnx").toString();
try (OrtTrainingSession trainingSession =
env.createTrainingSession(ckptPath, trainPath, evalPath, null)) {
Assertions.assertNotNull(trainingSession);
Set<String> inputNames = trainingSession.getEvalInputNames();
Assertions.assertFalse(inputNames.isEmpty());
Assertions.assertTrue(inputNames.contains("input-0"));
Set<String> outputNames = trainingSession.getEvalOutputNames();
Assertions.assertFalse(outputNames.isEmpty());
Assertions.assertTrue(outputNames.contains("onnx::loss::21273"));
}
}
// this test is not enabled as ORT Java doesn't support supplying an output buffer
@Disabled
@Test
public void TestTrainingSessionTrainStep() throws OrtException {
public void testTrainingSessionTrainStep() throws OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
float[] expectedOutput =
@ -134,7 +153,7 @@ public class TrainingTest {
}
@Test
public void TestTrainingSessionTrainStepOrtOutput() throws OrtException {
public void testTrainingSessionTrainStepOrtOutput() throws OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
try (OrtTrainingSession trainingSession =
@ -144,7 +163,7 @@ public class TrainingTest {
}
@Test
public void TestSaveCheckpoint() throws IOException, OrtException {
public void testSaveCheckpoint() throws IOException, OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
@ -168,7 +187,7 @@ public class TrainingTest {
}
@Test
public void TestTrainingSessionOptimizerStep() throws OrtException {
public void testTrainingSessionOptimizerStep() throws OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
@ -214,7 +233,7 @@ public class TrainingTest {
}
@Test
public void TestTrainingSessionSetLearningRate() throws OrtException {
public void testTrainingSessionSetLearningRate() throws OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
@ -229,7 +248,7 @@ public class TrainingTest {
}
@Test
public void TestTrainingSessionLinearLRScheduler() throws OrtException {
public void testTrainingSessionLinearLRScheduler() throws OrtException {
String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
String optimizerPath = TestHelpers.getResourcePath("/adamw.onnx").toString();
@ -253,4 +272,100 @@ public class TrainingTest {
Assertions.assertEquals(0.0f, trainingSession.getLearningRate());
}
}
@Test
public void testTrainingSessionExportModelForInferencing() throws IOException, OrtException {
String ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString();
String trainPath = TestHelpers.getResourcePath("/training_model.onnx").toString();
String evalPath = TestHelpers.getResourcePath("/eval_model.onnx").toString();
try (OrtTrainingSession trainingSession =
env.createTrainingSession(ckptPath, trainPath, evalPath, null)) {
String[] graphOutputs = new String[] {"output-0"};
Path inferencePath = Files.createTempFile("inference_model", ".onnx");
trainingSession.exportModelForInference(inferencePath, graphOutputs);
Assertions.assertTrue(inferencePath.toFile().exists());
inferencePath.toFile().delete();
}
}
@Test
public void testCheckpointStateAddIntProperty() throws OrtException {
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
String propertyName = "days in a week";
ckpt.addProperty(propertyName, 7);
int value = ckpt.getIntProperty(env.defaultAllocator, propertyName);
Assertions.assertEquals(7, value);
try {
String strVal = ckpt.getStringProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
try {
float floatVal = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
}
}
@Test
public void testCheckpointStateAddFloatProperty() throws OrtException {
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
String propertyName = "pi";
ckpt.addProperty(propertyName, 3.14f);
float value = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
Assertions.assertEquals(3.14f, value);
try {
String strVal = ckpt.getStringProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
try {
int intVal = ckpt.getIntProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
}
}
@Test
public void testCheckpointStateAddStringProperty() throws OrtException {
Path ckptPath = TestHelpers.getResourcePath("/checkpoint.ckpt");
try (OrtCheckpointState ckpt = OrtCheckpointState.loadCheckpoint(ckptPath)) {
String propertyName = "best ai framework";
ckpt.addProperty(propertyName, "onnxruntime");
String value = ckpt.getStringProperty(env.defaultAllocator, propertyName);
Assertions.assertEquals("onnxruntime", value);
try {
float floatVal = ckpt.getFloatProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
try {
int intVal = ckpt.getIntProperty(env.defaultAllocator, propertyName);
Assertions.fail("Should have thrown");
} catch (OrtException e) {
// pass
}
}
}
}