[java] Allows the creation and extraction of zero length tensors (#15116)

### Description
Allows the creation of zero length tensors via the buffer path (the
array path with zero length arrays still throws as the validation logic
to check it's not ragged would require more intrusive revision), and
allows the `tensor.getValue()` method to return a Java multidimensional
array with a zero dimension. Also added a test for the creation and
extraction behaviour.

### Motivation and Context
The Python interface can return zero length tensors (e.g. if object
detection doesn't find any objects), and before this PR in Java calling
`tensor.getValue()` throws an exception with a confusing error message.
Fixes #7270 & #15107.
This commit is contained in:
Adam Pocock 2023-04-05 18:49:59 +01:00 committed by GitHub
parent 9191e04259
commit ef11032c89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 12 deletions

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -76,7 +76,10 @@ public class OnnxTensor extends OnnxTensorLike {
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
}
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -41,9 +41,9 @@ public final class OrtUtil {
int[] newShape = new int[shape.length];
for (int i = 0; i < shape.length; i++) {
long curDim = shape[i];
if (curDim < 1 || curDim > Integer.MAX_VALUE) {
if (curDim < 0 || curDim > Integer.MAX_VALUE) {
throw new IllegalArgumentException(
"Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found "
"Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found "
+ Arrays.toString(shape));
} else {
newShape[i] = (int) curDim;
@ -345,20 +345,23 @@ public final class OrtUtil {
/**
* Counts the number of elements stored in a Tensor of this shape.
*
* <p>Multiplies all the elements together if they are positive, throws an {@link
* <p>Multiplies all the elements together if they are non-negative, throws an {@link
* IllegalArgumentException} otherwise.
*
* @param shape The shape to use.
* @return The number of elements.
*/
public static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long count = 1;
for (int i = 0; i < shape.length; i++) {
if (shape[i] > 0) {
if (shape[i] >= 0) {
count *= shape[i];
} else {
throw new IllegalArgumentException(
"Received non-positive value in shape " + Arrays.toString(shape) + " .");
"Received negative value in shape " + Arrays.toString(shape) + " .");
}
}
return count;

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -107,6 +107,9 @@ public class TensorInfo implements ValueInfo {
/** The native type of this tensor. */
public final OnnxTensorType onnxType;
/** The number of elements in this tensor. */
final long numElements;
/**
* Constructs a TensorInfo with the specified shape, Java type and native type.
*
@ -118,6 +121,7 @@ public class TensorInfo implements ValueInfo {
this.shape = shape;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
}
/**
@ -132,6 +136,7 @@ public class TensorInfo implements ValueInfo {
this.shape = shape;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
}
/**
@ -173,6 +178,39 @@ public class TensorInfo implements ValueInfo {
return OrtUtil.validateShape(shape);
}
/**
* Computes the number of elements in this tensor.
*
* <p>This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which
* are used for symbolic dimensions in input and output info objects.
*
* @param shape The tensor shape.
* @return The number of elements.
*/
private static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long output = 1;
for (int i = 0; i < shape.length; i++) {
output *= shape[i];
}
return output;
}
/**
* Returns the number of elements in this tensor.
*
* <p>If the returned value is negative, then this tensor info refers to an input or output
* placeholder which has symbolic dimensions, and the element count cannot be computed without
* specifying the symbolic dimensions.
*
* @return The number of elements.
*/
public long getNumElements() {
return numElements;
}
/**
* Constructs an array the right shape and type to hold this tensor.
*
@ -181,11 +219,12 @@ public class TensorInfo implements ValueInfo {
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
* greater than an int).
*/
public Object makeCarrier() throws OrtException {
if (!validateShape()) {
// Zero length tensors are allowed to be returned.
if (!validateShape() && numElements != 0) {
throw new OrtException(
"This tensor is not representable in Java, it's too big - shape = "
+ Arrays.toString(shape));

View file

@ -1,10 +1,11 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@ -122,4 +123,25 @@ public class TensorCreationTest {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}
@Test
public void testEmptyTensor() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
FloatBuffer buf = FloatBuffer.allocate(0);
long[] shape = new long[] {4, 0};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(4, output.length);
Assertions.assertEquals(0, output[0].length);
FloatBuffer fb = t.getFloatBuffer();
Assertions.assertEquals(0, fb.remaining());
}
shape = new long[] {0, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(0, output.length);
}
}
}