[java] Changes OrtEnvironment so it can't be closed by users (#10670)

* Changes OrtEnvironment so it can't be closed by users.

* Fix the formatting and add a same instance check.
This commit is contained in:
Adam Pocock 2022-03-01 00:03:40 -05:00 committed by GitHub
parent e23a224518
commit f856608599
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 649 additions and 674 deletions

View file

@ -167,6 +167,7 @@ test {
java {
dependsOn spotlessJava
}
forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool
useJUnitPlatform()
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir

View file

@ -333,7 +333,7 @@ public class OnnxTensor implements OnnxValue {
*/
static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info = TensorInfo.constructFromJavaArray(data);
if (info.type == OnnxJavaType.STRING) {
if (info.shape.length == 0) {
@ -403,7 +403,7 @@ public class OnnxTensor implements OnnxValue {
*/
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info =
new TensorInfo(
shape,
@ -451,7 +451,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.FLOAT;
return createTensor(type, allocator, data, shape);
} else {
@ -492,7 +492,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.DOUBLE;
return createTensor(type, allocator, data, shape);
} else {
@ -571,7 +571,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
@ -611,7 +611,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT16;
return createTensor(type, allocator, data, shape);
} else {
@ -652,7 +652,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT32;
return createTensor(type, allocator, data, shape);
} else {
@ -693,7 +693,7 @@ public class OnnxTensor implements OnnxValue {
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT64;
return createTensor(type, allocator, data, shape);
} else {

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -7,14 +7,18 @@ package ai.onnxruntime;
import ai.onnxruntime.OrtSession.SessionOptions;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
/**
* The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate
* specific models.
*
* <p>There can be at most one OrtEnvironment object created in a JVM lifetime. This class
* implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but
* the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered
* on construction.
*/
public class OrtEnvironment implements AutoCloseable {
public final class OrtEnvironment implements AutoCloseable {
private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());
@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable {
private static volatile OrtEnvironment INSTANCE;
private static final AtomicInteger refCount = new AtomicInteger();
private static volatile OrtLoggingLevel curLogLevel;
private static volatile String curLoggingName;
@ -47,8 +49,6 @@ public class OrtEnvironment implements AutoCloseable {
// If there's no instance, create one.
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
} else {
// else return the current one.
refCount.incrementAndGet();
return INSTANCE;
}
}
@ -106,7 +106,6 @@ public class OrtEnvironment implements AutoCloseable {
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
}
}
refCount.incrementAndGet();
return INSTANCE;
}
@ -131,7 +130,6 @@ public class OrtEnvironment implements AutoCloseable {
} catch (OrtException e) {
throw new IllegalStateException("Failed to create OrtEnvironment", e);
}
refCount.incrementAndGet();
return INSTANCE;
} else {
// As the thread pool state is unknown, and that's probably not what the user wanted.
@ -144,8 +142,6 @@ public class OrtEnvironment implements AutoCloseable {
final OrtAllocator defaultAllocator;
private volatile boolean closed = false;
/**
* Create an OrtEnvironment using a default name.
*
@ -165,6 +161,8 @@ public class OrtEnvironment implements AutoCloseable {
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}
/**
@ -181,6 +179,8 @@ public class OrtEnvironment implements AutoCloseable {
createHandle(
OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}
/**
@ -219,11 +219,7 @@ public class OrtEnvironment implements AutoCloseable {
*/
OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelPath, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelPath, allocator, options);
}
/**
@ -262,11 +258,7 @@ public class OrtEnvironment implements AutoCloseable {
*/
OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelArray, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelArray, allocator, options);
}
/**
@ -279,41 +271,11 @@ public class OrtEnvironment implements AutoCloseable {
setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry);
}
/**
* Is this environment closed?
*
* @return True if the environment is closed.
*/
public boolean isClosed() {
return closed;
}
@Override
public String toString() {
return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")";
}
/**
* Closes the OrtEnvironment. If this is the last reference to the environment then it closes the
* native handle.
*
* @throws OrtException If the close failed.
*/
@Override
public synchronized void close() throws OrtException {
synchronized (refCount) {
int curCount = refCount.get();
if (curCount != 0) {
refCount.decrementAndGet();
}
if (curCount == 1) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
INSTANCE = null;
}
}
}
/**
* Gets the providers available in this environment.
*
@ -377,11 +339,22 @@ public class OrtEnvironment implements AutoCloseable {
private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry)
throws OrtException;
/** Close is a no-op on OrtEnvironment since ORT 1.11. */
@Override
public void close() {}
/**
* Controls the global thread pools in the environment. Only used if the session is constructed
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
*/
public static final class ThreadingOptions implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}
private final long nativeHandle;
@ -486,4 +459,24 @@ public class OrtEnvironment implements AutoCloseable {
private native void closeThreadingOptions(long apiHandle, long nativeHandle);
}
private static final class OrtEnvCloser implements Runnable {
private final long apiHandle;
private final long nativeHandle;
OrtEnvCloser(long apiHandle, long nativeHandle) {
this.apiHandle = apiHandle;
this.nativeHandle = nativeHandle;
}
@Override
public void run() {
try {
OrtEnvironment.close(apiHandle, nativeHandle);
} catch (OrtException e) {
System.err.println("Error closing OrtEnvironment, " + e);
}
}
}
}

View file

@ -0,0 +1,88 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import static ai.onnxruntime.TestHelpers.getResourcePath;
import static ai.onnxruntime.TestHelpers.loadTensorFromFile;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
/** This test is in a separate class to ensure it is run in a clean JVM. */
public class EnvironmentThreadPoolTest {
@Test
public void environmentThreadPoolTest() throws OrtException {
Path squeezeNet = getResourcePath("/squeezenet.onnx");
String modelPath = squeezeNet.toString();
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
Map<String, OnnxTensor> container = new HashMap<>();
OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
threadOpts.setGlobalInterOpNumThreads(2);
threadOpts.setGlobalIntraOpNumThreads(2);
threadOpts.setGlobalDenormalAsZero();
threadOpts.setGlobalSpinControl(true);
OrtEnvironment env =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions();
OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) {
disableThreadOptions.disablePerSessionThreads();
// Check that the regular session executes
try (OrtSession session = env.createSession(modelPath, options)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}
// Check that the session using the env thread pool executes
try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}
}
try {
OrtEnvironment newEnv =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
// fail as we can't recreate environments with different threading options
fail("Should have thrown IllegalStateException");
} catch (IllegalStateException e) {
// pass
}
threadOpts.close();
}
}

File diff suppressed because it is too large Load diff

View file

@ -12,117 +12,114 @@ public class TensorCreationTest {
@Test
public void testScalarCreation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
String[] stringValues = new String[] {"true", "false"};
for (String s : stringValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, s)) {
Assertions.assertEquals(s, t.getValue());
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
String[] stringValues = new String[] {"true", "false"};
for (String s : stringValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, s)) {
Assertions.assertEquals(s, t.getValue());
}
}
boolean[] boolValues = new boolean[] {true, false};
for (boolean b : boolValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, b)) {
Assertions.assertEquals(b, t.getValue());
}
boolean[] boolValues = new boolean[] {true, false};
for (boolean b : boolValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, b)) {
Assertions.assertEquals(b, t.getValue());
}
}
int[] intValues =
new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE};
for (int i : intValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, i)) {
Assertions.assertEquals(i, t.getValue());
}
int[] intValues =
new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE};
for (int i : intValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, i)) {
Assertions.assertEquals(i, t.getValue());
}
}
long[] longValues =
new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE};
for (long l : longValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, l)) {
Assertions.assertEquals(l, t.getValue());
}
long[] longValues =
new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE};
for (long l : longValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, l)) {
Assertions.assertEquals(l, t.getValue());
}
}
float[] floatValues =
new float[] {
-1.0f,
0.0f,
-0.0f,
1.0f,
1234.5678f,
-1234.5678f,
(float) Math.PI,
(float) Math.E,
Float.MAX_VALUE,
Float.MIN_VALUE
};
for (float f : floatValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, f)) {
Assertions.assertEquals(f, t.getValue());
}
float[] floatValues =
new float[] {
-1.0f,
0.0f,
-0.0f,
1.0f,
1234.5678f,
-1234.5678f,
(float) Math.PI,
(float) Math.E,
Float.MAX_VALUE,
Float.MIN_VALUE
};
for (float f : floatValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, f)) {
Assertions.assertEquals(f, t.getValue());
}
}
double[] doubleValues =
new double[] {
-1.0,
0.0,
-0.0,
1.0,
1234.5678,
-1234.5678,
Math.PI,
Math.E,
Double.MAX_VALUE,
Double.MIN_VALUE
};
for (double d : doubleValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, d)) {
Assertions.assertEquals(d, t.getValue());
}
double[] doubleValues =
new double[] {
-1.0,
0.0,
-0.0,
1.0,
1234.5678,
-1234.5678,
Math.PI,
Math.E,
Double.MAX_VALUE,
Double.MIN_VALUE
};
for (double d : doubleValues) {
try (OnnxTensor t = OnnxTensor.createTensor(env, d)) {
Assertions.assertEquals(d, t.getValue());
}
}
}
@Test
public void testStringCreation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
String[] output = (String[]) t.getValue();
Assertions.assertArrayEquals(arrValues, output);
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
String[] output = (String[]) t.getValue();
Assertions.assertArrayEquals(arrValues, output);
}
String[][] stringValues =
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
String[][] output = (String[][]) t.getValue();
Assertions.assertArrayEquals(stringValues, output);
}
String[][] stringValues =
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
String[][] output = (String[][]) t.getValue();
Assertions.assertArrayEquals(stringValues, output);
}
String[][][] deepStringValues =
new String[][][] {
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
};
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
String[][][] output = (String[][][]) t.getValue();
Assertions.assertArrayEquals(deepStringValues, output);
}
String[][][] deepStringValues =
new String[][][] {
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
};
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
String[][][] output = (String[][][]) t.getValue();
Assertions.assertArrayEquals(deepStringValues, output);
}
}
@Test
public void testUint8Creation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
byte[] buf = new byte[] {0, 1};
ByteBuffer data = ByteBuffer.wrap(buf);
long[] shape = new long[] {2};
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
OrtEnvironment env = OrtEnvironment.getEnvironment();
byte[] buf = new byte[] {0, 1};
ByteBuffer data = ByteBuffer.wrap(buf);
long[] shape = new long[] {2};
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}
}

View file

