mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[java] Changes OrtEnvironment so it can't be closed by users (#10670)
* Changes OrtEnvironment so it can't be closed by users. * Fix the formatting and add a same instance check.
This commit is contained in:
parent
e23a224518
commit
f856608599
8 changed files with 649 additions and 674 deletions
|
|
@ -167,6 +167,7 @@ test {
|
|||
java {
|
||||
dependsOn spotlessJava
|
||||
}
|
||||
forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool
|
||||
useJUnitPlatform()
|
||||
if (cmakeBuildDir != null) {
|
||||
workingDir cmakeBuildDir
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
*/
|
||||
static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
TensorInfo info = TensorInfo.constructFromJavaArray(data);
|
||||
if (info.type == OnnxJavaType.STRING) {
|
||||
if (info.shape.length == 0) {
|
||||
|
|
@ -403,7 +403,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
*/
|
||||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
TensorInfo info =
|
||||
new TensorInfo(
|
||||
shape,
|
||||
|
|
@ -451,7 +451,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
OnnxJavaType type = OnnxJavaType.FLOAT;
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
|
|
@ -492,7 +492,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
OnnxJavaType type = OnnxJavaType.DOUBLE;
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
|
|
@ -571,7 +571,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
|
||||
|
|
@ -611,7 +611,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
OnnxJavaType type = OnnxJavaType.INT16;
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
|
|
@ -652,7 +652,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
OnnxJavaType type = OnnxJavaType.INT32;
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
|
|
@ -693,7 +693,7 @@ public class OnnxTensor implements OnnxValue {
|
|||
static OnnxTensor createTensor(
|
||||
OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape)
|
||||
throws OrtException {
|
||||
if ((!env.isClosed()) && (!allocator.isClosed())) {
|
||||
if (!allocator.isClosed()) {
|
||||
OnnxJavaType type = OnnxJavaType.INT64;
|
||||
return createTensor(type, allocator, data, shape);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -7,14 +7,18 @@ package ai.onnxruntime;
|
|||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate
|
||||
* specific models.
|
||||
*
|
||||
* <p>There can be at most one OrtEnvironment object created in a JVM lifetime. This class
|
||||
* implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but
|
||||
* the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered
|
||||
* on construction.
|
||||
*/
|
||||
public class OrtEnvironment implements AutoCloseable {
|
||||
public final class OrtEnvironment implements AutoCloseable {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());
|
||||
|
||||
|
|
@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
|
||||
private static volatile OrtEnvironment INSTANCE;
|
||||
|
||||
private static final AtomicInteger refCount = new AtomicInteger();
|
||||
|
||||
private static volatile OrtLoggingLevel curLogLevel;
|
||||
|
||||
private static volatile String curLoggingName;
|
||||
|
|
@ -47,8 +49,6 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
// If there's no instance, create one.
|
||||
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
|
||||
} else {
|
||||
// else return the current one.
|
||||
refCount.incrementAndGet();
|
||||
return INSTANCE;
|
||||
}
|
||||
}
|
||||
|
|
@ -106,7 +106,6 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
|
||||
}
|
||||
}
|
||||
refCount.incrementAndGet();
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
|
|
@ -131,7 +130,6 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
} catch (OrtException e) {
|
||||
throw new IllegalStateException("Failed to create OrtEnvironment", e);
|
||||
}
|
||||
refCount.incrementAndGet();
|
||||
return INSTANCE;
|
||||
} else {
|
||||
// As the thread pool state is unknown, and that's probably not what the user wanted.
|
||||
|
|
@ -144,8 +142,6 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
|
||||
final OrtAllocator defaultAllocator;
|
||||
|
||||
private volatile boolean closed = false;
|
||||
|
||||
/**
|
||||
* Create an OrtEnvironment using a default name.
|
||||
*
|
||||
|
|
@ -165,6 +161,8 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
|
||||
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
|
||||
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
|
||||
Runtime.getRuntime()
|
||||
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -181,6 +179,8 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
createHandle(
|
||||
OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle);
|
||||
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
|
||||
Runtime.getRuntime()
|
||||
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -219,11 +219,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
*/
|
||||
OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options)
|
||||
throws OrtException {
|
||||
if (!closed) {
|
||||
return new OrtSession(this, modelPath, allocator, options);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
|
||||
}
|
||||
return new OrtSession(this, modelPath, allocator, options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -262,11 +258,7 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
*/
|
||||
OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options)
|
||||
throws OrtException {
|
||||
if (!closed) {
|
||||
return new OrtSession(this, modelArray, allocator, options);
|
||||
} else {
|
||||
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
|
||||
}
|
||||
return new OrtSession(this, modelArray, allocator, options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -279,41 +271,11 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry);
|
||||
}
|
||||
|
||||
/**
|
||||
* Is this environment closed?
|
||||
*
|
||||
* @return True if the environment is closed.
|
||||
*/
|
||||
public boolean isClosed() {
|
||||
return closed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")";
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the OrtEnvironment. If this is the last reference to the environment then it closes the
|
||||
* native handle.
|
||||
*
|
||||
* @throws OrtException If the close failed.
|
||||
*/
|
||||
@Override
|
||||
public synchronized void close() throws OrtException {
|
||||
synchronized (refCount) {
|
||||
int curCount = refCount.get();
|
||||
if (curCount != 0) {
|
||||
refCount.decrementAndGet();
|
||||
}
|
||||
if (curCount == 1) {
|
||||
close(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
closed = true;
|
||||
INSTANCE = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the providers available in this environment.
|
||||
*
|
||||
|
|
@ -377,11 +339,22 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry)
|
||||
throws OrtException;
|
||||
|
||||
/** Close is a no-op on OrtEnvironment since ORT 1.11. */
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
/**
|
||||
* Controls the global thread pools in the environment. Only used if the session is constructed
|
||||
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
|
||||
*/
|
||||
public static final class ThreadingOptions implements AutoCloseable {
|
||||
static {
|
||||
try {
|
||||
OnnxRuntime.init();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load onnx-runtime library", e);
|
||||
}
|
||||
}
|
||||
|
||||
private final long nativeHandle;
|
||||
|
||||
|
|
@ -486,4 +459,24 @@ public class OrtEnvironment implements AutoCloseable {
|
|||
|
||||
private native void closeThreadingOptions(long apiHandle, long nativeHandle);
|
||||
}
|
||||
|
||||
private static final class OrtEnvCloser implements Runnable {
|
||||
|
||||
private final long apiHandle;
|
||||
private final long nativeHandle;
|
||||
|
||||
OrtEnvCloser(long apiHandle, long nativeHandle) {
|
||||
this.apiHandle = apiHandle;
|
||||
this.nativeHandle = nativeHandle;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
OrtEnvironment.close(apiHandle, nativeHandle);
|
||||
} catch (OrtException e) {
|
||||
System.err.println("Error closing OrtEnvironment, " + e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import static ai.onnxruntime.TestHelpers.getResourcePath;
|
||||
import static ai.onnxruntime.TestHelpers.loadTensorFromFile;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/** This test is in a separate class to ensure it is run in a clean JVM. */
|
||||
public class EnvironmentThreadPoolTest {
|
||||
|
||||
@Test
|
||||
public void environmentThreadPoolTest() throws OrtException {
|
||||
Path squeezeNet = getResourcePath("/squeezenet.onnx");
|
||||
String modelPath = squeezeNet.toString();
|
||||
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
|
||||
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
|
||||
OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
|
||||
threadOpts.setGlobalInterOpNumThreads(2);
|
||||
threadOpts.setGlobalIntraOpNumThreads(2);
|
||||
threadOpts.setGlobalDenormalAsZero();
|
||||
threadOpts.setGlobalSpinControl(true);
|
||||
OrtEnvironment env =
|
||||
OrtEnvironment.getEnvironment(
|
||||
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
|
||||
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions();
|
||||
OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) {
|
||||
disableThreadOptions.disablePerSessionThreads();
|
||||
|
||||
// Check that the regular session executes
|
||||
try (OrtSession session = env.createSession(modelPath, options)) {
|
||||
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
|
||||
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
|
||||
Object tensorData = OrtUtil.reshape(inputData, inputShape);
|
||||
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
|
||||
container.put(inputMeta.getName(), tensor);
|
||||
try (OrtSession.Result result = session.run(container)) {
|
||||
OnnxValue resultTensor = result.get(0);
|
||||
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
|
||||
assertEquals(expectedOutput.length, resultArray.length);
|
||||
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
|
||||
}
|
||||
container.clear();
|
||||
tensor.close();
|
||||
}
|
||||
|
||||
// Check that the session using the env thread pool executes
|
||||
try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
|
||||
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
|
||||
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
|
||||
Object tensorData = OrtUtil.reshape(inputData, inputShape);
|
||||
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
|
||||
container.put(inputMeta.getName(), tensor);
|
||||
try (OrtSession.Result result = session.run(container)) {
|
||||
OnnxValue resultTensor = result.get(0);
|
||||
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
|
||||
assertEquals(expectedOutput.length, resultArray.length);
|
||||
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
|
||||
}
|
||||
container.clear();
|
||||
tensor.close();
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
OrtEnvironment newEnv =
|
||||
OrtEnvironment.getEnvironment(
|
||||
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
|
||||
// fail as we can't recreate environments with different threading options
|
||||
fail("Should have thrown IllegalStateException");
|
||||
} catch (IllegalStateException e) {
|
||||
// pass
|
||||
}
|
||||
|
||||
threadOpts.close();
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -12,117 +12,114 @@ public class TensorCreationTest {
|
|||
|
||||
@Test
|
||||
public void testScalarCreation() throws OrtException {
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
String[] stringValues = new String[] {"true", "false"};
|
||||
for (String s : stringValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, s)) {
|
||||
Assertions.assertEquals(s, t.getValue());
|
||||
}
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String[] stringValues = new String[] {"true", "false"};
|
||||
for (String s : stringValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, s)) {
|
||||
Assertions.assertEquals(s, t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
boolean[] boolValues = new boolean[] {true, false};
|
||||
for (boolean b : boolValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, b)) {
|
||||
Assertions.assertEquals(b, t.getValue());
|
||||
}
|
||||
boolean[] boolValues = new boolean[] {true, false};
|
||||
for (boolean b : boolValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, b)) {
|
||||
Assertions.assertEquals(b, t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
int[] intValues =
|
||||
new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE};
|
||||
for (int i : intValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, i)) {
|
||||
Assertions.assertEquals(i, t.getValue());
|
||||
}
|
||||
int[] intValues =
|
||||
new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE};
|
||||
for (int i : intValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, i)) {
|
||||
Assertions.assertEquals(i, t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
long[] longValues =
|
||||
new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE};
|
||||
for (long l : longValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, l)) {
|
||||
Assertions.assertEquals(l, t.getValue());
|
||||
}
|
||||
long[] longValues =
|
||||
new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE};
|
||||
for (long l : longValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, l)) {
|
||||
Assertions.assertEquals(l, t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
float[] floatValues =
|
||||
new float[] {
|
||||
-1.0f,
|
||||
0.0f,
|
||||
-0.0f,
|
||||
1.0f,
|
||||
1234.5678f,
|
||||
-1234.5678f,
|
||||
(float) Math.PI,
|
||||
(float) Math.E,
|
||||
Float.MAX_VALUE,
|
||||
Float.MIN_VALUE
|
||||
};
|
||||
for (float f : floatValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, f)) {
|
||||
Assertions.assertEquals(f, t.getValue());
|
||||
}
|
||||
float[] floatValues =
|
||||
new float[] {
|
||||
-1.0f,
|
||||
0.0f,
|
||||
-0.0f,
|
||||
1.0f,
|
||||
1234.5678f,
|
||||
-1234.5678f,
|
||||
(float) Math.PI,
|
||||
(float) Math.E,
|
||||
Float.MAX_VALUE,
|
||||
Float.MIN_VALUE
|
||||
};
|
||||
for (float f : floatValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, f)) {
|
||||
Assertions.assertEquals(f, t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
double[] doubleValues =
|
||||
new double[] {
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1234.5678,
|
||||
-1234.5678,
|
||||
Math.PI,
|
||||
Math.E,
|
||||
Double.MAX_VALUE,
|
||||
Double.MIN_VALUE
|
||||
};
|
||||
for (double d : doubleValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, d)) {
|
||||
Assertions.assertEquals(d, t.getValue());
|
||||
}
|
||||
double[] doubleValues =
|
||||
new double[] {
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1234.5678,
|
||||
-1234.5678,
|
||||
Math.PI,
|
||||
Math.E,
|
||||
Double.MAX_VALUE,
|
||||
Double.MIN_VALUE
|
||||
};
|
||||
for (double d : doubleValues) {
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, d)) {
|
||||
Assertions.assertEquals(d, t.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringCreation() throws OrtException {
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
|
||||
String[] output = (String[]) t.getValue();
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
}
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
|
||||
String[] output = (String[]) t.getValue();
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
}
|
||||
|
||||
String[][] stringValues =
|
||||
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
|
||||
String[][] output = (String[][]) t.getValue();
|
||||
Assertions.assertArrayEquals(stringValues, output);
|
||||
}
|
||||
String[][] stringValues =
|
||||
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
|
||||
String[][] output = (String[][]) t.getValue();
|
||||
Assertions.assertArrayEquals(stringValues, output);
|
||||
}
|
||||
|
||||
String[][][] deepStringValues =
|
||||
new String[][][] {
|
||||
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
|
||||
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
|
||||
};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
|
||||
String[][][] output = (String[][][]) t.getValue();
|
||||
Assertions.assertArrayEquals(deepStringValues, output);
|
||||
}
|
||||
String[][][] deepStringValues =
|
||||
new String[][][] {
|
||||
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
|
||||
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
|
||||
};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
|
||||
String[][][] output = (String[][][]) t.getValue();
|
||||
Assertions.assertArrayEquals(deepStringValues, output);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUint8Creation() throws OrtException {
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
byte[] buf = new byte[] {0, 1};
|
||||
ByteBuffer data = ByteBuffer.wrap(buf);
|
||||
long[] shape = new long[] {2};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
|
||||
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
|
||||
}
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
byte[] buf = new byte[] {0, 1};
|
||||
ByteBuffer data = ByteBuffer.wrap(buf);
|
||||
long[] shape = new long[] {2};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
|
||||
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,31 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Pattern;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
|
||||
/** Test helpers for manipulating primitive arrays. */
|
||||
class TestHelpers {
|
||||
|
||||
private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]");
|
||||
|
||||
static boolean[] toPrimitiveBoolean(List<Boolean> input) {
|
||||
boolean[] output = new boolean[input.size()];
|
||||
|
||||
|
|
@ -234,4 +249,179 @@ class TestHelpers {
|
|||
static void flattenStringBase(String[] input, List<String> output) {
|
||||
output.addAll(Arrays.asList(input));
|
||||
}
|
||||
|
||||
static Path getResourcePath(String path) {
|
||||
return new File(InferenceTest.class.getResource(path).getFile()).toPath();
|
||||
}
|
||||
|
||||
static float[] loadTensorFromFile(Path filename) {
|
||||
return loadTensorFromFile(filename, true);
|
||||
}
|
||||
|
||||
static float[] loadTensorFromFile(Path filename, boolean skipHeader) {
|
||||
// read data from file
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(filename.toFile()))) {
|
||||
if (skipHeader) {
|
||||
reader.readLine(); // skip the input name
|
||||
}
|
||||
String[] dataStr = LOAD_PATTERN.split(reader.readLine());
|
||||
List<Float> tensorData = new ArrayList<>();
|
||||
for (int i = 0; i < dataStr.length; i++) {
|
||||
if (!dataStr[i].isEmpty()) {
|
||||
tensorData.add(Float.parseFloat(dataStr[i]));
|
||||
}
|
||||
}
|
||||
return toPrimitiveFloat(tensorData);
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static TypeWidth getTypeAndWidth(OnnxMl.TensorProto.DataType elemType) {
|
||||
OnnxJavaType type;
|
||||
int width;
|
||||
switch (elemType) {
|
||||
case FLOAT:
|
||||
type = OnnxJavaType.FLOAT;
|
||||
width = 4;
|
||||
break;
|
||||
case UINT8:
|
||||
case INT8:
|
||||
type = OnnxJavaType.INT8;
|
||||
width = 1;
|
||||
break;
|
||||
case UINT16:
|
||||
case INT16:
|
||||
type = OnnxJavaType.INT16;
|
||||
width = 2;
|
||||
break;
|
||||
case INT32:
|
||||
case UINT32:
|
||||
type = OnnxJavaType.INT32;
|
||||
width = 4;
|
||||
break;
|
||||
case INT64:
|
||||
case UINT64:
|
||||
type = OnnxJavaType.INT64;
|
||||
width = 8;
|
||||
break;
|
||||
case STRING:
|
||||
type = OnnxJavaType.STRING;
|
||||
width = 1;
|
||||
break;
|
||||
case BOOL:
|
||||
type = OnnxJavaType.BOOL;
|
||||
width = 1;
|
||||
break;
|
||||
case FLOAT16:
|
||||
type = OnnxJavaType.FLOAT;
|
||||
width = 2;
|
||||
break;
|
||||
case DOUBLE:
|
||||
type = OnnxJavaType.DOUBLE;
|
||||
width = 8;
|
||||
break;
|
||||
default:
|
||||
type = null;
|
||||
width = 0;
|
||||
break;
|
||||
}
|
||||
return new TypeWidth(type, width);
|
||||
}
|
||||
|
||||
static StringTensorPair loadTensorFromFilePb(
|
||||
OrtEnvironment env, File filename, Map<String, NodeInfo> nodeMetaDict)
|
||||
throws IOException, OrtException {
|
||||
InputStream is = new BufferedInputStream(new FileInputStream(filename), 1024 * 1024 * 4);
|
||||
OnnxMl.TensorProto tensor = OnnxMl.TensorProto.parseFrom(is);
|
||||
is.close();
|
||||
|
||||
TypeWidth tw = getTypeAndWidth(OnnxMl.TensorProto.DataType.forNumber(tensor.getDataType()));
|
||||
int width = tw.width;
|
||||
OnnxJavaType tensorElemType = tw.type;
|
||||
long[] intDims = new long[tensor.getDimsCount()];
|
||||
for (int i = 0; i < tensor.getDimsCount(); i++) {
|
||||
intDims[i] = tensor.getDims(i);
|
||||
}
|
||||
|
||||
TensorInfo nodeMeta = null;
|
||||
String nodeName = "";
|
||||
if (nodeMetaDict.size() == 1) {
|
||||
for (Map.Entry<String, NodeInfo> e : nodeMetaDict.entrySet()) {
|
||||
nodeMeta = (TensorInfo) e.getValue().getInfo();
|
||||
nodeName = e.getKey(); // valid for single node input
|
||||
}
|
||||
} else if (nodeMetaDict.size() > 1) {
|
||||
if (!tensor.getName().isEmpty()) {
|
||||
nodeMeta = (TensorInfo) nodeMetaDict.get(tensor.getName()).getInfo();
|
||||
nodeName = tensor.getName();
|
||||
} else {
|
||||
boolean matchfound = false;
|
||||
// try to find from matching type and shape
|
||||
for (Map.Entry<String, NodeInfo> e : nodeMetaDict.entrySet()) {
|
||||
if (e.getValue().getInfo() instanceof TensorInfo) {
|
||||
TensorInfo meta = (TensorInfo) e.getValue().getInfo();
|
||||
if (tensorElemType == meta.type && tensor.getDimsCount() == meta.shape.length) {
|
||||
int i = 0;
|
||||
for (; i < meta.shape.length; i++) {
|
||||
if (meta.shape[i] != -1 && meta.shape[i] != intDims[i]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (i >= meta.shape.length) {
|
||||
matchfound = true;
|
||||
nodeMeta = meta;
|
||||
nodeName = e.getKey();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!matchfound) {
|
||||
// throw error
|
||||
throw new IllegalStateException(
|
||||
"No matching Tensor found in InputOutputMetadata corresponding to the serialized tensor loaded from "
|
||||
+ filename);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// throw error
|
||||
throw new IllegalStateException(
|
||||
"While reading the serialized tensor loaded from "
|
||||
+ filename
|
||||
+ ", metaDataDict has 0 elements");
|
||||
}
|
||||
|
||||
Assertions.assertEquals(tensorElemType, nodeMeta.type);
|
||||
Assertions.assertEquals(nodeMeta.shape.length, tensor.getDimsCount());
|
||||
for (int i = 0; i < nodeMeta.shape.length; i++) {
|
||||
Assertions.assertTrue((nodeMeta.shape[i] == -1) || (nodeMeta.shape[i] == intDims[i]));
|
||||
}
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.wrap(tensor.getRawData().toByteArray());
|
||||
|
||||
OnnxTensor onnxTensor = OnnxTensor.createTensor(env, buffer, intDims, tensorElemType);
|
||||
|
||||
return new StringTensorPair(nodeName, onnxTensor);
|
||||
}
|
||||
|
||||
private static class TypeWidth {
|
||||
public final OnnxJavaType type;
|
||||
public final int width;
|
||||
|
||||
public TypeWidth(OnnxJavaType type, int width) {
|
||||
this.type = type;
|
||||
this.width = width;
|
||||
}
|
||||
}
|
||||
|
||||
static class StringTensorPair {
|
||||
public final String string;
|
||||
public final OnnxTensor tensor;
|
||||
|
||||
public StringTensorPair(String string, OnnxTensor tensor) {
|
||||
this.string = string;
|
||||
this.tensor = tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -263,8 +263,8 @@ public class ScoreMNIST {
|
|||
return;
|
||||
}
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions opts = new SessionOptions()) {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
try (OrtSession.SessionOptions opts = new SessionOptions()) {
|
||||
|
||||
opts.setOptimizationLevel(OptLevel.BASIC_OPT);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue