mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
[java] Multi-LoRA support (#22280)
### Description Java parts of Multi-LoRA support - #22046. ### Motivation and Context API equivalence with Python & C#. --------- Co-authored-by: Dmitri Smirnov <dmitrism@microsoft.com>
This commit is contained in:
parent
1fc2b94644
commit
14d1bfc34b
5 changed files with 381 additions and 8 deletions
161
java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java
Normal file
161
java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* A container for an adapter which can be supplied to {@link
|
||||
* OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific
|
||||
* execution of a model.
|
||||
*/
|
||||
public final class OrtLoraAdapter implements AutoCloseable {
|
||||
static {
|
||||
try {
|
||||
OnnxRuntime.init();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load onnx-runtime library", e);
|
||||
}
|
||||
}
|
||||
|
||||
private final long nativeHandle;
|
||||
|
||||
private boolean closed = false;
|
||||
|
||||
private OrtLoraAdapter(long nativeHandle) {
|
||||
this.nativeHandle = nativeHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter from a byte array.
|
||||
*
|
||||
* @param loraArray The LoRA stored in a byte array.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
public static OrtLoraAdapter create(byte[] loraArray) throws OrtException {
|
||||
return create(loraArray, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter from a byte array.
|
||||
*
|
||||
* @param loraArray The LoRA stored in a byte array.
|
||||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
|
||||
* allocator memory.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException {
|
||||
Objects.requireNonNull(loraArray, "LoRA array must not be null");
|
||||
long allocatorHandle = allocator == null ? 0 : allocator.handle;
|
||||
return new OrtLoraAdapter(
|
||||
createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
|
||||
*
|
||||
* @param loraBuffer The buffer to load.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException {
|
||||
return create(loraBuffer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
|
||||
*
|
||||
* @param loraBuffer The buffer to load.
|
||||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
|
||||
* allocator memory.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException {
|
||||
Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null");
|
||||
if (loraBuffer.remaining() == 0) {
|
||||
throw new OrtException("Invalid LoRA buffer, no elements remaining.");
|
||||
} else if (!loraBuffer.isDirect()) {
|
||||
throw new OrtException("ByteBuffer is not direct.");
|
||||
}
|
||||
long allocatorHandle = allocator == null ? 0 : allocator.handle;
|
||||
return new OrtLoraAdapter(
|
||||
createLoraAdapterFromBuffer(
|
||||
OnnxRuntime.ortApiHandle,
|
||||
loraBuffer,
|
||||
loraBuffer.position(),
|
||||
loraBuffer.remaining(),
|
||||
allocatorHandle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter.
|
||||
*
|
||||
* @param adapterPath path to the adapter file that is going to be memory mapped.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
public static OrtLoraAdapter create(String adapterPath) throws OrtException {
|
||||
return create(adapterPath, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an instance of OrtLoraAdapter.
|
||||
*
|
||||
* @param adapterPath path to the adapter file that is going to be memory mapped.
|
||||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
|
||||
* allocator memory.
|
||||
* @throws OrtException If the native call failed.
|
||||
* @return An OrtLoraAdapter instance.
|
||||
*/
|
||||
static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException {
|
||||
long allocatorHandle = allocator == null ? 0 : allocator.handle;
|
||||
return new OrtLoraAdapter(
|
||||
createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Package accessor for native pointer.
|
||||
*
|
||||
* @return The native pointer.
|
||||
*/
|
||||
long getNativeHandle() {
|
||||
return nativeHandle;
|
||||
}
|
||||
|
||||
/** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */
|
||||
void checkClosed() {
|
||||
if (closed) {
|
||||
throw new IllegalStateException("Trying to use a closed OrtLoraAdapter");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (!closed) {
|
||||
close(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
closed = true;
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter");
|
||||
}
|
||||
}
|
||||
|
||||
private static native long createLoraAdapter(
|
||||
long apiHandle, String adapterPath, long allocatorHandle) throws OrtException;
|
||||
|
||||
private static native long createLoraAdapterFromArray(
|
||||
long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException;
|
||||
|
||||
private static native long createLoraAdapterFromBuffer(
|
||||
long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle)
|
||||
throws OrtException;
|
||||
|
||||
private static native void close(long apiHandle, long nativeHandle);
|
||||
}
|
||||
|
|
@ -635,8 +635,8 @@ public class OrtSession implements AutoCloseable {
|
|||
* The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum
|
||||
* in the C API.
|
||||
*
|
||||
* <p>See <a
|
||||
* href="https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
|
||||
* <p>See <a href=
|
||||
* "https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
|
||||
* Optimizations</a> for more details.
|
||||
*/
|
||||
public enum OptLevel {
|
||||
|
|
@ -684,6 +684,7 @@ public class OrtSession implements AutoCloseable {
|
|||
SEQUENTIAL(0),
|
||||
/** Executes some nodes in parallel. */
|
||||
PARALLEL(1);
|
||||
|
||||
private final int id;
|
||||
|
||||
ExecutionMode(int id) {
|
||||
|
|
@ -1391,17 +1392,19 @@ public class OrtSession implements AutoCloseable {
|
|||
throws OrtException;
|
||||
|
||||
/*
|
||||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
|
||||
* functions to enable them in the session:
|
||||
* To use additional providers, you must build ORT with the extra providers enabled. Then call
|
||||
* one of these functions to enable them in the session:
|
||||
*
|
||||
* OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
* OrtSessionOptionsAppendExecutionProvider_CUDA
|
||||
* OrtSessionOptionsAppendExecutionProvider_ROCM
|
||||
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
|
||||
* The order they care called indicates the preference order as well. In other words call this method
|
||||
* on your most preferred execution provider first followed by the less preferred ones.
|
||||
* If none are called Ort will use its internal CPU execution provider.
|
||||
*
|
||||
* If a backend is unavailable then it throws an OrtException
|
||||
* The order they are called indicates the preference order as well. In other words call this
|
||||
* method on your most preferred execution provider first followed by the less preferred ones.
|
||||
* If none are called ORT will use its internal CPU execution provider.
|
||||
*
|
||||
* If a backend is unavailable then it throws an OrtException.
|
||||
*/
|
||||
private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException;
|
||||
|
||||
|
|
@ -1579,6 +1582,18 @@ public class OrtSession implements AutoCloseable {
|
|||
addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds the specified adapter to the list of active adapters for this run.
|
||||
*
|
||||
* @param loraAdapter valid OrtLoraAdapter object
|
||||
* @throws OrtException of the native library call failed
|
||||
*/
|
||||
public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException {
|
||||
checkClosed();
|
||||
loraAdapter.checkClosed();
|
||||
addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle());
|
||||
}
|
||||
|
||||
/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
|
||||
private void checkClosed() {
|
||||
if (closed) {
|
||||
|
|
@ -1619,6 +1634,9 @@ public class OrtSession implements AutoCloseable {
|
|||
private native void addRunConfigEntry(
|
||||
long apiHandle, long nativeHandle, String key, String value) throws OrtException;
|
||||
|
||||
private native void addActiveLoraAdapter(
|
||||
long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException;
|
||||
|
||||
private static native void close(long apiHandle, long nativeHandle);
|
||||
}
|
||||
|
||||
|
|
|
|||
106
java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c
Normal file
106
java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <string.h>
|
||||
#include "onnxruntime/core/session/onnxruntime_c_api.h"
|
||||
#include "OrtJniUtil.h"
|
||||
#include "ai_onnxruntime_OrtLoraAdapter.h"
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtLoraAdapter
|
||||
* Method: createLoraAdapter
|
||||
* Signature: (JLjava/lang/String;J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter
|
||||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
OrtLoraAdapter* lora;
|
||||
|
||||
#ifdef _WIN32
|
||||
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL);
|
||||
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath);
|
||||
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
|
||||
if (newString == NULL) {
|
||||
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
|
||||
throwOrtException(jniEnv, 1, "Not enough memory");
|
||||
return 0;
|
||||
}
|
||||
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
|
||||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora));
|
||||
free(newString);
|
||||
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
|
||||
#else
|
||||
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL);
|
||||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath);
|
||||
#endif
|
||||
|
||||
return (jlong) lora;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtLoraAdapter
|
||||
* Method: createLoraAdapterFromBuffer
|
||||
* Signature: (JLjava/nio/ByteBuffer;IIJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer
|
||||
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
OrtLoraAdapter* lora;
|
||||
|
||||
// Extract the buffer
|
||||
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
|
||||
// Increment by bufferPos bytes
|
||||
bufferArr = bufferArr + bufferPos;
|
||||
|
||||
// Create the adapter
|
||||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora));
|
||||
|
||||
return (jlong) lora;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtLoraAdapter
|
||||
* Method: createLoraAdapterFromArray
|
||||
* Signature: (J[BJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray
|
||||
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) {
|
||||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
|
||||
OrtLoraAdapter* lora;
|
||||
|
||||
size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray);
|
||||
if (loraLength == 0) {
|
||||
throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get a reference to the byte array elements
|
||||
jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL);
|
||||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora));
|
||||
// Release the C array.
|
||||
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT);
|
||||
|
||||
return (jlong) lora;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtLoraAdapter
|
||||
* Method: close
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close
|
||||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) {
|
||||
(void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle);
|
||||
}
|
||||
|
||||
|
|
@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf
|
|||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: addActiveLoraAdapter
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) {
|
||||
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*) apiHandle;
|
||||
checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle));
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_RunOptions
|
||||
* Method: setTerminate
|
||||
|
|
|
|||
|
|
@ -1310,6 +1310,82 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithLoraAdapter() throws IOException, OrtException {
|
||||
Path modelPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx");
|
||||
Path adapterPath = TestHelpers.getResourcePath("/lora/two_params_lora_model.onnx_adapter");
|
||||
|
||||
long[] inputShape = new long[] {4, 4};
|
||||
float[] inputData = new float[16];
|
||||
Arrays.fill(inputData, 1.f);
|
||||
FloatBuffer buf =
|
||||
ByteBuffer.allocateDirect(Float.BYTES * 16).order(ByteOrder.nativeOrder()).asFloatBuffer();
|
||||
buf.put(inputData);
|
||||
buf.rewind();
|
||||
|
||||
float[][] expectedOutput =
|
||||
new float[][] {
|
||||
{28.f, 32.f, 36.f, 40.f},
|
||||
{28.f, 32.f, 36.f, 40.f},
|
||||
{28.f, 32.f, 36.f, 40.f},
|
||||
{28.f, 32.f, 36.f, 40.f}
|
||||
};
|
||||
|
||||
float[][] expectedLoRAOutput =
|
||||
new float[][] {
|
||||
{154.f, 176.f, 198.f, 220.f},
|
||||
{154.f, 176.f, 198.f, 220.f},
|
||||
{154.f, 176.f, 198.f, 220.f},
|
||||
{154.f, 176.f, 198.f, 220.f}
|
||||
};
|
||||
|
||||
try (OrtSession session = env.createSession(modelPath.toString());
|
||||
OnnxTensor tensor = OnnxTensor.createTensor(env, buf, inputShape)) {
|
||||
|
||||
Map<String, OnnxTensor> inputs = Collections.singletonMap("input_x", tensor);
|
||||
|
||||
// Without LoRA
|
||||
try (OrtSession.Result result = session.run(inputs)) {
|
||||
float[][] resultArr = (float[][]) result.get(0).getValue();
|
||||
Assertions.assertArrayEquals(expectedOutput, resultArr);
|
||||
}
|
||||
|
||||
// With LoRA from path
|
||||
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(adapterPath.toString());
|
||||
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
|
||||
runOptions.addActiveLoraAdapter(adapter);
|
||||
try (OrtSession.Result result = session.run(inputs, runOptions)) {
|
||||
float[][] resultArr = (float[][]) result.get(0).getValue();
|
||||
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
|
||||
}
|
||||
}
|
||||
|
||||
// With LoRA from array
|
||||
byte[] loraArray = Files.readAllBytes(adapterPath);
|
||||
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraArray);
|
||||
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
|
||||
runOptions.addActiveLoraAdapter(adapter);
|
||||
try (OrtSession.Result result = session.run(inputs, runOptions)) {
|
||||
float[][] resultArr = (float[][]) result.get(0).getValue();
|
||||
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
|
||||
}
|
||||
}
|
||||
|
||||
// With LoRA from buffer
|
||||
ByteBuffer loraBuf = ByteBuffer.allocateDirect(loraArray.length);
|
||||
loraBuf.put(loraArray);
|
||||
loraBuf.rewind();
|
||||
try (OrtLoraAdapter adapter = OrtLoraAdapter.create(loraBuf);
|
||||
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
|
||||
runOptions.addActiveLoraAdapter(adapter);
|
||||
try (OrtSession.Result result = session.run(inputs, runOptions)) {
|
||||
float[][] resultArr = (float[][]) result.get(0).getValue();
|
||||
Assertions.assertArrayEquals(expectedLoRAOutput, resultArr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExtraSessionOptions() throws OrtException, IOException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
|
|
|
|||
Loading…
Reference in a new issue