@ -1,16 +1,31 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.junit.jupiter.api.Assertions;
/** Test helpers for manipulating primitive arrays. */
class TestHelpers {
private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]");
static boolean[] toPrimitiveBoolean(List<Boolean> input) {
boolean[] output = new boolean[input.size()];
@ -234,4 +249,179 @@ class TestHelpers {
static void flattenStringBase(String[] input, List<String> output) {
output.addAll(Arrays.asList(input));
}
static Path getResourcePath(String path) {
return new File(InferenceTest.class.getResource(path).getFile()).toPath();
}
static float[] loadTensorFromFile(Path filename) {
return loadTensorFromFile(filename, true);
}
static float[] loadTensorFromFile(Path filename, boolean skipHeader) {
// read data from file
try (BufferedReader reader = new BufferedReader(new FileReader(filename.toFile()))) {
if (skipHeader) {
reader.readLine(); // skip the input name
}
String[] dataStr = LOAD_PATTERN.split(reader.readLine());
List<Float> tensorData = new ArrayList<>();
for (int i = 0; i < dataStr.length; i++) {
if (!dataStr[i].isEmpty()) {
tensorData.add(Float.parseFloat(dataStr[i]));
}
}
return toPrimitiveFloat(tensorData);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
private static TypeWidth getTypeAndWidth(OnnxMl.TensorProto.DataType elemType) {
OnnxJavaType type;
int width;
switch (elemType) {
case FLOAT:
type = OnnxJavaType.FLOAT;
width = 4;
break;
case UINT8:
case INT8:
type = OnnxJavaType.INT8;
width = 1;
break;
case UINT16:
case INT16:
type = OnnxJavaType.INT16;
width = 2;
break;
case INT32:
case UINT32:
type = OnnxJavaType.INT32;
width = 4;
break;
case INT64:
case UINT64:
type = OnnxJavaType.INT64;
width = 8;
break;
case STRING:
type = OnnxJavaType.STRING;
width = 1;
break;
case BOOL:
type = OnnxJavaType.BOOL;
width = 1;
break;
case FLOAT16:
type = OnnxJavaType.FLOAT;
width = 2;
break;
case DOUBLE:
type = OnnxJavaType.DOUBLE;
width = 8;
break;
default:
type = null;
width = 0;
break;
}
return new TypeWidth(type, width);
}
static StringTensorPair loadTensorFromFilePb(
OrtEnvironment env, File filename, Map<String, NodeInfo> nodeMetaDict)
throws IOException, OrtException {
InputStream is = new BufferedInputStream(new FileInputStream(filename), 1024 * 1024 * 4);
OnnxMl.TensorProto tensor = OnnxMl.TensorProto.parseFrom(is);
is.close();
TypeWidth tw = getTypeAndWidth(OnnxMl.TensorProto.DataType.forNumber(tensor.getDataType()));
int width = tw.width;
OnnxJavaType tensorElemType = tw.type;
long[] intDims = new long[tensor.getDimsCount()];
for (int i = 0; i < tensor.getDimsCount(); i++) {
intDims[i] = tensor.getDims(i);
}
TensorInfo nodeMeta = null;
String nodeName = "";
if (nodeMetaDict.size() == 1) {
for (Map.Entry<String, NodeInfo> e : nodeMetaDict.entrySet()) {
nodeMeta = (TensorInfo) e.getValue().getInfo();
nodeName = e.getKey(); // valid for single node input
}
} else if (nodeMetaDict.size() > 1) {
if (!tensor.getName().isEmpty()) {
nodeMeta = (TensorInfo) nodeMetaDict.get(tensor.getName()).getInfo();
nodeName = tensor.getName();
} else {
boolean matchfound = false;
// try to find from matching type and shape
for (Map.Entry<String, NodeInfo> e : nodeMetaDict.entrySet()) {
if (e.getValue().getInfo() instanceof TensorInfo) {
TensorInfo meta = (TensorInfo) e.getValue().getInfo();
if (tensorElemType == meta.type && tensor.getDimsCount() == meta.shape.length) {
int i = 0;
for (; i < meta.shape.length; i++) {
if (meta.shape[i] != -1 && meta.shape[i] != intDims[i]) {
break;
}
}
if (i >= meta.shape.length) {
matchfound = true;
nodeMeta = meta;
nodeName = e.getKey();
break;
}
}
}
}
if (!matchfound) {
// throw error
throw new IllegalStateException(
"No matching Tensor found in InputOutputMetadata corresponding to the serialized tensor loaded from "
+ filename);
}
}
} else {
// throw error
throw new IllegalStateException(
"While reading the serialized tensor loaded from "
+ filename
+ ", metaDataDict has 0 elements");
}
Assertions.assertEquals(tensorElemType, nodeMeta.type);
Assertions.assertEquals(nodeMeta.shape.length, tensor.getDimsCount());
for (int i = 0; i < nodeMeta.shape.length; i++) {
Assertions.assertTrue((nodeMeta.shape[i] == -1) || (nodeMeta.shape[i] == intDims[i]));
}
ByteBuffer buffer = ByteBuffer.wrap(tensor.getRawData().toByteArray());
OnnxTensor onnxTensor = OnnxTensor.createTensor(env, buffer, intDims, tensorElemType);
return new StringTensorPair(nodeName, onnxTensor);
}
private static class TypeWidth {
public final OnnxJavaType type;
public final int width;
public TypeWidth(OnnxJavaType type, int width) {
this.type = type;
this.width = width;
}
}
static class StringTensorPair {
public final String string;
public final OnnxTensor tensor;
public StringTensorPair(String string, OnnxTensor tensor) {
this.string = string;
this.tensor = tensor;
}
}
}

View file

@ -263,8 +263,8 @@ public class ScoreMNIST {
return;
}
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new SessionOptions()) {
OrtEnvironment env = OrtEnvironment.getEnvironment();
try (OrtSession.SessionOptions opts = new SessionOptions()) {
opts.setOptimizationLevel(OptLevel.BASIC_OPT);