[java] Adding the graph description to the exposed model metadata. (#10318)

This commit is contained in:
Adam Pocock 2022-02-28 13:05:03 -05:00 committed by GitHub
parent 037f08f1ff
commit e47434ea12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 6 deletions

View file

@ -1,3 +1,7 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.util.Collections;
@ -17,6 +21,7 @@ public final class OnnxModelMetadata {
private final String producerName;
private final String graphName;
private final String graphDescription;
private final String domain;
private final String description;
private final long version;
@ -29,6 +34,7 @@ public final class OnnxModelMetadata {
*
* @param producerName The model producer name.
* @param graphName The model graph name.
* @param graphDescription The model graph description.
* @param domain The model domain name.
* @param description The model description.
* @param version The model version.
@ -37,12 +43,14 @@ public final class OnnxModelMetadata {
OnnxModelMetadata(
String producerName,
String graphName,
String graphDescription,
String domain,
String description,
long version,
String[] customMetadataArray) {
this.producerName = producerName == null ? "" : producerName;
this.graphName = graphName == null ? "" : graphName;
this.graphDescription = graphDescription == null ? "" : graphDescription;
this.domain = domain == null ? "" : domain;
this.description = description == null ? "" : description;
this.version = version;
@ -66,6 +74,7 @@ public final class OnnxModelMetadata {
*
* @param producerName The model producer name.
* @param graphName The model graph name.
* @param graphDescription The model graph name.
* @param domain The model domain name.
* @param description The model description.
* @param version The model version.
@ -74,12 +83,14 @@ public final class OnnxModelMetadata {
OnnxModelMetadata(
String producerName,
String graphName,
String graphDescription,
String domain,
String description,
long version,
Map<String, String> customMetadata) {
this.producerName = producerName == null ? "" : producerName;
this.graphName = graphName == null ? "" : graphName;
this.graphDescription = graphDescription == null ? "" : graphDescription;
this.domain = domain == null ? "" : domain;
this.description = description == null ? "" : description;
this.version = version;
@ -94,6 +105,7 @@ public final class OnnxModelMetadata {
public OnnxModelMetadata(OnnxModelMetadata other) {
this.producerName = other.producerName;
this.graphName = other.graphName;
this.graphDescription = other.graphDescription;
this.domain = other.domain;
this.description = other.description;
this.version = other.version;
@ -111,6 +123,7 @@ public final class OnnxModelMetadata {
return version == that.version
&& producerName.equals(that.producerName)
&& graphName.equals(that.graphName)
&& graphDescription.equals(that.graphDescription)
&& domain.equals(that.domain)
&& description.equals(that.description)
&& customMetadata.equals(that.customMetadata);
@ -118,7 +131,8 @@ public final class OnnxModelMetadata {
@Override
public int hashCode() {
return Objects.hash(producerName, graphName, domain, description, version, customMetadata);
return Objects.hash(
producerName, graphName, graphDescription, domain, description, version, customMetadata);
}
/**
@ -139,6 +153,15 @@ public final class OnnxModelMetadata {
return graphName;
}
/**
* Gets the graph description.
*
* @return The graph description.
*/
public String getGraphDescription() {
return graphDescription;
}
/**
* Gets the domain.
*
@ -195,6 +218,9 @@ public final class OnnxModelMetadata {
+ ", graphName='"
+ graphName
+ '\''
+ ", graphDescription='"
+ graphDescription
+ '\''
+ ", domain='"
+ domain
+ '\''

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
@ -385,7 +385,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName);
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "<init>",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
// Get metadata
OrtModelMetadata* metadata;
@ -402,6 +402,11 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
// Read out the graph description and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer));
jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
// Read out the domain and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer));
jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
@ -449,8 +454,8 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
}
// Invoke the metadata constructor
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, domainStr, descriptionStr, (jlong) version, customArray);
//OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray)
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray);
// Release the metadata
api->ReleaseModelMetadata(metadata);

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@ -1168,6 +1168,36 @@ public class InferenceTest {
}
}
@Test
public void testModelMetadata() throws OrtException {
String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString();
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) {
try (OrtSession session = env.createSession(modelPath)) {
OnnxModelMetadata modelMetadata = session.getMetadata();
Assertions.assertEquals(1, modelMetadata.getVersion());
Assertions.assertEquals("Hari", modelMetadata.getProducerName());
Assertions.assertEquals("matmul test", modelMetadata.getGraphName());
Assertions.assertEquals("", modelMetadata.getDomain());
Assertions.assertEquals(
"This is a test model with a valid ORT config Json", modelMetadata.getDescription());
Assertions.assertEquals("graph description", modelMetadata.getGraphDescription());
Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size());
Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key"));
Assertions.assertEquals(
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
modelMetadata.getCustomMetadata().get("ort_config"));
}
}
}
@Test
public void testModelInputBOOL() throws OrtException {
// model takes 1x5 input of fixed type, echoes back