From 9a6fa057c8ee536a75c13eba2968e11718accbe9 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 22 Jul 2021 16:19:49 -0400 Subject: [PATCH] [Java] Allow extraction of multidimensional String tensors (#8452) Fixing a bug where String tensors would always be single dimensional in Java. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 8 +++- .../src/main/java/ai/onnxruntime/OrtUtil.java | 41 +++++++++++++++++-- .../main/java/ai/onnxruntime/TensorInfo.java | 4 ++ .../java/ai/onnxruntime/InferenceTest.java | 28 +++++++------ .../ai/onnxruntime/TensorCreationTest.java | 31 ++++++++++++++ 5 files changed, 95 insertions(+), 17 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 1314bbccc5..29b19ffef5 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -97,7 +97,13 @@ public class OnnxTensor implements OnnxValue { } else { Object carrier = info.makeCarrier(); getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier); - return carrier; + if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape. + return OrtUtil.reshape((String[]) carrier, info.shape); + } else { + return carrier; + } } } diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index da92939ee7..cddc34f65f 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -165,6 +165,19 @@ public final class OrtUtil { return Array.newInstance(double.class, intShape); } + /** + * Creates a new String array of up to 8 dimensions, using the supplied shape. + * + *

+ * + * @param shape The shape of array to create. + * @return A double array. + */ + public static Object newStringArray(long[] shape) { + int[] intShape = transformShape(shape); + return Array.newInstance(String.class, intShape); + } + /** * Reshapes a boolean array into the desired n-dimensional array assuming the boolean array is * stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number @@ -270,6 +283,21 @@ public final class OrtUtil { return output; } + /** + * Reshapes a String array into the desired n-dimensional array assuming the String array is + * stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number + * of elements doesn't match between the shape and the input or the shape is invalid. + * + * @param input The double array. + * @param shape The desired shape. + * @return An n-dimensional String array. + */ + public static Object reshape(String[] input, long[] shape) { + Object output = OrtUtil.newStringArray(shape); + reshape(input, output, 0); + return output; + } + /** * Copies elements from the flat input array to the appropriate primitive array of the output. * Recursively calls itself as it traverses the output array. @@ -285,7 +313,8 @@ public final class OrtUtil { for (Object outputElement : outputArray) { Class outputElementClass = outputElement.getClass(); if (outputElementClass.isArray()) { - if (outputElementClass.getComponentType().isPrimitive()) { + Class componentType = outputElementClass.getComponentType(); + if (componentType.isPrimitive() || componentType == String.class) { int length = Array.getLength(outputElement); System.arraycopy(input, position, outputElement, 0, length); position += length; @@ -351,11 +380,15 @@ public final class OrtUtil { * @return A single dimensional String array. */ public static String[] flattenString(Object o) { - ArrayList output = new ArrayList<>(); + if (o instanceof String[]) { + return (String[]) o; + } else { + ArrayList output = new ArrayList<>(); - flattenString((Object[]) o, output); + flattenString((Object[]) o, output); - return output.toArray(new String[0]); + return output.toArray(new String[0]); + } } /** diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index a79f1f9786..550ddfe57f 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -160,6 +160,10 @@ public class TensorInfo implements ValueInfo { /** * Constructs an array the right shape and type to hold this tensor. * + *

Note for String tensors, this carrier is a single dimensional array with enough space for + * all elements as that's the expected format of the native code. It can be reshaped to the + * correct shape using {@link OrtUtil#reshape(String[],long[])}. + * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is * greater than an int). diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 67c0eb308e..38627d2ec7 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1476,13 +1476,15 @@ public class InferenceTest { OnnxValue firstOutput = outputs.get(0); assertTrue(firstOutput instanceof OnnxTensor); - String[] labelOutput = (String[]) firstOutput.getValue(); + String[][] labelOutput = (String[][]) firstOutput.getValue(); - assertEquals("this", labelOutput[0]); - assertEquals("is", labelOutput[1]); - assertEquals("identity", labelOutput[2]); - assertEquals("test \u263A", labelOutput[3]); - assertEquals(4, labelOutput.length); + assertEquals("this", labelOutput[0][0]); + assertEquals("is", labelOutput[0][1]); + assertEquals("identity", labelOutput[1][0]); + assertEquals("test \u263A", labelOutput[1][1]); + assertEquals(2, labelOutput.length); + assertEquals(2, labelOutput[0].length); + assertEquals(2, labelOutput[1].length); OnnxValue.close(container); container.clear(); @@ -1498,13 +1500,15 @@ public class InferenceTest { OnnxValue firstOutput = outputs.get(0); assertTrue(firstOutput instanceof OnnxTensor); - String[] labelOutput = (String[]) firstOutput.getValue(); + String[][] labelOutput = (String[][]) firstOutput.getValue(); - assertEquals("this", labelOutput[0]); - assertEquals("is", labelOutput[1]); - assertEquals("identity", labelOutput[2]); - assertEquals("test \u263A", labelOutput[3]); - assertEquals(4, labelOutput.length); + assertEquals("this", labelOutput[0][0]); + assertEquals("is", labelOutput[0][1]); + assertEquals("identity", labelOutput[1][0]); + assertEquals("test \u263A", labelOutput[1][1]); + assertEquals(2, labelOutput.length); + assertEquals(2, labelOutput[0].length); + assertEquals(2, labelOutput[1].length); } } } diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java index 3b26cfd9ce..8ebba91b31 100644 --- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java +++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java @@ -81,4 +81,35 @@ public class TensorCreationTest { } } } + + @Test + public void testStringCreation() throws OrtException { + try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { + String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"}; + try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { + Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape); + String[] output = (String[]) t.getValue(); + Assertions.assertArrayEquals(arrValues, output); + } + + String[][] stringValues = + new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}}; + try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) { + Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape); + String[][] output = (String[][]) t.getValue(); + Assertions.assertArrayEquals(stringValues, output); + } + + String[][][] deepStringValues = + new String[][][] { + {{"this", "is", "a"}, {"multi", "dimensional", "string"}}, + {{"with", "lots", "more"}, {"dimensions", "than", "before"}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) { + Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape); + String[][][] output = (String[][][]) t.getValue(); + Assertions.assertArrayEquals(deepStringValues, output); + } + } + } }