onnxruntime/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c
Adam Pocock a36692066d
[java] CUDA & TensorRT options fix (#20549)
### Description
I misunderstood how UpdateCUDAProviderOptions and
UpdateTensorRTProviderOptions work in the C API, I had assumed that they
updated the options struct, however they re-initialize the struct to the
defaults then only apply the values in the update. I've rewritten the
Java bindings for those classes so that they aggregate all the updates
and apply them in one go. I also updated the C API documentation to note
that these classes have this behaviour. I've not checked if any of the
other providers with an options struct have this behaviour, we only
expose CUDA and TensorRT's options in Java.

There's a small unrelated update to add a private constructor to the
Fp16Conversions classes to remove a documentation warning (they
shouldn't be instantiated anyway as they are utility classes containing
static methods).

### Motivation and Context
Fixes #20544.
2024-05-05 00:16:55 -07:00

78 lines
3.3 KiB
C

/*
* Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_providers_OrtTensorRTProviderOptions.h"
/*
* Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions
* Method: create
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_create
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtTensorRTProviderOptionsV2* opts;
checkOrtStatus(jniEnv,api,api->CreateTensorRTProviderOptions(&opts));
return (jlong) opts;
}
/*
* Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions
* Method: applyToNative
* Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle;
jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr);
const char** keys = (const char**) allocarray(keyLength, sizeof(const char*));
const char** values = (const char**) allocarray(keyLength, sizeof(const char*));
if ((keys == NULL) || (values == NULL)) {
if (keys != NULL) {
free((void*)keys);
}
if (values != NULL) {
free((void*)values);
}
throwOrtException(jniEnv, 1, "Not enough memory");
} else {
// Copy out strings into UTF-8.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i);
values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL);
}
// Write to the provider options.
checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength));
// Release allocated strings.
for (jsize i = 0; i < keyLength; i++) {
jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]);
jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i);
(*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]);
}
free((void*)keys);
free((void*)values);
}
}
/*
* Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_close
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void)jniEnv; (void)jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
api->ReleaseTensorRTProviderOptions((OrtTensorRTProviderOptionsV2*)handle);
}