[Java] Allow extraction of multidimensional String tensors (#8452)

Fixing a bug where String tensors would always be single dimensional in Java.
This commit is contained in:
Adam Pocock 2021-07-22 16:19:49 -04:00 committed by GitHub
parent 287a2a778f
commit 9a6fa057c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 95 additions and 17 deletions

View file

@ -97,7 +97,13 @@ public class OnnxTensor implements OnnxValue {
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
return carrier;
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
return OrtUtil.reshape((String[]) carrier, info.shape);
} else {
return carrier;
}
}
}

View file

@ -165,6 +165,19 @@ public final class OrtUtil {
return Array.newInstance(double.class, intShape);
}
/**
* Creates a new String array of up to 8 dimensions, using the supplied shape.
*
* <p>
*
* @param shape The shape of array to create.
* @return A double array.
*/
public static Object newStringArray(long[] shape) {
int[] intShape = transformShape(shape);
return Array.newInstance(String.class, intShape);
}
/**
* Reshapes a boolean array into the desired n-dimensional array assuming the boolean array is
* stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
@ -270,6 +283,21 @@ public final class OrtUtil {
return output;
}
/**
* Reshapes a String array into the desired n-dimensional array assuming the String array is
* stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
* of elements doesn't match between the shape and the input or the shape is invalid.
*
* @param input The double array.
* @param shape The desired shape.
* @return An n-dimensional String array.
*/
public static Object reshape(String[] input, long[] shape) {
Object output = OrtUtil.newStringArray(shape);
reshape(input, output, 0);
return output;
}
/**
* Copies elements from the flat input array to the appropriate primitive array of the output.
* Recursively calls itself as it traverses the output array.
@ -285,7 +313,8 @@ public final class OrtUtil {
for (Object outputElement : outputArray) {
Class<?> outputElementClass = outputElement.getClass();
if (outputElementClass.isArray()) {
if (outputElementClass.getComponentType().isPrimitive()) {
Class<?> componentType = outputElementClass.getComponentType();
if (componentType.isPrimitive() || componentType == String.class) {
int length = Array.getLength(outputElement);
System.arraycopy(input, position, outputElement, 0, length);
position += length;
@ -351,11 +380,15 @@ public final class OrtUtil {
* @return A single dimensional String array.
*/
public static String[] flattenString(Object o) {
ArrayList<String> output = new ArrayList<>();
if (o instanceof String[]) {
return (String[]) o;
} else {
ArrayList<String> output = new ArrayList<>();
flattenString((Object[]) o, output);
flattenString((Object[]) o, output);
return output.toArray(new String[0]);
return output.toArray(new String[0]);
}
}
/**

View file

@ -160,6 +160,10 @@ public class TensorInfo implements ValueInfo {
/**
* Constructs an array the right shape and type to hold this tensor.
*
* <p>Note for String tensors, this carrier is a single dimensional array with enough space for
* all elements as that's the expected format of the native code. It can be reshaped to the
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
* greater than an int).

View file

@ -1476,13 +1476,15 @@ public class InferenceTest {
OnnxValue firstOutput = outputs.get(0);
assertTrue(firstOutput instanceof OnnxTensor);
String[] labelOutput = (String[]) firstOutput.getValue();
String[][] labelOutput = (String[][]) firstOutput.getValue();
assertEquals("this", labelOutput[0]);
assertEquals("is", labelOutput[1]);
assertEquals("identity", labelOutput[2]);
assertEquals("test \u263A", labelOutput[3]);
assertEquals(4, labelOutput.length);
assertEquals("this", labelOutput[0][0]);
assertEquals("is", labelOutput[0][1]);
assertEquals("identity", labelOutput[1][0]);
assertEquals("test \u263A", labelOutput[1][1]);
assertEquals(2, labelOutput.length);
assertEquals(2, labelOutput[0].length);
assertEquals(2, labelOutput[1].length);
OnnxValue.close(container);
container.clear();
@ -1498,13 +1500,15 @@ public class InferenceTest {
OnnxValue firstOutput = outputs.get(0);
assertTrue(firstOutput instanceof OnnxTensor);
String[] labelOutput = (String[]) firstOutput.getValue();
String[][] labelOutput = (String[][]) firstOutput.getValue();
assertEquals("this", labelOutput[0]);
assertEquals("is", labelOutput[1]);
assertEquals("identity", labelOutput[2]);
assertEquals("test \u263A", labelOutput[3]);
assertEquals(4, labelOutput.length);
assertEquals("this", labelOutput[0][0]);
assertEquals("is", labelOutput[0][1]);
assertEquals("identity", labelOutput[1][0]);
assertEquals("test \u263A", labelOutput[1][1]);
assertEquals(2, labelOutput.length);
assertEquals(2, labelOutput[0].length);
assertEquals(2, labelOutput[1].length);
}
}
}

View file

@ -81,4 +81,35 @@ public class TensorCreationTest {
}
}
}
@Test
public void testStringCreation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
String[] output = (String[]) t.getValue();
Assertions.assertArrayEquals(arrValues, output);
}
String[][] stringValues =
new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
String[][] output = (String[][]) t.getValue();
Assertions.assertArrayEquals(stringValues, output);
}
String[][][] deepStringValues =
new String[][][] {
{{"this", "is", "a"}, {"multi", "dimensional", "string"}},
{{"with", "lots", "more"}, {"dimensions", "than", "before"}}
};
try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
String[][][] output = (String[][][]) t.getValue();
Assertions.assertArrayEquals(deepStringValues, output);
}
}
}
}