mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[java] Adding addExternalInitializers and addInitializer to OrtSession.SessionOptions (#16198)
### Description Adds support for adding external initializers or overriding initializers to a session options from Java. ### Motivation and Context We want to instantiate large models from Java without filesystem access. cc @yuslepukhin
This commit is contained in:
parent
661fd4b978
commit
ba91457183
9 changed files with 444 additions and 16 deletions
|
|
@ -793,6 +793,54 @@ public class OrtSession implements AutoCloseable {
|
|||
return Collections.unmodifiableMap(configEntries);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds in the supplied externally loaded initializers.
|
||||
*
|
||||
* <p>Note the initializers are copied into the session once it has been created, and the native
|
||||
* references are removed from this {@code SessionOptions}. Once the session has been created
|
||||
* those initializers can be closed. This is a different lifetime to initializers added via
|
||||
* {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be created from {@link
|
||||
* java.nio.Buffer} objects.
|
||||
*
|
||||
* @param initializers The map of names to initializers.
|
||||
* @throws OrtException If the initializers could not be loaded.
|
||||
*/
|
||||
public void addExternalInitializers(Map<String, OnnxTensorLike> initializers)
|
||||
throws OrtException {
|
||||
checkClosed();
|
||||
if (initializers.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String[] names = new String[initializers.size()];
|
||||
long[] handles = new long[initializers.size()];
|
||||
int i = 0;
|
||||
for (Map.Entry<String, OnnxTensorLike> e : initializers.entrySet()) {
|
||||
names[i] = e.getKey();
|
||||
handles[i] = e.getValue().nativeHandle;
|
||||
i++;
|
||||
}
|
||||
addExternalInitializers(OnnxRuntime.ortApiHandle, nativeHandle, names, handles);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an initializer to override one from the ONNX model.
|
||||
*
|
||||
* <p>Note the initializer lifetime must outlive the session and session options. This is a
|
||||
* different lifetime to initializers added via {@link #addExternalInitializers(Map)}. The
|
||||
* initializers must be created from {@link java.nio.Buffer} objects.
|
||||
*
|
||||
* @param name The initializer name.
|
||||
* @param initializer The initializer value.
|
||||
* @throws OrtException If the initializer could not be loaded into the session options.
|
||||
*/
|
||||
public void addInitializer(String name, OnnxTensorLike initializer) throws OrtException {
|
||||
checkClosed();
|
||||
if (name.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("Initializer name was blank");
|
||||
}
|
||||
addInitializer(OnnxRuntime.ortApiHandle, nativeHandle, name, initializer.getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* Add CUDA as an execution backend, using device 0.
|
||||
*
|
||||
|
|
@ -1096,6 +1144,13 @@ public class OrtSession implements AutoCloseable {
|
|||
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
|
||||
throws OrtException;
|
||||
|
||||
private native void addExternalInitializers(
|
||||
long apiHandle, long nativeHandle, String[] names, long[] tensorHandles)
|
||||
throws OrtException;
|
||||
|
||||
private native void addInitializer(
|
||||
long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException;
|
||||
|
||||
private native void disablePerSessionThreads(long apiHandle, long nativeHandle)
|
||||
throws OrtException;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -350,6 +350,85 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFre
|
|||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: addExternalInitializers
|
||||
* Signature: (JJ[Ljava/lang/String;[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExternalInitializers
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray namesArray, jlongArray handlesArray) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
|
||||
size_t arrLength = (*jniEnv)->GetArrayLength(jniEnv, handlesArray);
|
||||
|
||||
const char** names = allocarray(arrLength, sizeof(char*));
|
||||
if (names == NULL) {
|
||||
// Nothing to cleanup, return and throw exception
|
||||
throwOrtException(jniEnv, 1, "Not enough memory");
|
||||
return;
|
||||
}
|
||||
jobject* javaNameStrings = allocarray(arrLength, sizeof(jobject));
|
||||
if (javaNameStrings == NULL) {
|
||||
goto cleanup_names;
|
||||
}
|
||||
const OrtValue** initializers = allocarray(arrLength, sizeof(OrtValue*));
|
||||
if (initializers == NULL) {
|
||||
goto cleanup_java_input_strings;
|
||||
}
|
||||
|
||||
// Extract a C array of longs which are pointers to the input tensors.
|
||||
// The Java-side objects store native pointers as 64-bit longs, and on 32-bit systems
|
||||
// we cannot cast the long array to a pointer array as they are different sizes,
|
||||
// so we copy the longs applying the appropriate cast.
|
||||
jlong* initializersArr = (*jniEnv)->GetLongArrayElements(jniEnv, handlesArray, NULL);
|
||||
|
||||
for (size_t i = 0; i < arrLength; i++) {
|
||||
// Extract the string chars and cast the tensor
|
||||
javaNameStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, namesArray, (jint) i);
|
||||
names[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaNameStrings[i], NULL);
|
||||
initializers[i] = (const OrtValue*) initializersArr[i];
|
||||
}
|
||||
|
||||
checkOrtStatus(jniEnv,api,api->AddExternalInitializers(options,names,initializers,arrLength));
|
||||
|
||||
// Release the java array copy of pointers to the tensors.
|
||||
(*jniEnv)->ReleaseLongArrayElements(jniEnv, handlesArray, initializersArr, JNI_ABORT);
|
||||
free(initializers);
|
||||
cleanup_java_input_strings:
|
||||
// Release the Java strings
|
||||
for (size_t i = 0; i < arrLength; i++) {
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, javaNameStrings[i], names[i]);
|
||||
}
|
||||
free(javaNameStrings);
|
||||
cleanup_names:
|
||||
free(names);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: addInitializer
|
||||
* Signature: (JJLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addInitializer
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring name, jlong tensorHandle) {
|
||||
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
|
||||
const OrtApi* api = (const OrtApi*)apiHandle;
|
||||
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
|
||||
|
||||
// Extract the string chars
|
||||
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL);
|
||||
|
||||
// Cast the onnx value
|
||||
const OrtValue* tensor = (const OrtValue*) tensorHandle;
|
||||
|
||||
checkOrtStatus(jniEnv,api,api->AddInitializer(options,cName,tensor));
|
||||
|
||||
// Release the string chars
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,name,cName);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: disablePerSessionThreads
|
||||
|
|
|
|||
|
|
@ -664,6 +664,86 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExternalInitializers() throws IOException, OrtException {
|
||||
String modelPath = TestHelpers.getResourcePath("/java-external-matmul.onnx").toString();
|
||||
|
||||
// Run by loading the external initializer from disk
|
||||
// initializer is 1...16 in a 4x4 matrix.
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
|
||||
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
|
||||
OnnxTensor output = (OnnxTensor) res.get(0);
|
||||
float[][] outputArr = (float[][]) output.getValue();
|
||||
assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Run by overriding the initializer with the identity matrix
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
|
||||
options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
|
||||
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
|
||||
OnnxTensor output = (OnnxTensor) res.get(0);
|
||||
float[][] outputArr = (float[][]) output.getValue();
|
||||
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
|
||||
}
|
||||
}
|
||||
tensor.close();
|
||||
}
|
||||
// Run by overriding the initializer with the identity matrix loaded from a byte array
|
||||
byte[] modelBytes =
|
||||
Files.readAllBytes(TestHelpers.getResourcePath("/java-external-matmul.onnx"));
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
|
||||
options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
|
||||
try (OrtSession session = env.createSession(modelBytes, options)) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
|
||||
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
|
||||
OnnxTensor output = (OnnxTensor) res.get(0);
|
||||
float[][] outputArr = (float[][]) output.getValue();
|
||||
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
|
||||
}
|
||||
}
|
||||
tensor.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOverridingInitializer() throws OrtException {
|
||||
String modelPath = TestHelpers.getResourcePath("/java-matmul.onnx").toString();
|
||||
|
||||
// Run with the normal initializer
|
||||
// initializer is 1...16 in a 4x4 matrix.
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
|
||||
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
|
||||
OnnxTensor output = (OnnxTensor) res.get(0);
|
||||
float[][] outputArr = (float[][]) output.getValue();
|
||||
assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Run by overriding the initializer with the identity matrix
|
||||
try (SessionOptions options = new SessionOptions()) {
|
||||
OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
|
||||
options.addInitializer("tensor", tensor);
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
|
||||
OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
|
||||
OnnxTensor output = (OnnxTensor) res.get(0);
|
||||
float[][] outputArr = (float[][]) output.getValue();
|
||||
assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
|
||||
}
|
||||
}
|
||||
tensor.close();
|
||||
}
|
||||
}
|
||||
|
||||
private static File getTestModelsDir() throws IOException {
|
||||
// get build directory, append downloaded models location
|
||||
String cwd = System.getProperty("user.dir");
|
||||
|
|
|
|||
184
java/src/test/java/ai/onnxruntime/ModelGenerators.java
Normal file
184
java/src/test/java/ai/onnxruntime/ModelGenerators.java
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
/*
|
||||
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import ai.onnxruntime.OnnxMl.StringStringEntryProto;
|
||||
import ai.onnxruntime.OnnxMl.TensorProto.DataLocation;
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Arrays;
|
||||
|
||||
/** Methods to generate test models. */
|
||||
public final class ModelGenerators {
|
||||
private ModelGenerators() {}
|
||||
|
||||
public static OnnxMl.TensorShapeProto getShapeProto(
|
||||
long[] dimensions, String[] dimensionOverrides) {
|
||||
OnnxMl.TensorShapeProto.Builder builder = OnnxMl.TensorShapeProto.newBuilder();
|
||||
for (int i = 0; i < dimensions.length; i++) {
|
||||
if (dimensions[i] == -1) {
|
||||
builder.addDim(
|
||||
OnnxMl.TensorShapeProto.Dimension.newBuilder()
|
||||
.setDimParam(dimensionOverrides[i])
|
||||
.build());
|
||||
} else {
|
||||
builder.addDim(
|
||||
OnnxMl.TensorShapeProto.Dimension.newBuilder().setDimValue(dimensions[i]).build());
|
||||
}
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static OnnxMl.TypeProto buildTensorTypeNode(
|
||||
long[] dimensions, String[] dimensionOverrides, OnnxMl.TensorProto.DataType type) {
|
||||
OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder();
|
||||
|
||||
OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder();
|
||||
tensorBuilder.setElemType(type.getNumber());
|
||||
tensorBuilder.setShape(getShapeProto(dimensions, dimensionOverrides));
|
||||
builder.setTensorType(tensorBuilder.build());
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public void generateExternalMatMul() throws IOException {
|
||||
OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
|
||||
graph.setName("ort-test-matmul");
|
||||
|
||||
// Add placeholders
|
||||
OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
|
||||
input.setName("input");
|
||||
OnnxMl.TypeProto inputType =
|
||||
buildTensorTypeNode(
|
||||
new long[] {-1, 4},
|
||||
new String[] {"batch_size", null},
|
||||
OnnxMl.TensorProto.DataType.FLOAT);
|
||||
input.setType(inputType);
|
||||
graph.addInput(input);
|
||||
OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
|
||||
output.setName("output");
|
||||
OnnxMl.TypeProto outputType =
|
||||
buildTensorTypeNode(
|
||||
new long[] {-1, 4},
|
||||
new String[] {"batch_size", null},
|
||||
OnnxMl.TensorProto.DataType.FLOAT);
|
||||
output.setType(outputType);
|
||||
graph.addOutput(output);
|
||||
|
||||
// Add initializers
|
||||
OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
|
||||
tensor.addDims(4);
|
||||
tensor.addDims(4);
|
||||
tensor.setDataLocation(DataLocation.EXTERNAL);
|
||||
tensor.addExternalData(
|
||||
StringStringEntryProto.newBuilder()
|
||||
.setKey("location")
|
||||
.setValue("external-matmul.out")
|
||||
.build());
|
||||
tensor.addExternalData(
|
||||
StringStringEntryProto.newBuilder().setKey("offset").setValue("0").build());
|
||||
tensor.addExternalData(
|
||||
StringStringEntryProto.newBuilder().setKey("length").setValue("64").build());
|
||||
float[] floats =
|
||||
new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
|
||||
ByteBuffer buf = ByteBuffer.allocate(64).order(ByteOrder.LITTLE_ENDIAN);
|
||||
FloatBuffer floatBuf = buf.asFloatBuffer();
|
||||
floatBuf.put(floats);
|
||||
floatBuf.rewind();
|
||||
buf.rewind();
|
||||
try (OutputStream os =
|
||||
Files.newOutputStream(Paths.get("src", "test", "resources", "external-matmul.out"))) {
|
||||
os.write(buf.array());
|
||||
}
|
||||
tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
|
||||
tensor.setName("tensor");
|
||||
graph.addInitializer(tensor);
|
||||
|
||||
// Add operations
|
||||
OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
|
||||
matmul.setName("matmul-0");
|
||||
matmul.setOpType("MatMul");
|
||||
matmul.addInput("input");
|
||||
matmul.addInput("tensor");
|
||||
matmul.addOutput("output");
|
||||
graph.addNode(matmul);
|
||||
|
||||
// Build model
|
||||
OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
|
||||
model.setGraph(graph);
|
||||
model.setDocString("ORT matmul test");
|
||||
model.setModelVersion(0);
|
||||
model.setIrVersion(8);
|
||||
model.setDomain("ai.onnxruntime.test");
|
||||
model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
|
||||
try (OutputStream os =
|
||||
Files.newOutputStream(Paths.get("src", "test", "resources", "java-external-matmul.onnx"))) {
|
||||
model.build().writeTo(os);
|
||||
}
|
||||
}
|
||||
|
||||
public void generateMatMul() throws IOException {
|
||||
OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
|
||||
graph.setName("ort-test-matmul");
|
||||
|
||||
// Add placeholders
|
||||
OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
|
||||
input.setName("input");
|
||||
OnnxMl.TypeProto inputType =
|
||||
buildTensorTypeNode(
|
||||
new long[] {-1, 4},
|
||||
new String[] {"batch_size", null},
|
||||
OnnxMl.TensorProto.DataType.FLOAT);
|
||||
input.setType(inputType);
|
||||
graph.addInput(input);
|
||||
OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
|
||||
output.setName("output");
|
||||
OnnxMl.TypeProto outputType =
|
||||
buildTensorTypeNode(
|
||||
new long[] {-1, 4},
|
||||
new String[] {"batch_size", null},
|
||||
OnnxMl.TensorProto.DataType.FLOAT);
|
||||
output.setType(outputType);
|
||||
graph.addOutput(output);
|
||||
|
||||
// Add initializers
|
||||
OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
|
||||
tensor.addDims(4);
|
||||
tensor.addDims(4);
|
||||
Float[] floats =
|
||||
new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
|
||||
tensor.addAllFloatData(Arrays.asList(floats));
|
||||
tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
|
||||
tensor.setName("tensor");
|
||||
graph.addInitializer(tensor);
|
||||
|
||||
// Add operations
|
||||
OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
|
||||
matmul.setName("matmul-0");
|
||||
matmul.setOpType("MatMul");
|
||||
matmul.addInput("input");
|
||||
matmul.addInput("tensor");
|
||||
matmul.addOutput("output");
|
||||
graph.addNode(matmul);
|
||||
|
||||
// Build model
|
||||
OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
|
||||
model.setGraph(graph);
|
||||
model.setDocString("ORT matmul test");
|
||||
model.setModelVersion(0);
|
||||
model.setIrVersion(8);
|
||||
model.setDomain("ai.onnxruntime.test");
|
||||
model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
|
||||
try (OutputStream os =
|
||||
Files.newOutputStream(Paths.get("src", "test", "resources", "java-matmul.onnx"))) {
|
||||
model.build().writeTo(os);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -31,7 +31,7 @@ public class SparseTensorTest {
|
|||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
|
||||
OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
|
||||
long[] shape = new long[] {3, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
|
|
@ -152,7 +152,7 @@ public class SparseTensorTest {
|
|||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
|
||||
denseIdMatrix = makeIdentityMatrix(env, 4);
|
||||
denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
|
||||
long[] vectorShape = new long[] {1, 4};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
|
|
@ -212,7 +212,7 @@ public class SparseTensorTest {
|
|||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
Map<String, OnnxTensorLike> inputMap = new HashMap<>();
|
||||
|
||||
OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
|
||||
OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
|
||||
long[] shape = new long[] {3, 3};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
|
|
@ -341,7 +341,7 @@ public class SparseTensorTest {
|
|||
inputMap.clear();
|
||||
denseIdMatrix.close();
|
||||
|
||||
denseIdMatrix = makeIdentityMatrix(env, 4);
|
||||
denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
|
||||
long[] vectorShape = new long[] {1, 4};
|
||||
/*
|
||||
* Sparse matrix:
|
||||
|
|
@ -439,13 +439,4 @@ public class SparseTensorTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
|
||||
float[][] values = new float[size][size];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i][i] = 1.0f;
|
||||
}
|
||||
|
||||
return OnnxTensor.createTensor(env, values);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -13,6 +13,8 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
|
|
@ -411,6 +413,25 @@ public class TestHelpers {
|
|||
return new StringTensorPair(nodeName, onnxTensor);
|
||||
}
|
||||
|
||||
public static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
|
||||
float[][] values = new float[size][size];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i][i] = 1.0f;
|
||||
}
|
||||
|
||||
return OnnxTensor.createTensor(env, values);
|
||||
}
|
||||
|
||||
public static OnnxTensor makeIdentityMatrixBuf(OrtEnvironment env, int size) throws OrtException {
|
||||
FloatBuffer buf =
|
||||
ByteBuffer.allocateDirect(size * size * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
|
||||
for (int i = 0; i < size; i++) {
|
||||
buf.put(i * size + i, 1.0f);
|
||||
}
|
||||
|
||||
return OnnxTensor.createTensor(env, buf, new long[] {size, size});
|
||||
}
|
||||
|
||||
private static class TypeWidth {
|
||||
public final OnnxJavaType type;
|
||||
public final int width;
|
||||
|
|
|
|||
BIN
java/src/test/resources/external-matmul.out
Normal file
BIN
java/src/test/resources/external-matmul.out
Normal file
Binary file not shown.
18
java/src/test/resources/java-external-matmul.onnx
Normal file
18
java/src/test/resources/java-external-matmul.onnx
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"ai.onnxruntime.test2ORT matmul test:Ñ
|
||||
)
|
||||
input
|
||||
tensoroutputmatmul-0"MatMulort-test-matmul*L
|
||||
Btensorj
|
||||
locationexternal-matmul.outj
|
||||
offset0j
|
||||
length64pZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
BIN
java/src/test/resources/java-matmul.onnx
Normal file
BIN
java/src/test/resources/java-matmul.onnx
Normal file
Binary file not shown.
Loading…
Reference in a new issue