diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8..1c21387b50 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ package ai.onnxruntime; import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,12 @@ public class TensorInfo implements ValueInfo { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + + /** If there are non-empty dimension names */ + private final boolean hasNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +184,9 @@ public class TensorInfo implements ValueInfo { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); + this.hasNames = false; this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +198,20 @@ public class TensorInfo implements ValueInfo { *
Called from JNI.
*
* @param shape The tensor shape.
+ * @param names The dimension names.
* @param typeInt The native type int.
*/
- TensorInfo(long[] shape, int typeInt) {
+ TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
+ this.dimensionNames = names;
+ boolean hasNames = false;
+ for (String s : names) {
+ if (!s.isEmpty()) {
+ hasNames = true;
+ break;
+ }
+ }
+ this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
@@ -206,15 +226,42 @@ public class TensorInfo implements ValueInfo {
return Arrays.copyOf(shape, shape.length);
}
+ /**
+ * Get a copy of the tensor's named dimensions.
+ *
+ * @return A copof the tensor's named dimensions.
+ */
+ public String[] getDimensionNames() {
+ return Arrays.copyOf(dimensionNames, dimensionNames.length);
+ }
+
@Override
public String toString() {
- return "TensorInfo(javaType="
- + type.toString()
- + ",onnxType="
- + onnxType.toString()
- + ",shape="
- + Arrays.toString(shape)
- + ")";
+ String output =
+ "TensorInfo(javaType="
+ + type.toString()
+ + ",onnxType="
+ + onnxType.toString()
+ + ",shape="
+ + Arrays.toString(shape);
+ if (hasNames) {
+ output =
+ output
+ + ",dimNames=["
+ + Arrays.stream(dimensionNames)
+ .map(
+ a -> {
+ if (a.isEmpty()) {
+ return "\"\"";
+ } else {
+ return a;
+ }
+ })
+ .collect(Collectors.joining(","))
+ + "]";
+ }
+ output = output + ")";
+ return output;
}
/**
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 879ba8a310..7b26291581 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
- //printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
@@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;
+ // Create the string array for the names.
+ const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
+ if (dimensionNames == NULL) {
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return NULL;
+ }
+ code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
+ if (code != ORT_OK) {
+ // extraction failed, exception has been thrown, return to Java.
+ free(dimensionNames);
+ return NULL;
+ }
+ jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
+ jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
+ for (size_t i = 0; i < numDim; i++) {
+ jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
+ (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
+ }
+ free(dimensionNames);
+
// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
- jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "