mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[java] Allows the creation and extraction of zero length tensors (#15116)
### Description Allows the creation of zero length tensors via the buffer path (the array path with zero length arrays still throws as the validation logic to check it's not ragged would require more intrusive revision), and allows the `tensor.getValue()` method to return a Java multidimensional array with a zero dimension. Also added a test for the creation and extraction behaviour. ### Motivation and Context The Python interface can return zero length tensors (e.g. if object detection doesn't find any objects), and before this PR in Java calling `tensor.getValue()` throws an exception with a confusing error message. Fixes #7270 & #15107.
This commit is contained in:
parent
9191e04259
commit
ef11032c89
4 changed files with 79 additions and 12 deletions
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -76,7 +76,10 @@ public class OnnxTensor extends OnnxTensorLike {
|
|||
}
|
||||
} else {
|
||||
Object carrier = info.makeCarrier();
|
||||
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
|
||||
if (info.getNumElements() > 0) {
|
||||
// If the tensor has values copy them out
|
||||
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
|
||||
}
|
||||
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
|
||||
// We read the strings out from native code in a flat array and then reshape
|
||||
// to the desired output shape.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -41,9 +41,9 @@ public final class OrtUtil {
|
|||
int[] newShape = new int[shape.length];
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
long curDim = shape[i];
|
||||
if (curDim < 1 || curDim > Integer.MAX_VALUE) {
|
||||
if (curDim < 0 || curDim > Integer.MAX_VALUE) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found "
|
||||
"Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found "
|
||||
+ Arrays.toString(shape));
|
||||
} else {
|
||||
newShape[i] = (int) curDim;
|
||||
|
|
@ -345,20 +345,23 @@ public final class OrtUtil {
|
|||
/**
|
||||
* Counts the number of elements stored in a Tensor of this shape.
|
||||
*
|
||||
* <p>Multiplies all the elements together if they are positive, throws an {@link
|
||||
* <p>Multiplies all the elements together if they are non-negative, throws an {@link
|
||||
* IllegalArgumentException} otherwise.
|
||||
*
|
||||
* @param shape The shape to use.
|
||||
* @return The number of elements.
|
||||
*/
|
||||
public static long elementCount(long[] shape) {
|
||||
// Java side tensors must be less than Integer.MAX_VALUE,
|
||||
// tensors created in native code can be larger, but are not usable in Java.
|
||||
// Tensors should not be able to be created which will overflow a 64-bit long.
|
||||
long count = 1;
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
if (shape[i] > 0) {
|
||||
if (shape[i] >= 0) {
|
||||
count *= shape[i];
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Received non-positive value in shape " + Arrays.toString(shape) + " .");
|
||||
"Received negative value in shape " + Arrays.toString(shape) + " .");
|
||||
}
|
||||
}
|
||||
return count;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -107,6 +107,9 @@ public class TensorInfo implements ValueInfo {
|
|||
/** The native type of this tensor. */
|
||||
public final OnnxTensorType onnxType;
|
||||
|
||||
/** The number of elements in this tensor. */
|
||||
final long numElements;
|
||||
|
||||
/**
|
||||
* Constructs a TensorInfo with the specified shape, Java type and native type.
|
||||
*
|
||||
|
|
@ -118,6 +121,7 @@ public class TensorInfo implements ValueInfo {
|
|||
this.shape = shape;
|
||||
this.type = type;
|
||||
this.onnxType = onnxType;
|
||||
this.numElements = elementCount(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -132,6 +136,7 @@ public class TensorInfo implements ValueInfo {
|
|||
this.shape = shape;
|
||||
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
|
||||
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
|
||||
this.numElements = elementCount(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -173,6 +178,39 @@ public class TensorInfo implements ValueInfo {
|
|||
return OrtUtil.validateShape(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the number of elements in this tensor.
|
||||
*
|
||||
* <p>This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which
|
||||
* are used for symbolic dimensions in input and output info objects.
|
||||
*
|
||||
* @param shape The tensor shape.
|
||||
* @return The number of elements.
|
||||
*/
|
||||
private static long elementCount(long[] shape) {
|
||||
// Java side tensors must be less than Integer.MAX_VALUE,
|
||||
// tensors created in native code can be larger, but are not usable in Java.
|
||||
// Tensors should not be able to be created which will overflow a 64-bit long.
|
||||
long output = 1;
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
output *= shape[i];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of elements in this tensor.
|
||||
*
|
||||
* <p>If the returned value is negative, then this tensor info refers to an input or output
|
||||
* placeholder which has symbolic dimensions, and the element count cannot be computed without
|
||||
* specifying the symbolic dimensions.
|
||||
*
|
||||
* @return The number of elements.
|
||||
*/
|
||||
public long getNumElements() {
|
||||
return numElements;
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs an array the right shape and type to hold this tensor.
|
||||
*
|
||||
|
|
@ -181,11 +219,12 @@ public class TensorInfo implements ValueInfo {
|
|||
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
|
||||
*
|
||||
* @return A multidimensional array of the appropriate primitive type (or String).
|
||||
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
|
||||
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
|
||||
* greater than an int).
|
||||
*/
|
||||
public Object makeCarrier() throws OrtException {
|
||||
if (!validateShape()) {
|
||||
// Zero length tensors are allowed to be returned.
|
||||
if (!validateShape() && numElements != 0) {
|
||||
throw new OrtException(
|
||||
"This tensor is not representable in Java, it's too big - shape = "
|
||||
+ Arrays.toString(shape));
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
/*
|
||||
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -122,4 +123,25 @@ public class TensorCreationTest {
|
|||
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyTensor() throws OrtException {
|
||||
OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
FloatBuffer buf = FloatBuffer.allocate(0);
|
||||
long[] shape = new long[] {4, 0};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
|
||||
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
|
||||
float[][] output = (float[][]) t.getValue();
|
||||
Assertions.assertEquals(4, output.length);
|
||||
Assertions.assertEquals(0, output[0].length);
|
||||
FloatBuffer fb = t.getFloatBuffer();
|
||||
Assertions.assertEquals(0, fb.remaining());
|
||||
}
|
||||
shape = new long[] {0, 4};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
|
||||
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
|
||||
float[][] output = (float[][]) t.getValue();
|
||||
Assertions.assertEquals(0, output.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue