[java] Updating TensorInfo so it contains the named dimensions (#18962)

### Description
The Java `TensorInfo` object which is used to describe a tensor's shape,
along with the input and output placeholders for a model couldn't show
any symbolic/named dimensions in that tensor. Now this information is
stored in Java strings on construction and included in the toString.

### Motivation and Context
Setting symbolic dimensions required external information in Java, the
names were not discoverable from within the API.
This commit is contained in:
Adam Pocock 2024-01-15 17:42:50 -05:00 committed by GitHub
parent a97199c62d
commit 191525301f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 12 deletions

View file

@ -7,6 +7,7 @@ package ai.onnxruntime;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;
import java.util.stream.Collectors;
/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {
@ -159,6 +160,12 @@ public class TensorInfo implements ValueInfo {
/** The shape of the tensor. */
final long[] shape;
/** The names of the unbound dimensions. */
final String[] dimensionNames;
/** If there are non-empty dimension names */
private final boolean hasNames;
/** The Java type of this tensor. */
public final OnnxJavaType type;
@ -177,6 +184,9 @@ public class TensorInfo implements ValueInfo {
*/
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
this.shape = shape;
this.dimensionNames = new String[shape.length];
Arrays.fill(dimensionNames, "");
this.hasNames = false;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
@ -188,10 +198,20 @@ public class TensorInfo implements ValueInfo {
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param names The dimension names.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
this.dimensionNames = names;
boolean hasNames = false;
for (String s : names) {
if (!s.isEmpty()) {
hasNames = true;
break;
}
}
this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
@ -206,15 +226,42 @@ public class TensorInfo implements ValueInfo {
return Arrays.copyOf(shape, shape.length);
}
/**
* Get a copy of the tensor's named dimensions.
*
* @return A copof the tensor's named dimensions.
*/
public String[] getDimensionNames() {
return Arrays.copyOf(dimensionNames, dimensionNames.length);
}
@Override
public String toString() {
return "TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape)
+ ")";
String output =
"TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape);
if (hasNames) {
output =
output
+ ",dimNames=["
+ Arrays.stream(dimensionNames)
.map(
a -> {
if (a.isEmpty()) {
return "\"\"";
} else {
return a;
}
})
.collect(Collectors.joining(","))
+ "]";
}
output = output + ")";
return output;
}
/**

View file

@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
//printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;
// Create the string array for the names.
const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
if (dimensionNames == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
return NULL;
}
code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
if (code != ORT_OK) {
// extraction failed, exception has been thrown, return to Java.
free(dimensionNames);
return NULL;
}
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
for (size_t i = 0; i < numDim; i++) {
jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
}
free(dimensionNames);
// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([JI)V");
//printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([J[Ljava/lang/String;I)V");
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
return tensorInfo;
}

View file

@ -590,6 +590,12 @@ public class InferenceTest {
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
assertEquals(2, aInfo.dimensionNames.length);
assertEquals("n", aInfo.dimensionNames[0]);
assertEquals("", aInfo.dimensionNames[1]);
TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
assertEquals(1, bInfo.dimensionNames.length);
assertEquals("m", bInfo.dimensionNames[0]);
}
}
// Check that when the options are assigned it overrides the symbolic dimension