mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
[java] Adding the graph description to the exposed model metadata. (#10318)
This commit is contained in:
parent
037f08f1ff
commit
e47434ea12
3 changed files with 67 additions and 6 deletions
|
|
@ -1,3 +1,7 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
||||
import java.util.Collections;
|
||||
|
|
@ -17,6 +21,7 @@ public final class OnnxModelMetadata {
|
|||
|
||||
private final String producerName;
|
||||
private final String graphName;
|
||||
private final String graphDescription;
|
||||
private final String domain;
|
||||
private final String description;
|
||||
private final long version;
|
||||
|
|
@ -29,6 +34,7 @@ public final class OnnxModelMetadata {
|
|||
*
|
||||
* @param producerName The model producer name.
|
||||
* @param graphName The model graph name.
|
||||
* @param graphDescription The model graph description.
|
||||
* @param domain The model domain name.
|
||||
* @param description The model description.
|
||||
* @param version The model version.
|
||||
|
|
@ -37,12 +43,14 @@ public final class OnnxModelMetadata {
|
|||
OnnxModelMetadata(
|
||||
String producerName,
|
||||
String graphName,
|
||||
String graphDescription,
|
||||
String domain,
|
||||
String description,
|
||||
long version,
|
||||
String[] customMetadataArray) {
|
||||
this.producerName = producerName == null ? "" : producerName;
|
||||
this.graphName = graphName == null ? "" : graphName;
|
||||
this.graphDescription = graphDescription == null ? "" : graphDescription;
|
||||
this.domain = domain == null ? "" : domain;
|
||||
this.description = description == null ? "" : description;
|
||||
this.version = version;
|
||||
|
|
@ -66,6 +74,7 @@ public final class OnnxModelMetadata {
|
|||
*
|
||||
* @param producerName The model producer name.
|
||||
* @param graphName The model graph name.
|
||||
* @param graphDescription The model graph name.
|
||||
* @param domain The model domain name.
|
||||
* @param description The model description.
|
||||
* @param version The model version.
|
||||
|
|
@ -74,12 +83,14 @@ public final class OnnxModelMetadata {
|
|||
OnnxModelMetadata(
|
||||
String producerName,
|
||||
String graphName,
|
||||
String graphDescription,
|
||||
String domain,
|
||||
String description,
|
||||
long version,
|
||||
Map<String, String> customMetadata) {
|
||||
this.producerName = producerName == null ? "" : producerName;
|
||||
this.graphName = graphName == null ? "" : graphName;
|
||||
this.graphDescription = graphDescription == null ? "" : graphDescription;
|
||||
this.domain = domain == null ? "" : domain;
|
||||
this.description = description == null ? "" : description;
|
||||
this.version = version;
|
||||
|
|
@ -94,6 +105,7 @@ public final class OnnxModelMetadata {
|
|||
public OnnxModelMetadata(OnnxModelMetadata other) {
|
||||
this.producerName = other.producerName;
|
||||
this.graphName = other.graphName;
|
||||
this.graphDescription = other.graphDescription;
|
||||
this.domain = other.domain;
|
||||
this.description = other.description;
|
||||
this.version = other.version;
|
||||
|
|
@ -111,6 +123,7 @@ public final class OnnxModelMetadata {
|
|||
return version == that.version
|
||||
&& producerName.equals(that.producerName)
|
||||
&& graphName.equals(that.graphName)
|
||||
&& graphDescription.equals(that.graphDescription)
|
||||
&& domain.equals(that.domain)
|
||||
&& description.equals(that.description)
|
||||
&& customMetadata.equals(that.customMetadata);
|
||||
|
|
@ -118,7 +131,8 @@ public final class OnnxModelMetadata {
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(producerName, graphName, domain, description, version, customMetadata);
|
||||
return Objects.hash(
|
||||
producerName, graphName, graphDescription, domain, description, version, customMetadata);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -139,6 +153,15 @@ public final class OnnxModelMetadata {
|
|||
return graphName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the graph description.
|
||||
*
|
||||
* @return The graph description.
|
||||
*/
|
||||
public String getGraphDescription() {
|
||||
return graphDescription;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the domain.
|
||||
*
|
||||
|
|
@ -195,6 +218,9 @@ public final class OnnxModelMetadata {
|
|||
+ ", graphName='"
|
||||
+ graphName
|
||||
+ '\''
|
||||
+ ", graphDescription='"
|
||||
+ graphDescription
|
||||
+ '\''
|
||||
+ ", domain='"
|
||||
+ domain
|
||||
+ '\''
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
#include <jni.h>
|
||||
|
|
@ -385,7 +385,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
|
|||
jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName);
|
||||
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
|
||||
jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "<init>",
|
||||
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
|
||||
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
|
||||
|
||||
// Get metadata
|
||||
OrtModelMetadata* metadata;
|
||||
|
|
@ -402,6 +402,11 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
|
|||
jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
|
||||
|
||||
// Read out the graph description and convert it to a java.lang.String
|
||||
checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer));
|
||||
jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
|
||||
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));
|
||||
|
||||
// Read out the domain and convert it to a java.lang.String
|
||||
checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer));
|
||||
jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
|
||||
|
|
@ -449,8 +454,8 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
|
|||
}
|
||||
|
||||
// Invoke the metadata constructor
|
||||
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
|
||||
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, domainStr, descriptionStr, (jlong) version, customArray);
|
||||
//OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray)
|
||||
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray);
|
||||
|
||||
// Release the metadata
|
||||
api->ReleaseModelMetadata(metadata);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* Licensed under the MIT License.
|
||||
*/
|
||||
package ai.onnxruntime;
|
||||
|
|
@ -1168,6 +1168,36 @@ public class InferenceTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelMetadata() throws OrtException {
|
||||
String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString();
|
||||
|
||||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) {
|
||||
try (OrtSession session = env.createSession(modelPath)) {
|
||||
OnnxModelMetadata modelMetadata = session.getMetadata();
|
||||
|
||||
Assertions.assertEquals(1, modelMetadata.getVersion());
|
||||
|
||||
Assertions.assertEquals("Hari", modelMetadata.getProducerName());
|
||||
|
||||
Assertions.assertEquals("matmul test", modelMetadata.getGraphName());
|
||||
|
||||
Assertions.assertEquals("", modelMetadata.getDomain());
|
||||
|
||||
Assertions.assertEquals(
|
||||
"This is a test model with a valid ORT config Json", modelMetadata.getDescription());
|
||||
|
||||
Assertions.assertEquals("graph description", modelMetadata.getGraphDescription());
|
||||
|
||||
Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size());
|
||||
Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key"));
|
||||
Assertions.assertEquals(
|
||||
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
|
||||
modelMetadata.getCustomMetadata().get("ort_config"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelInputBOOL() throws OrtException {
|
||||
// model takes 1x5 input of fixed type, echoes back
|
||||
|
|
|
|||
Loading…
Reference in a new issue