diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 1314bbccc5..29b19ffef5 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -97,7 +97,13 @@ public class OnnxTensor implements OnnxValue { } else { Object carrier = info.makeCarrier(); getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier); - return carrier; + if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape. + return OrtUtil.reshape((String[]) carrier, info.shape); + } else { + return carrier; + } } } diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index da92939ee7..cddc34f65f 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -165,6 +165,19 @@ public final class OrtUtil { return Array.newInstance(double.class, intShape); } + /** + * Creates a new String array of up to 8 dimensions, using the supplied shape. + * + *
+ *
+ * @param shape The shape of array to create.
+ * @return A double array.
+ */
+ public static Object newStringArray(long[] shape) {
+ int[] intShape = transformShape(shape);
+ return Array.newInstance(String.class, intShape);
+ }
+
/**
* Reshapes a boolean array into the desired n-dimensional array assuming the boolean array is
* stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
@@ -270,6 +283,21 @@ public final class OrtUtil {
return output;
}
+ /**
+ * Reshapes a String array into the desired n-dimensional array assuming the String array is
+ * stored in n-dimensional row-major order. Throws {@link IllegalArgumentException} if the number
+ * of elements doesn't match between the shape and the input or the shape is invalid.
+ *
+ * @param input The double array.
+ * @param shape The desired shape.
+ * @return An n-dimensional String array.
+ */
+ public static Object reshape(String[] input, long[] shape) {
+ Object output = OrtUtil.newStringArray(shape);
+ reshape(input, output, 0);
+ return output;
+ }
+
/**
* Copies elements from the flat input array to the appropriate primitive array of the output.
* Recursively calls itself as it traverses the output array.
@@ -285,7 +313,8 @@ public final class OrtUtil {
for (Object outputElement : outputArray) {
Class> outputElementClass = outputElement.getClass();
if (outputElementClass.isArray()) {
- if (outputElementClass.getComponentType().isPrimitive()) {
+ Class> componentType = outputElementClass.getComponentType();
+ if (componentType.isPrimitive() || componentType == String.class) {
int length = Array.getLength(outputElement);
System.arraycopy(input, position, outputElement, 0, length);
position += length;
@@ -351,11 +380,15 @@ public final class OrtUtil {
* @return A single dimensional String array.
*/
public static String[] flattenString(Object o) {
- ArrayList Note for String tensors, this carrier is a single dimensional array with enough space for
+ * all elements as that's the expected format of the native code. It can be reshaped to the
+ * correct shape using {@link OrtUtil#reshape(String[],long[])}.
+ *
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
* greater than an int).
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 67c0eb308e..38627d2ec7 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -1476,13 +1476,15 @@ public class InferenceTest {
OnnxValue firstOutput = outputs.get(0);
assertTrue(firstOutput instanceof OnnxTensor);
- String[] labelOutput = (String[]) firstOutput.getValue();
+ String[][] labelOutput = (String[][]) firstOutput.getValue();
- assertEquals("this", labelOutput[0]);
- assertEquals("is", labelOutput[1]);
- assertEquals("identity", labelOutput[2]);
- assertEquals("test \u263A", labelOutput[3]);
- assertEquals(4, labelOutput.length);
+ assertEquals("this", labelOutput[0][0]);
+ assertEquals("is", labelOutput[0][1]);
+ assertEquals("identity", labelOutput[1][0]);
+ assertEquals("test \u263A", labelOutput[1][1]);
+ assertEquals(2, labelOutput.length);
+ assertEquals(2, labelOutput[0].length);
+ assertEquals(2, labelOutput[1].length);
OnnxValue.close(container);
container.clear();
@@ -1498,13 +1500,15 @@ public class InferenceTest {
OnnxValue firstOutput = outputs.get(0);
assertTrue(firstOutput instanceof OnnxTensor);
- String[] labelOutput = (String[]) firstOutput.getValue();
+ String[][] labelOutput = (String[][]) firstOutput.getValue();
- assertEquals("this", labelOutput[0]);
- assertEquals("is", labelOutput[1]);
- assertEquals("identity", labelOutput[2]);
- assertEquals("test \u263A", labelOutput[3]);
- assertEquals(4, labelOutput.length);
+ assertEquals("this", labelOutput[0][0]);
+ assertEquals("is", labelOutput[0][1]);
+ assertEquals("identity", labelOutput[1][0]);
+ assertEquals("test \u263A", labelOutput[1][1]);
+ assertEquals(2, labelOutput.length);
+ assertEquals(2, labelOutput[0].length);
+ assertEquals(2, labelOutput[1].length);
}
}
}
diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java
index 3b26cfd9ce..8ebba91b31 100644
--- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java
+++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java
@@ -81,4 +81,35 @@ public class TensorCreationTest {
}
}
}
+
+ @Test
+ public void testStringCreation() throws OrtException {
+ try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
+ String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"};
+ try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
+ Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape);
+ String[] output = (String[]) t.getValue();
+ Assertions.assertArrayEquals(arrValues, output);
+ }
+
+ String[][] stringValues =
+ new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}};
+ try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) {
+ Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape);
+ String[][] output = (String[][]) t.getValue();
+ Assertions.assertArrayEquals(stringValues, output);
+ }
+
+ String[][][] deepStringValues =
+ new String[][][] {
+ {{"this", "is", "a"}, {"multi", "dimensional", "string"}},
+ {{"with", "lots", "more"}, {"dimensions", "than", "before"}}
+ };
+ try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) {
+ Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape);
+ String[][][] output = (String[][][]) t.getValue();
+ Assertions.assertArrayEquals(deepStringValues, output);
+ }
+ }
+ }
}