mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[java] Updating TensorInfo so it contains the named dimensions (#18962)
### Description The Java `TensorInfo` object which is used to describe a tensor's shape, along with the input and output placeholders for a model couldn't show any symbolic/named dimensions in that tensor. Now this information is stored in Java strings on construction and included in the toString. ### Motivation and Context Setting symbolic dimensions required external information in Java, the names were not discoverable from within the API.
This commit is contained in:
parent
a97199c62d
commit
191525301f
3 changed files with 83 additions and 12 deletions
|
|
@ -7,6 +7,7 @@ package ai.onnxruntime;
|
|||
import java.lang.reflect.Array;
|
||||
import java.nio.Buffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
|
||||
public class TensorInfo implements ValueInfo {
|
||||
|
|
@ -159,6 +160,12 @@ public class TensorInfo implements ValueInfo {
|
|||
/** The shape of the tensor. */
|
||||
final long[] shape;
|
||||
|
||||
/** The names of the unbound dimensions. */
|
||||
final String[] dimensionNames;
|
||||
|
||||
/** If there are non-empty dimension names */
|
||||
private final boolean hasNames;
|
||||
|
||||
/** The Java type of this tensor. */
|
||||
public final OnnxJavaType type;
|
||||
|
||||
|
|
@ -177,6 +184,9 @@ public class TensorInfo implements ValueInfo {
|
|||
*/
|
||||
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
|
||||
this.shape = shape;
|
||||
this.dimensionNames = new String[shape.length];
|
||||
Arrays.fill(dimensionNames, "");
|
||||
this.hasNames = false;
|
||||
this.type = type;
|
||||
this.onnxType = onnxType;
|
||||
this.numElements = elementCount(shape);
|
||||
|
|
@ -188,10 +198,20 @@ public class TensorInfo implements ValueInfo {
|
|||
* <p>Called from JNI.
|
||||
*
|
||||
* @param shape The tensor shape.
|
||||
* @param names The dimension names.
|
||||
* @param typeInt The native type int.
|
||||
*/
|
||||
TensorInfo(long[] shape, int typeInt) {
|
||||
TensorInfo(long[] shape, String[] names, int typeInt) {
|
||||
this.shape = shape;
|
||||
this.dimensionNames = names;
|
||||
boolean hasNames = false;
|
||||
for (String s : names) {
|
||||
if (!s.isEmpty()) {
|
||||
hasNames = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
this.hasNames = hasNames;
|
||||
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
|
||||
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
|
||||
this.numElements = elementCount(shape);
|
||||
|
|
@ -206,15 +226,42 @@ public class TensorInfo implements ValueInfo {
|
|||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a copy of the tensor's named dimensions.
|
||||
*
|
||||
* @return A copof the tensor's named dimensions.
|
||||
*/
|
||||
public String[] getDimensionNames() {
|
||||
return Arrays.copyOf(dimensionNames, dimensionNames.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TensorInfo(javaType="
|
||||
+ type.toString()
|
||||
+ ",onnxType="
|
||||
+ onnxType.toString()
|
||||
+ ",shape="
|
||||
+ Arrays.toString(shape)
|
||||
+ ")";
|
||||
String output =
|
||||
"TensorInfo(javaType="
|
||||
+ type.toString()
|
||||
+ ",onnxType="
|
||||
+ onnxType.toString()
|
||||
+ ",shape="
|
||||
+ Arrays.toString(shape);
|
||||
if (hasNames) {
|
||||
output =
|
||||
output
|
||||
+ ",dimNames=["
|
||||
+ Arrays.stream(dimensionNames)
|
||||
.map(
|
||||
a -> {
|
||||
if (a.isEmpty()) {
|
||||
return "\"\"";
|
||||
} else {
|
||||
return a;
|
||||
}
|
||||
})
|
||||
.collect(Collectors.joining(","))
|
||||
+ "]";
|
||||
}
|
||||
output = output + ")";
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
|
|||
if (code != ORT_OK) {
|
||||
return NULL;
|
||||
}
|
||||
//printf("numDim %d\n",numDim);
|
||||
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
|
||||
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
|
||||
if (code != ORT_OK) {
|
||||
|
|
@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
|
|||
free(dimensions);
|
||||
dimensions = NULL;
|
||||
|
||||
// Create the string array for the names.
|
||||
const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
|
||||
if (dimensionNames == NULL) {
|
||||
throwOrtException(jniEnv, 1, "Not enough memory");
|
||||
return NULL;
|
||||
}
|
||||
code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
|
||||
if (code != ORT_OK) {
|
||||
// extraction failed, exception has been thrown, return to Java.
|
||||
free(dimensionNames);
|
||||
return NULL;
|
||||
}
|
||||
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
|
||||
jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
|
||||
for (size_t i = 0; i < numDim; i++) {
|
||||
jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
|
||||
(*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
|
||||
}
|
||||
free(dimensionNames);
|
||||
|
||||
// Create the TensorInfo object
|
||||
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
|
||||
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
|
||||
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([JI)V");
|
||||
//printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
|
||||
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
|
||||
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([J[Ljava/lang/String;I)V");
|
||||
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
|
||||
return tensorInfo;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -590,6 +590,12 @@ public class InferenceTest {
|
|||
Map<String, NodeInfo> infoMap = session.getInputInfo();
|
||||
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
|
||||
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
|
||||
assertEquals(2, aInfo.dimensionNames.length);
|
||||
assertEquals("n", aInfo.dimensionNames[0]);
|
||||
assertEquals("", aInfo.dimensionNames[1]);
|
||||
TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
|
||||
assertEquals(1, bInfo.dimensionNames.length);
|
||||
assertEquals("m", bInfo.dimensionNames[0]);
|
||||
}
|
||||
}
|
||||
// Check that when the options are assigned it overrides the symbolic dimension
|
||||
|
|
|
|||
Loading…
Reference in a new issue