diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8..1c21387b50 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ package ai.onnxruntime; import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,12 @@ public class TensorInfo implements ValueInfo { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + + /** If there are non-empty dimension names */ + private final boolean hasNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +184,9 @@ public class TensorInfo implements ValueInfo { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); + this.hasNames = false; this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +198,20 @@ public class TensorInfo implements ValueInfo { *

Called from JNI. * * @param shape The tensor shape. + * @param names The dimension names. * @param typeInt The native type int. */ - TensorInfo(long[] shape, int typeInt) { + TensorInfo(long[] shape, String[] names, int typeInt) { this.shape = shape; + this.dimensionNames = names; + boolean hasNames = false; + for (String s : names) { + if (!s.isEmpty()) { + hasNames = true; + break; + } + } + this.hasNames = hasNames; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); this.numElements = elementCount(shape); @@ -206,15 +226,42 @@ public class TensorInfo implements ValueInfo { return Arrays.copyOf(shape, shape.length); } + /** + * Get a copy of the tensor's named dimensions. + * + * @return A copof the tensor's named dimensions. + */ + public String[] getDimensionNames() { + return Arrays.copyOf(dimensionNames, dimensionNames.length); + } + @Override public String toString() { - return "TensorInfo(javaType=" - + type.toString() - + ",onnxType=" - + onnxType.toString() - + ",shape=" - + Arrays.toString(shape) - + ")"; + String output = + "TensorInfo(javaType=" + + type.toString() + + ",onnxType=" + + onnxType.toString() + + ",shape=" + + Arrays.toString(shape); + if (hasNames) { + output = + output + + ",dimNames=[" + + Arrays.stream(dimensionNames) + .map( + a -> { + if (a.isEmpty()) { + return "\"\""; + } else { + return a; + } + }) + .collect(Collectors.joining(",")) + + "]"; + } + output = output + ")"; + return output; } /** diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 879ba8a310..7b26291581 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT if (code != ORT_OK) { return NULL; } - //printf("numDim %d\n",numDim); int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); if (code != ORT_OK) { @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT free(dimensions); dimensions = NULL; + // Create the string array for the names. + const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim); + if (dimensionNames == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim)); + if (code != ORT_OK) { + // extraction failed, exception has been thrown, return to Java. + free(dimensionNames); + return NULL; + } + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL); + for (size_t i = 0; i < numDim; i++) { + jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName); + } + free(dimensionNames); + // Create the TensorInfo object static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V"); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt); return tensorInfo; } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f6f9da1829..7fef2dc784 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -590,6 +590,12 @@ public class InferenceTest { Map infoMap = session.getInputInfo(); TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + assertEquals(2, aInfo.dimensionNames.length); + assertEquals("n", aInfo.dimensionNames[0]); + assertEquals("", aInfo.dimensionNames[1]); + TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo(); + assertEquals(1, bInfo.dimensionNames.length); + assertEquals("m", bInfo.dimensionNames[0]); } } // Check that when the options are assigned it overrides the symbolic dimension