[java] Multi-LoRA support (#22280)

### Description
Java parts of Multi-LoRA support - #22046.

### Motivation and Context
API equivalence with Python & C#.

---------

Co-authored-by: Dmitri Smirnov <dmitrism@microsoft.com>
This commit is contained in:
Adam Pocock 2024-10-01 16:54:37 -04:00 committed by GitHub
parent 1fc2b94644
commit 14d1bfc34b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 381 additions and 8 deletions

View file

@ -0,0 +1,161 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Objects;
/**
* A container for an adapter which can be supplied to {@link
* OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific
* execution of a model.
*/
public final class OrtLoraAdapter implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}
private final long nativeHandle;
private boolean closed = false;
private OrtLoraAdapter(long nativeHandle) {
this.nativeHandle = nativeHandle;
}
/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(byte[] loraArray) throws OrtException {
return create(loraArray, null);
}
/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraArray, "LoRA array must not be null");
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle));
}
/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException {
return create(loraBuffer, null);
}
/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null");
if (loraBuffer.remaining() == 0) {
throw new OrtException("Invalid LoRA buffer, no elements remaining.");
} else if (!loraBuffer.isDirect()) {
throw new OrtException("ByteBuffer is not direct.");
}
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromBuffer(
OnnxRuntime.ortApiHandle,
loraBuffer,
loraBuffer.position(),
loraBuffer.remaining(),
allocatorHandle));
}
/**
* Creates an instance of OrtLoraAdapter.
*
* @param adapterPath path to the adapter file that is going to be memory mapped.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(String adapterPath) throws OrtException {
return create(adapterPath, null);
}
/**
* Creates an instance of OrtLoraAdapter.
*
* @param adapterPath path to the adapter file that is going to be memory mapped.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException {
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle));
}
/**
* Package accessor for native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}
/** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */
void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OrtLoraAdapter");
}
}
@Override
public void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter");
}
}
private static native long createLoraAdapter(
long apiHandle, String adapterPath, long allocatorHandle) throws OrtException;
private static native long createLoraAdapterFromArray(
long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException;
private static native long createLoraAdapterFromBuffer(
long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle)
throws OrtException;
private static native void close(long apiHandle, long nativeHandle);
}

View file

@ -635,8 +635,8 @@ public class OrtSession implements AutoCloseable {
* The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum
* in the C API.
*
* <p>See <a
* href="https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* <p>See <a href=
* "https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* Optimizations</a> for more details.
*/
public enum OptLevel {
@ -684,6 +684,7 @@ public class OrtSession implements AutoCloseable {
SEQUENTIAL(0),
/** Executes some nodes in parallel. */
PARALLEL(1);
private final int id;
ExecutionMode(int id) {
@ -1391,17 +1392,19 @@ public class OrtSession implements AutoCloseable {
throws OrtException;
/*
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* To use additional providers, you must build ORT with the extra providers enabled. Then call
* one of these functions to enable them in the session:
*
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_ROCM
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException
* The order they are called indicates the preference order as well. In other words call this
* method on your most preferred execution provider first followed by the less preferred ones.
* If none are called ORT will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException.
*/
private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException;
@ -1579,6 +1582,18 @@ public class OrtSession implements AutoCloseable {
addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value);
}
/**
* Adds the specified adapter to the list of active adapters for this run.
*
* @param loraAdapter valid OrtLoraAdapter object
* @throws OrtException of the native library call failed
*/
public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException {
checkClosed();
loraAdapter.checkClosed();
addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle());
}
/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
@ -1619,6 +1634,9 @@ public class OrtSession implements AutoCloseable {
private native void addRunConfigEntry(
long apiHandle, long nativeHandle, String key, String value) throws OrtException;
private native void addActiveLoraAdapter(
long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException;
private static native void close(long apiHandle, long nativeHandle);
}

View file

@ -0,0 +1,106 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <string.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtLoraAdapter.h"
/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapter
* Signature: (JLjava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;
#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath);
#endif
return (jlong) lora;
}
/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromBuffer
* Signature: (JLjava/nio/ByteBuffer;IIJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;
// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;
// Create the adapter
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora));
return (jlong) lora;
}
/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromArray
* Signature: (J[BJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;
size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray);
if (loraLength == 0) {
throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length.");
return 0;
}
// Get a reference to the byte array elements
jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora));
// Release the C array.
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT);
return (jlong) lora;
}
/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) {
(void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle);
}

View file

@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf
(*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value);
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: addActiveLoraAdapter
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle));
}
/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setTerminate

View file

@ -1310,6 +1310,82 @@ public class InferenceTest {
}
}
@Test
public void testRunWithLoraAdapter() throws IOException, OrtException {
Path modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx");
Path adapterPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter");
long[] inputShape = new long[] {4, 4};
float[] inputData = new float[16];
Arrays.fill(inputData, 1.f);
FloatBuffer buf =
ByteBuffer.allocateDirect(Float.BYTES * 16).order(ByteOrder.nativeOrder()).asFloatBuffer();
buf.put(inputData);
buf.rewind();
float[][] expectedOutput =
new float[][] {
{28.f, 32.f, 36.f, 40.f},
{28.f, 32.f, 36.f, 40.f},
{28.f, 32.f, 36.f, 40.f},
{28.f, 32.f, 36.f, 40.f}
};
float[][] expectedLoRAOutput =
new float[][] {
{154.f, 176.f, 198.f, 220.f},
{154.f, 176.f, 198.f, 220.f},
{154.f, 176.f, 198.f, 220.f},
{154.f, 176.f, 198.f, 220.f}
};
try (OrtSession session = env.createSession(modelPath.toString());
OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) {
Map<String, OnnxTensor> inputs = Collections.singletonMap("input_x", tensor);
// Without LoRA
try (OrtSession.Result result = session.run(inputs)) {
float[][] resultArr = (float[][]) result.get(0).getValue();
Assertions.assertArrayEquals(expectedOutput, resultArr);
}
// With LoRA from path
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath.toString());
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
runOptions.addActiveLoraAdapter(adapter);
try (OrtSession.Result result = session.run(inputs, runOptions)) {
float[][] resultArr = (float[][]) result.get(0).getValue();
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
}
}
// With LoRA from array
byte[] loraArray = Files.readAllBytes(adapterPath);
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraArray);
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
runOptions.addActiveLoraAdapter(adapter);
try (OrtSession.Result result = session.run(inputs, runOptions)) {
float[][] resultArr = (float[][]) result.get(0).getValue();
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
}
}
// With LoRA from buffer
ByteBuffer loraBuf = ByteBuffer.allocateDirect(loraArray.length);
loraBuf.put(loraArray);
loraBuf.rewind();
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraBuf);
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
runOptions.addActiveLoraAdapter(adapter);
try (OrtSession.Result result = session.run(inputs, runOptions)) {
float[][] resultArr = (float[][]) result.get(0).getValue();
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
}
}
}
}
@Test
public void testExtraSessionOptions() throws OrtException, IOException {
// model takes 1x5 input of fixed type, echoes back