[java] Adding addExternalInitializers and addInitializer to OrtSession.SessionOptions (#16198)

### Description
Adds support for adding external initializers or overriding initializers
to a session options from Java.

### Motivation and Context
We want to instantiate large models from Java without filesystem access.

cc @yuslepukhin
This commit is contained in:
Adam Pocock 2023-07-05 15:51:59 -04:00 committed by GitHub
parent 661fd4b978
commit ba91457183
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 444 additions and 16 deletions

View file

@ -793,6 +793,54 @@ public class OrtSession implements AutoCloseable {
return Collections.unmodifiableMap(configEntries);
}
/**
* Adds in the supplied externally loaded initializers.
*
* <p>Note the initializers are copied into the session once it has been created, and the native
* references are removed from this {@code SessionOptions}. Once the session has been created
* those initializers can be closed. This is a different lifetime to initializers added via
* {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be created from {@link
* java.nio.Buffer} objects.
*
* @param initializers The map of names to initializers.
* @throws OrtException If the initializers could not be loaded.
*/
public void addExternalInitializers(Map<String, OnnxTensorLike> initializers)
throws OrtException {
checkClosed();
if (initializers.isEmpty()) {
return;
}
String[] names = new String[initializers.size()];
long[] handles = new long[initializers.size()];
int i = 0;
for (Map.Entry<String, OnnxTensorLike> e : initializers.entrySet()) {
names[i] = e.getKey();
handles[i] = e.getValue().nativeHandle;
i++;
}
addExternalInitializers(OnnxRuntime.ortApiHandle, nativeHandle, names, handles);
}
/**
* Adds an initializer to override one from the ONNX model.
*
* <p>Note the initializer lifetime must outlive the session and session options. This is a
* different lifetime to initializers added via {@link #addExternalInitializers(Map)}. The
* initializers must be created from {@link java.nio.Buffer} objects.
*
* @param name The initializer name.
* @param initializer The initializer value.
* @throws OrtException If the initializer could not be loaded into the session options.
*/
public void addInitializer(String name, OnnxTensorLike initializer) throws OrtException {
checkClosed();
if (name.trim().isEmpty()) {
throw new IllegalArgumentException("Initializer name was blank");
}
addInitializer(OnnxRuntime.ortApiHandle, nativeHandle, name, initializer.getNativeHandle());
}
/**
* Add CUDA as an execution backend, using device 0.
*
@ -1096,6 +1144,13 @@ public class OrtSession implements AutoCloseable {
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
throws OrtException;
private native void addExternalInitializers(
long apiHandle, long nativeHandle, String[] names, long[] tensorHandles)
throws OrtException;
private native void addInitializer(
long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException;
private native void disablePerSessionThreads(long apiHandle, long nativeHandle)
throws OrtException;

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -350,6 +350,85 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFre
(*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addExternalInitializers
* Signature: (JJ[Ljava/lang/String;[J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExternalInitializers
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray namesArray, jlongArray handlesArray) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
size_t arrLength = (*jniEnv)->GetArrayLength(jniEnv, handlesArray);
const char** names = allocarray(arrLength, sizeof(char*));
if (names == NULL) {
// Nothing to cleanup, return and throw exception
throwOrtException(jniEnv, 1, "Not enough memory");
return;
}
jobject* javaNameStrings = allocarray(arrLength, sizeof(jobject));
if (javaNameStrings == NULL) {
goto cleanup_names;
}
const OrtValue** initializers = allocarray(arrLength, sizeof(OrtValue*));
if (initializers == NULL) {
goto cleanup_java_input_strings;
}
// Extract a C array of longs which are pointers to the input tensors.
// The Java-side objects store native pointers as 64-bit longs, and on 32-bit systems
// we cannot cast the long array to a pointer array as they are different sizes,
// so we copy the longs applying the appropriate cast.
jlong* initializersArr = (*jniEnv)->GetLongArrayElements(jniEnv, handlesArray, NULL);
for (size_t i = 0; i < arrLength; i++) {
// Extract the string chars and cast the tensor
javaNameStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, namesArray, (jint) i);
names[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaNameStrings[i], NULL);
initializers[i] = (const OrtValue*) initializersArr[i];
}
checkOrtStatus(jniEnv,api,api->AddExternalInitializers(options,names,initializers,arrLength));
// Release the java array copy of pointers to the tensors.
(*jniEnv)->ReleaseLongArrayElements(jniEnv, handlesArray, initializersArr, JNI_ABORT);
free(initializers);
cleanup_java_input_strings:
// Release the Java strings
for (size_t i = 0; i < arrLength; i++) {
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaNameStrings[i], names[i]);
}
free(javaNameStrings);
cleanup_names:
free(names);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addInitializer
* Signature: (JJLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addInitializer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring name, jlong tensorHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
// Extract the string chars
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL);
// Cast the onnx value
const OrtValue* tensor = (const OrtValue*) tensorHandle;
checkOrtStatus(jniEnv,api,api->AddInitializer(options,cName,tensor));
// Release the string chars
(*jniEnv)->ReleaseStringUTFChars(jniEnv,name,cName);
}
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: disablePerSessionThreads

View file

@ -664,6 +664,86 @@ public class InferenceTest {
}
}
@Test
public void testExternalInitializers() throws IOException, OrtException {
String modelPath = TestHelpers.getResourcePath("/java-external-matmul.onnx").toString();
// Run by loading the external initializer from disk
// initializer is 1...16 in a 4x4 matrix.
try (SessionOptions options = new SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
OnnxTensor output = (OnnxTensor) res.get(0);
float[][] outputArr = (float[][]) output.getValue();
assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
}
}
}
// Run by overriding the initializer with the identity matrix
try (SessionOptions options = new SessionOptions()) {
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
try (OrtSession session = env.createSession(modelPath, options)) {
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
OnnxTensor output = (OnnxTensor) res.get(0);
float[][] outputArr = (float[][]) output.getValue();
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
}
}
tensor.close();
}
// Run by overriding the initializer with the identity matrix loaded from a byte array
byte[] modelBytes =
Files.readAllBytes(TestHelpers.getResourcePath("/java-external-matmul.onnx"));
try (SessionOptions options = new SessionOptions()) {
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
try (OrtSession session = env.createSession(modelBytes, options)) {
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
OnnxTensor output = (OnnxTensor) res.get(0);
float[][] outputArr = (float[][]) output.getValue();
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
}
}
tensor.close();
}
}
@Test
public void testOverridingInitializer() throws OrtException {
String modelPath = TestHelpers.getResourcePath("/java-matmul.onnx").toString();
// Run with the normal initializer
// initializer is 1...16 in a 4x4 matrix.
try (SessionOptions options = new SessionOptions()) {
try (OrtSession session = env.createSession(modelPath, options)) {
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
OnnxTensor output = (OnnxTensor) res.get(0);
float[][] outputArr = (float[][]) output.getValue();
assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
}
}
}
// Run by overriding the initializer with the identity matrix
try (SessionOptions options = new SessionOptions()) {
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
options.addInitializer("tensor", tensor);
try (OrtSession session = env.createSession(modelPath, options)) {
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
OnnxTensor output = (OnnxTensor) res.get(0);
float[][] outputArr = (float[][]) output.getValue();
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
}
}
tensor.close();
}
}
private static File getTestModelsDir() throws IOException {
// get build directory, append downloaded models location
String cwd = System.getProperty("user.dir");

View file

@ -0,0 +1,184 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import ai.onnxruntime.OnnxMl.StringStringEntryProto;
import ai.onnxruntime.OnnxMl.TensorProto.DataLocation;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
/** Methods to generate test models. */
public final class ModelGenerators {
private ModelGenerators() {}
public static OnnxMl.TensorShapeProto getShapeProto(
long[] dimensions, String[] dimensionOverrides) {
OnnxMl.TensorShapeProto.Builder builder = OnnxMl.TensorShapeProto.newBuilder();
for (int i = 0; i < dimensions.length; i++) {
if (dimensions[i] == -1) {
builder.addDim(
OnnxMl.TensorShapeProto.Dimension.newBuilder()
.setDimParam(dimensionOverrides[i])
.build());
} else {
builder.addDim(
OnnxMl.TensorShapeProto.Dimension.newBuilder().setDimValue(dimensions[i]).build());
}
}
return builder.build();
}
public static OnnxMl.TypeProto buildTensorTypeNode(
long[] dimensions, String[] dimensionOverrides, OnnxMl.TensorProto.DataType type) {
OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder();
OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder();
tensorBuilder.setElemType(type.getNumber());
tensorBuilder.setShape(getShapeProto(dimensions, dimensionOverrides));
builder.setTensorType(tensorBuilder.build());
return builder.build();
}
public void generateExternalMatMul() throws IOException {
OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
graph.setName("ort-test-matmul");
// Add placeholders
OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
input.setName("input");
OnnxMl.TypeProto inputType =
buildTensorTypeNode(
new long[] {-1, 4},
new String[] {"batch_size", null},
OnnxMl.TensorProto.DataType.FLOAT);
input.setType(inputType);
graph.addInput(input);
OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
output.setName("output");
OnnxMl.TypeProto outputType =
buildTensorTypeNode(
new long[] {-1, 4},
new String[] {"batch_size", null},
OnnxMl.TensorProto.DataType.FLOAT);
output.setType(outputType);
graph.addOutput(output);
// Add initializers
OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
tensor.addDims(4);
tensor.addDims(4);
tensor.setDataLocation(DataLocation.EXTERNAL);
tensor.addExternalData(
StringStringEntryProto.newBuilder()
.setKey("location")
.setValue("external-matmul.out")
.build());
tensor.addExternalData(
StringStringEntryProto.newBuilder().setKey("offset").setValue("0").build());
tensor.addExternalData(
StringStringEntryProto.newBuilder().setKey("length").setValue("64").build());
float[] floats =
new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
ByteBuffer buf = ByteBuffer.allocate(64).order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer floatBuf = buf.asFloatBuffer();
floatBuf.put(floats);
floatBuf.rewind();
buf.rewind();
try (OutputStream os =
Files.newOutputStream(Paths.get("src", "test", "resources", "external-matmul.out"))) {
os.write(buf.array());
}
tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
tensor.setName("tensor");
graph.addInitializer(tensor);
// Add operations
OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
matmul.setName("matmul-0");
matmul.setOpType("MatMul");
matmul.addInput("input");
matmul.addInput("tensor");
matmul.addOutput("output");
graph.addNode(matmul);
// Build model
OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
model.setGraph(graph);
model.setDocString("ORT matmul test");
model.setModelVersion(0);
model.setIrVersion(8);
model.setDomain("ai.onnxruntime.test");
model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
try (OutputStream os =
Files.newOutputStream(Paths.get("src", "test", "resources", "java-external-matmul.onnx"))) {
model.build().writeTo(os);
}
}
public void generateMatMul() throws IOException {
OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
graph.setName("ort-test-matmul");
// Add placeholders
OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
input.setName("input");
OnnxMl.TypeProto inputType =
buildTensorTypeNode(
new long[] {-1, 4},
new String[] {"batch_size", null},
OnnxMl.TensorProto.DataType.FLOAT);
input.setType(inputType);
graph.addInput(input);
OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
output.setName("output");
OnnxMl.TypeProto outputType =
buildTensorTypeNode(
new long[] {-1, 4},
new String[] {"batch_size", null},
OnnxMl.TensorProto.DataType.FLOAT);
output.setType(outputType);
graph.addOutput(output);
// Add initializers
OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
tensor.addDims(4);
tensor.addDims(4);
Float[] floats =
new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
tensor.addAllFloatData(Arrays.asList(floats));
tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
tensor.setName("tensor");
graph.addInitializer(tensor);
// Add operations
OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
matmul.setName("matmul-0");
matmul.setOpType("MatMul");
matmul.addInput("input");
matmul.addInput("tensor");
matmul.addOutput("output");
graph.addNode(matmul);
// Build model
OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
model.setGraph(graph);
model.setDocString("ORT matmul test");
model.setModelVersion(0);
model.setIrVersion(8);
model.setDomain("ai.onnxruntime.test");
model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
try (OutputStream os =
Files.newOutputStream(Paths.get("src", "test", "resources", "java-matmul.onnx"))) {
model.build().writeTo(os);
}
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -31,7 +31,7 @@ public class SparseTensorTest {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
@ -152,7 +152,7 @@ public class SparseTensorTest {
inputMap.clear();
denseIdMatrix.close();
denseIdMatrix = makeIdentityMatrix(env, 4);
denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
@ -212,7 +212,7 @@ public class SparseTensorTest {
try (OrtSession session = env.createSession(modelPath, options)) {
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
@ -341,7 +341,7 @@ public class SparseTensorTest {
inputMap.clear();
denseIdMatrix.close();
denseIdMatrix = makeIdentityMatrix(env, 4);
denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
@ -439,13 +439,4 @@ public class SparseTensorTest {
}
}
}
private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
float[][] values = new float[size][size];
for (int i = 0; i < values.length; i++) {
values[i][i] = 1.0f;
}
return OnnxTensor.createTensor(env, values);
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -13,6 +13,8 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
@ -411,6 +413,25 @@ public class TestHelpers {
return new StringTensorPair(nodeName, onnxTensor);
}
public static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
float[][] values = new float[size][size];
for (int i = 0; i < values.length; i++) {
values[i][i] = 1.0f;
}
return OnnxTensor.createTensor(env, values);
}
public static OnnxTensor makeIdentityMatrixBuf(OrtEnvironment env, int size) throws OrtException {
FloatBuffer buf =
ByteBuffer.allocateDirect(size * size * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
for (int i = 0; i < size; i++) {
buf.put(i * size + i, 1.0f);
}
return OnnxTensor.createTensor(env, buf, new long[] {size, size});
}
private static class TypeWidth {
public final OnnxJavaType type;
public final int width;

Binary file not shown.

View file

@ -0,0 +1,18 @@
"ai.onnxruntime.test2ORT matmul test:Ñ
)
input
tensoroutputmatmul-0"MatMulort-test-matmul*L
Btensorj
locationexternal-matmul.outj
offset0j
length64pZ!
input


batch_size
b"
output


batch_size
B

Binary file not shown.