mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[Java] Allow extraction of multidimensional String tensors (#8452)
Fixing a bug where String tensors would always be single dimensional in Java.
This commit is contained in:
parent
287a2a778f
commit
9a6fa057c8
5 changed files with 95 additions and 17 deletions
|
|
@ -97,7 +97,13 @@ public class OnnxTensor implements OnnxValue {
|
|||
} else {
|
||||
Object carrier = info.makeCarrier();
|
||||
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
|
||||
return carrier;
|
||||
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
|
||||
// We read the strings out from native code in a flat array and then reshape
|
||||
// to the desired output shape.
|
||||
return OrtUtil.reshape((String[]) carrier, info.shape);
|
||||
} else {
|
||||
return carrier;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -165,6 +165,19 @@ public final class OrtUtil {
|
|||
return Array.newInstance(double.class, intShape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new String array of up to 8 dimensions, using the supplied shape.
|
||||
*
|
||||
* <p>
|
||||
*
|
||||
* @param shape The shape of array to create.
|
||||
* @return A double array.
|
||||
*/
|
||||
public static Object newStringArray(long[] shape) {
|
||||
int[] intShape = transformShape(shape);
|
||||
return Array.newInstance(String.class, intShape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reshapes a boolean array into the desired n-dimensional array assuming the boolean array is
|
||||
* stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
|
||||
|
|
@ -270,6 +283,21 @@ public final class OrtUtil {
|
|||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reshapes a String array into the desired n-dimensional array assuming the String array is
|
||||
* stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
|
||||
* of elements doesn't match between the shape and the input or the shape is invalid.
|
||||
*
|
||||
* @param input The double array.
|
||||
* @param shape The desired shape.
|
||||
* @return An n-dimensional String array.
|
||||
*/
|
||||
public static Object reshape(String[] input, long[] shape) {
|
||||
Object output = OrtUtil.newStringArray(shape);
|
||||
reshape(input, output, 0);
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies elements from the flat input array to the appropriate primitive array of the output.
|
||||
* Recursively calls itself as it traverses the output array.
|
||||
|
|
@ -285,7 +313,8 @@ public final class OrtUtil {
|
|||
for (Object outputElement : outputArray) {
|
||||
Class<?> outputElementClass = outputElement.getClass();
|
||||
if (outputElementClass.isArray()) {
|
||||
if (outputElementClass.getComponentType().isPrimitive()) {
|
||||
Class<?> componentType = outputElementClass.getComponentType();
|
||||
if (componentType.isPrimitive() || componentType == String.class) {
|
||||
int length = Array.getLength(outputElement);
|
||||
System.arraycopy(input, position, outputElement, 0, length);
|
||||
position += length;
|
||||
|
|
@ -351,11 +380,15 @@ public final class OrtUtil {
|
|||
* @return A single dimensional String array.
|
||||
*/
|
||||
public static String[] flattenString(Object o) {
|
||||
ArrayList<String> output = new ArrayList<>();
|
||||
if (o instanceof String[]) {
|
||||
return (String[]) o;
|
||||
} else {
|
||||
ArrayList<String> output = new ArrayList<>();
|
||||
|
||||
flattenString((Object[]) o, output);
|
||||
flattenString((Object[]) o, output);
|
||||
|
||||
return output.toArray(new String[0]);
|
||||
return output.toArray(new String[0]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -160,6 +160,10 @@ public class TensorInfo implements ValueInfo {
|
|||
/**
|
||||
* Constructs an array the right shape and type to hold this tensor.
|
||||
*
|
||||
* <p>Note for String tensors, this carrier is a single dimensional array with enough space for
|
||||
* all elements as that's the expected format of the native code. It can be reshaped to the
|
||||
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
|
||||
*
|
||||
* @return A multidimensional array of the appropriate primitive type (or String).
|
||||
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
|
||||
* greater than an int).
|
||||
|
|
|
|||
|
|
@ -1476,13 +1476,15 @@ public class InferenceTest {
|
|||
OnnxValue firstOutput = outputs.get(0);
|
||||
assertTrue(firstOutput instanceof OnnxTensor);
|
||||
|
||||
String[] labelOutput = (String[]) firstOutput.getValue();
|
||||
String[][] labelOutput = (String[][]) firstOutput.getValue();
|
||||
|
||||
assertEquals("this", labelOutput[0]);
|
||||
assertEquals("is", labelOutput[1]);
|
||||
assertEquals("identity", labelOutput[2]);
|
||||
assertEquals("test \u263A", labelOutput[3]);
|
||||
assertEquals(4, labelOutput.length);
|
||||
assertEquals("this", labelOutput[0][0]);
|
||||
assertEquals("is", labelOutput[0][1]);
|
||||
assertEquals("identity", labelOutput[1][0]);
|
||||
assertEquals("test \u263A", labelOutput[1][1]);
|
||||
assertEquals(2, labelOutput.length);
|
||||
assertEquals(2, labelOutput[0].length);
|
||||
assertEquals(2, labelOutput[1].length);
|
||||
|
||||
OnnxValue.close(container);
|
||||
container.clear();
|
||||
|
|
@ -1498,13 +1500,15 @@ public class InferenceTest {
|
|||
OnnxValue firstOutput = outputs.get(0);
|
||||
assertTrue(firstOutput instanceof OnnxTensor);
|
||||
|
||||
String[] labelOutput = (String[]) firstOutput.getValue();
|
||||
String[][] labelOutput = (String[][]) firstOutput.getValue();
|
||||
|
||||
assertEquals("this", labelOutput[0]);
|
||||
assertEquals("is", labelOutput[1]);
|
||||
assertEquals("identity", labelOutput[2]);
|
||||
assertEquals("test \u263A", labelOutput[3]);
|
||||
assertEquals(4, labelOutput.length);
|
||||
assertEquals("this", labelOutput[0][0]);
|
||||
assertEquals("is", labelOutput[0][1]);
|
||||
assertEquals("identity", labelOutput[1][0]);
|
||||
assertEquals("test \u263A", labelOutput[1][1]);
|
||||
assertEquals(2, labelOutput.length);
|
||||
assertEquals(2, labelOutput[0].length);
|
||||
assertEquals(2, labelOutput[1].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,4 +81,35 @@ public class TensorCreationTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringCreation() throws OrtException {
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
|
||||
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
|
||||
String[] output = (String[]) t.getValue();
|
||||
Assertions.assertArrayEquals(arrValues, output);
|
||||
}
|
||||
|
||||
String[][] stringValues =
|
||||
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
|
||||
String[][] output = (String[][]) t.getValue();
|
||||
Assertions.assertArrayEquals(stringValues, output);
|
||||
}
|
||||
|
||||
String[][][] deepStringValues =
|
||||
new String[][][] {
|
||||
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
|
||||
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
|
||||
};
|
||||
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
|
||||
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
|
||||
String[][][] output = (String[][][]) t.getValue();
|
||||
Assertions.assertArrayEquals(deepStringValues, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue