diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java
index 3191e0b2e2..7df69c8bfe 100644
--- a/java/src/main/java/ai/onnxruntime/OrtSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtSession.java
@@ -793,6 +793,54 @@ public class OrtSession implements AutoCloseable {
return Collections.unmodifiableMap(configEntries);
}
+ /**
+ * Adds in the supplied externally loaded initializers.
+ *
+ *
Note the initializers are copied into the session once it has been created, and the native
+ * references are removed from this {@code SessionOptions}. Once the session has been created
+ * those initializers can be closed. This is a different lifetime to initializers added via
+ * {@link #addInitializer(String, OnnxTensorLike)}. The initializers must be created from {@link
+ * java.nio.Buffer} objects.
+ *
+ * @param initializers The map of names to initializers.
+ * @throws OrtException If the initializers could not be loaded.
+ */
+ public void addExternalInitializers(Map initializers)
+ throws OrtException {
+ checkClosed();
+ if (initializers.isEmpty()) {
+ return;
+ }
+ String[] names = new String[initializers.size()];
+ long[] handles = new long[initializers.size()];
+ int i = 0;
+ for (Map.Entry e : initializers.entrySet()) {
+ names[i] = e.getKey();
+ handles[i] = e.getValue().nativeHandle;
+ i++;
+ }
+ addExternalInitializers(OnnxRuntime.ortApiHandle, nativeHandle, names, handles);
+ }
+
+ /**
+ * Adds an initializer to override one from the ONNX model.
+ *
+ * Note the initializer lifetime must outlive the session and session options. This is a
+ * different lifetime to initializers added via {@link #addExternalInitializers(Map)}. The
+ * initializers must be created from {@link java.nio.Buffer} objects.
+ *
+ * @param name The initializer name.
+ * @param initializer The initializer value.
+ * @throws OrtException If the initializer could not be loaded into the session options.
+ */
+ public void addInitializer(String name, OnnxTensorLike initializer) throws OrtException {
+ checkClosed();
+ if (name.trim().isEmpty()) {
+ throw new IllegalArgumentException("Initializer name was blank");
+ }
+ addInitializer(OnnxRuntime.ortApiHandle, nativeHandle, name, initializer.getNativeHandle());
+ }
+
/**
* Add CUDA as an execution backend, using device 0.
*
@@ -1096,6 +1144,13 @@ public class OrtSession implements AutoCloseable {
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
throws OrtException;
+ private native void addExternalInitializers(
+ long apiHandle, long nativeHandle, String[] names, long[] tensorHandles)
+ throws OrtException;
+
+ private native void addInitializer(
+ long apiHandle, long nativeHandle, String name, long tensorHandle) throws OrtException;
+
private native void disablePerSessionThreads(long apiHandle, long nativeHandle)
throws OrtException;
diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
index c805e00dc4..d3239c7442 100644
--- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
+++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include
@@ -350,6 +350,85 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addFre
(*jniEnv)->ReleaseStringUTFChars(jniEnv,dimensionName,cName);
}
+/*
+ * Class: ai_onnxruntime_OrtSession_SessionOptions
+ * Method: addExternalInitializers
+ * Signature: (JJ[Ljava/lang/String;[J)V
+ */
+JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExternalInitializers
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray namesArray, jlongArray handlesArray) {
+ (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
+ const OrtApi* api = (const OrtApi*)apiHandle;
+ OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
+
+ size_t arrLength = (*jniEnv)->GetArrayLength(jniEnv, handlesArray);
+
+ const char** names = allocarray(arrLength, sizeof(char*));
+ if (names == NULL) {
+ // Nothing to cleanup, return and throw exception
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return;
+ }
+ jobject* javaNameStrings = allocarray(arrLength, sizeof(jobject));
+ if (javaNameStrings == NULL) {
+ goto cleanup_names;
+ }
+ const OrtValue** initializers = allocarray(arrLength, sizeof(OrtValue*));
+ if (initializers == NULL) {
+ goto cleanup_java_input_strings;
+ }
+
+ // Extract a C array of longs which are pointers to the input tensors.
+ // The Java-side objects store native pointers as 64-bit longs, and on 32-bit systems
+ // we cannot cast the long array to a pointer array as they are different sizes,
+ // so we copy the longs applying the appropriate cast.
+ jlong* initializersArr = (*jniEnv)->GetLongArrayElements(jniEnv, handlesArray, NULL);
+
+ for (size_t i = 0; i < arrLength; i++) {
+ // Extract the string chars and cast the tensor
+ javaNameStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, namesArray, (jint) i);
+ names[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaNameStrings[i], NULL);
+ initializers[i] = (const OrtValue*) initializersArr[i];
+ }
+
+ checkOrtStatus(jniEnv,api,api->AddExternalInitializers(options,names,initializers,arrLength));
+
+ // Release the java array copy of pointers to the tensors.
+ (*jniEnv)->ReleaseLongArrayElements(jniEnv, handlesArray, initializersArr, JNI_ABORT);
+ free(initializers);
+cleanup_java_input_strings:
+ // Release the Java strings
+ for (size_t i = 0; i < arrLength; i++) {
+ (*jniEnv)->ReleaseStringUTFChars(jniEnv, javaNameStrings[i], names[i]);
+ }
+ free(javaNameStrings);
+cleanup_names:
+ free(names);
+}
+
+/*
+ * Class: ai_onnxruntime_OrtSession_SessionOptions
+ * Method: addInitializer
+ * Signature: (JJLjava/lang/String;J)V
+ */
+JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addInitializer
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring name, jlong tensorHandle) {
+ (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
+ const OrtApi* api = (const OrtApi*)apiHandle;
+ OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
+
+ // Extract the string chars
+ const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL);
+
+ // Cast the onnx value
+ const OrtValue* tensor = (const OrtValue*) tensorHandle;
+
+ checkOrtStatus(jniEnv,api,api->AddInitializer(options,cName,tensor));
+
+ // Release the string chars
+ (*jniEnv)->ReleaseStringUTFChars(jniEnv,name,cName);
+}
+
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: disablePerSessionThreads
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 76ac083049..5ca7bbc651 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -664,6 +664,86 @@ public class InferenceTest {
}
}
+ @Test
+ public void testExternalInitializers() throws IOException, OrtException {
+ String modelPath = TestHelpers.getResourcePath("/java-external-matmul.onnx").toString();
+
+ // Run by loading the external initializer from disk
+ // initializer is 1...16 in a 4x4 matrix.
+ try (SessionOptions options = new SessionOptions()) {
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
+ OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
+ OnnxTensor output = (OnnxTensor) res.get(0);
+ float[][] outputArr = (float[][]) output.getValue();
+ assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
+ }
+ }
+ }
+ // Run by overriding the initializer with the identity matrix
+ try (SessionOptions options = new SessionOptions()) {
+ OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
+ options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
+ OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
+ OnnxTensor output = (OnnxTensor) res.get(0);
+ float[][] outputArr = (float[][]) output.getValue();
+ assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
+ }
+ }
+ tensor.close();
+ }
+ // Run by overriding the initializer with the identity matrix loaded from a byte array
+ byte[] modelBytes =
+ Files.readAllBytes(TestHelpers.getResourcePath("/java-external-matmul.onnx"));
+ try (SessionOptions options = new SessionOptions()) {
+ OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
+ options.addExternalInitializers(Collections.singletonMap("tensor", tensor));
+ try (OrtSession session = env.createSession(modelBytes, options)) {
+ try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
+ OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
+ OnnxTensor output = (OnnxTensor) res.get(0);
+ float[][] outputArr = (float[][]) output.getValue();
+ assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
+ }
+ }
+ tensor.close();
+ }
+ }
+
+ @Test
+ public void testOverridingInitializer() throws OrtException {
+ String modelPath = TestHelpers.getResourcePath("/java-matmul.onnx").toString();
+
+ // Run with the normal initializer
+ // initializer is 1...16 in a 4x4 matrix.
+ try (SessionOptions options = new SessionOptions()) {
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
+ OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
+ OnnxTensor output = (OnnxTensor) res.get(0);
+ float[][] outputArr = (float[][]) output.getValue();
+ assertArrayEquals(new float[] {90, 100, 110, 120}, outputArr[0]);
+ }
+ }
+ }
+ // Run by overriding the initializer with the identity matrix
+ try (SessionOptions options = new SessionOptions()) {
+ OnnxTensor tensor = TestHelpers.makeIdentityMatrixBuf(env, 4);
+ options.addInitializer("tensor", tensor);
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ try (OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}});
+ OrtSession.Result res = session.run(Collections.singletonMap("input", t))) {
+ OnnxTensor output = (OnnxTensor) res.get(0);
+ float[][] outputArr = (float[][]) output.getValue();
+ assertArrayEquals(new float[] {1, 2, 3, 4}, outputArr[0]);
+ }
+ }
+ tensor.close();
+ }
+ }
+
private static File getTestModelsDir() throws IOException {
// get build directory, append downloaded models location
String cwd = System.getProperty("user.dir");
diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java
new file mode 100644
index 0000000000..6dcc4ce7f8
--- /dev/null
+++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java
@@ -0,0 +1,184 @@
+/*
+ * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
+ * Licensed under the MIT License.
+ */
+package ai.onnxruntime;
+
+import ai.onnxruntime.OnnxMl.StringStringEntryProto;
+import ai.onnxruntime.OnnxMl.TensorProto.DataLocation;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+
+/** Methods to generate test models. */
+public final class ModelGenerators {
+ private ModelGenerators() {}
+
+ public static OnnxMl.TensorShapeProto getShapeProto(
+ long[] dimensions, String[] dimensionOverrides) {
+ OnnxMl.TensorShapeProto.Builder builder = OnnxMl.TensorShapeProto.newBuilder();
+ for (int i = 0; i < dimensions.length; i++) {
+ if (dimensions[i] == -1) {
+ builder.addDim(
+ OnnxMl.TensorShapeProto.Dimension.newBuilder()
+ .setDimParam(dimensionOverrides[i])
+ .build());
+ } else {
+ builder.addDim(
+ OnnxMl.TensorShapeProto.Dimension.newBuilder().setDimValue(dimensions[i]).build());
+ }
+ }
+ return builder.build();
+ }
+
+ public static OnnxMl.TypeProto buildTensorTypeNode(
+ long[] dimensions, String[] dimensionOverrides, OnnxMl.TensorProto.DataType type) {
+ OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder();
+
+ OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder();
+ tensorBuilder.setElemType(type.getNumber());
+ tensorBuilder.setShape(getShapeProto(dimensions, dimensionOverrides));
+ builder.setTensorType(tensorBuilder.build());
+
+ return builder.build();
+ }
+
+ public void generateExternalMatMul() throws IOException {
+ OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
+ graph.setName("ort-test-matmul");
+
+ // Add placeholders
+ OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
+ input.setName("input");
+ OnnxMl.TypeProto inputType =
+ buildTensorTypeNode(
+ new long[] {-1, 4},
+ new String[] {"batch_size", null},
+ OnnxMl.TensorProto.DataType.FLOAT);
+ input.setType(inputType);
+ graph.addInput(input);
+ OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
+ output.setName("output");
+ OnnxMl.TypeProto outputType =
+ buildTensorTypeNode(
+ new long[] {-1, 4},
+ new String[] {"batch_size", null},
+ OnnxMl.TensorProto.DataType.FLOAT);
+ output.setType(outputType);
+ graph.addOutput(output);
+
+ // Add initializers
+ OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
+ tensor.addDims(4);
+ tensor.addDims(4);
+ tensor.setDataLocation(DataLocation.EXTERNAL);
+ tensor.addExternalData(
+ StringStringEntryProto.newBuilder()
+ .setKey("location")
+ .setValue("external-matmul.out")
+ .build());
+ tensor.addExternalData(
+ StringStringEntryProto.newBuilder().setKey("offset").setValue("0").build());
+ tensor.addExternalData(
+ StringStringEntryProto.newBuilder().setKey("length").setValue("64").build());
+ float[] floats =
+ new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
+ ByteBuffer buf = ByteBuffer.allocate(64).order(ByteOrder.LITTLE_ENDIAN);
+ FloatBuffer floatBuf = buf.asFloatBuffer();
+ floatBuf.put(floats);
+ floatBuf.rewind();
+ buf.rewind();
+ try (OutputStream os =
+ Files.newOutputStream(Paths.get("src", "test", "resources", "external-matmul.out"))) {
+ os.write(buf.array());
+ }
+ tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
+ tensor.setName("tensor");
+ graph.addInitializer(tensor);
+
+ // Add operations
+ OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
+ matmul.setName("matmul-0");
+ matmul.setOpType("MatMul");
+ matmul.addInput("input");
+ matmul.addInput("tensor");
+ matmul.addOutput("output");
+ graph.addNode(matmul);
+
+ // Build model
+ OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
+ model.setGraph(graph);
+ model.setDocString("ORT matmul test");
+ model.setModelVersion(0);
+ model.setIrVersion(8);
+ model.setDomain("ai.onnxruntime.test");
+ model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
+ try (OutputStream os =
+ Files.newOutputStream(Paths.get("src", "test", "resources", "java-external-matmul.onnx"))) {
+ model.build().writeTo(os);
+ }
+ }
+
+ public void generateMatMul() throws IOException {
+ OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder();
+ graph.setName("ort-test-matmul");
+
+ // Add placeholders
+ OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder();
+ input.setName("input");
+ OnnxMl.TypeProto inputType =
+ buildTensorTypeNode(
+ new long[] {-1, 4},
+ new String[] {"batch_size", null},
+ OnnxMl.TensorProto.DataType.FLOAT);
+ input.setType(inputType);
+ graph.addInput(input);
+ OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder();
+ output.setName("output");
+ OnnxMl.TypeProto outputType =
+ buildTensorTypeNode(
+ new long[] {-1, 4},
+ new String[] {"batch_size", null},
+ OnnxMl.TensorProto.DataType.FLOAT);
+ output.setType(outputType);
+ graph.addOutput(output);
+
+ // Add initializers
+ OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder();
+ tensor.addDims(4);
+ tensor.addDims(4);
+ Float[] floats =
+ new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f};
+ tensor.addAllFloatData(Arrays.asList(floats));
+ tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
+ tensor.setName("tensor");
+ graph.addInitializer(tensor);
+
+ // Add operations
+ OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder();
+ matmul.setName("matmul-0");
+ matmul.setOpType("MatMul");
+ matmul.addInput("input");
+ matmul.addInput("tensor");
+ matmul.addOutput("output");
+ graph.addNode(matmul);
+
+ // Build model
+ OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder();
+ model.setGraph(graph);
+ model.setDocString("ORT matmul test");
+ model.setModelVersion(0);
+ model.setIrVersion(8);
+ model.setDomain("ai.onnxruntime.test");
+ model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build());
+ try (OutputStream os =
+ Files.newOutputStream(Paths.get("src", "test", "resources", "java-matmul.onnx"))) {
+ model.build().writeTo(os);
+ }
+ }
+}
diff --git a/java/src/test/java/ai/onnxruntime/SparseTensorTest.java b/java/src/test/java/ai/onnxruntime/SparseTensorTest.java
index 9b0533f5c4..99e55874f3 100644
--- a/java/src/test/java/ai/onnxruntime/SparseTensorTest.java
+++ b/java/src/test/java/ai/onnxruntime/SparseTensorTest.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@@ -31,7 +31,7 @@ public class SparseTensorTest {
try (OrtSession session = env.createSession(modelPath, options)) {
Map inputMap = new HashMap<>();
- OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
+ OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
@@ -152,7 +152,7 @@ public class SparseTensorTest {
inputMap.clear();
denseIdMatrix.close();
- denseIdMatrix = makeIdentityMatrix(env, 4);
+ denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
@@ -212,7 +212,7 @@ public class SparseTensorTest {
try (OrtSession session = env.createSession(modelPath, options)) {
Map inputMap = new HashMap<>();
- OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3);
+ OnnxTensor denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 3);
long[] shape = new long[] {3, 3};
/*
* Sparse matrix:
@@ -341,7 +341,7 @@ public class SparseTensorTest {
inputMap.clear();
denseIdMatrix.close();
- denseIdMatrix = makeIdentityMatrix(env, 4);
+ denseIdMatrix = TestHelpers.makeIdentityMatrix(env, 4);
long[] vectorShape = new long[] {1, 4};
/*
* Sparse matrix:
@@ -439,13 +439,4 @@ public class SparseTensorTest {
}
}
}
-
- private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
- float[][] values = new float[size][size];
- for (int i = 0; i < values.length; i++) {
- values[i][i] = 1.0f;
- }
-
- return OnnxTensor.createTensor(env, values);
- }
}
diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java
index b82cb07aee..7d41918b1c 100644
--- a/java/src/test/java/ai/onnxruntime/TestHelpers.java
+++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@@ -13,6 +13,8 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -411,6 +413,25 @@ public class TestHelpers {
return new StringTensorPair(nodeName, onnxTensor);
}
+ public static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException {
+ float[][] values = new float[size][size];
+ for (int i = 0; i < values.length; i++) {
+ values[i][i] = 1.0f;
+ }
+
+ return OnnxTensor.createTensor(env, values);
+ }
+
+ public static OnnxTensor makeIdentityMatrixBuf(OrtEnvironment env, int size) throws OrtException {
+ FloatBuffer buf =
+ ByteBuffer.allocateDirect(size * size * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
+ for (int i = 0; i < size; i++) {
+ buf.put(i * size + i, 1.0f);
+ }
+
+ return OnnxTensor.createTensor(env, buf, new long[] {size, size});
+ }
+
private static class TypeWidth {
public final OnnxJavaType type;
public final int width;
diff --git a/java/src/test/resources/external-matmul.out b/java/src/test/resources/external-matmul.out
new file mode 100644
index 0000000000..bbed315933
Binary files /dev/null and b/java/src/test/resources/external-matmul.out differ
diff --git a/java/src/test/resources/java-external-matmul.onnx b/java/src/test/resources/java-external-matmul.onnx
new file mode 100644
index 0000000000..d0f31853d1
--- /dev/null
+++ b/java/src/test/resources/java-external-matmul.onnx
@@ -0,0 +1,18 @@
+"ai.onnxruntime.test2ORT matmul test:Ñ
+)
+input
+tensoroutputmatmul-0"MatMulort-test-matmul*L
+Btensorj
+locationexternal-matmul.outj
+offset0j
+length64pZ!
+input
+
+
+batch_size
+b"
+output
+
+
+batch_size
+B
\ No newline at end of file
diff --git a/java/src/test/resources/java-matmul.onnx b/java/src/test/resources/java-matmul.onnx
new file mode 100644
index 0000000000..2deeeebd13
Binary files /dev/null and b/java/src/test/resources/java-matmul.onnx differ