[js/rn] Package dependency change to manage ort-extensions for react_native app (#15641)

### Description
<!-- Describe your changes. -->

js/react_native package dependency change to manage ort-extensions for
react-native app.

Enable optional inclusion of ort-ext aar/ ort-ext pods for react-native
extensions apps when specifiy `ortExtensionsEnabled` in user's
package.json file


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local>
Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
Rachel Guo 2023-04-29 00:07:12 -07:00 committed by GitHub
parent 41dcf0d32e
commit c8bd34f975
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 2 deletions

View file

@ -20,6 +20,24 @@ def getExtOrIntegerDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
}
def checkIfOrtExtensionsEnabled() {
// locate user's project dir
def reactnativeRootDir = project.rootDir.parentFile
// get package.json file in root directory
def packageJsonFile = new File(reactnativeRootDir, 'package.json')
// read field 'onnxruntimeExtensionsEnabled'
if (packageJsonFile.exists()) {
def packageJsonContents = packageJsonFile.getText()
def packageJson = new groovy.json.JsonSlurper().parseText(packageJsonContents)
return packageJson.onnxruntimeExtensionsEnabled == "true"
} else {
logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ONNX Runtime Extensions will not be enabled.")
}
return false
}
boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()
android {
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
buildToolsVersion getExtOrDefault('buildToolsVersion')
@ -43,6 +61,17 @@ android {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
sourceSets {
main {
java.srcDirs = ['src/main/java/']
if (ortExtensionsEnabled) {
java.exclude '**/OnnxruntimeExtensionsDisabled.java'
} else {
java.exclude '**/OnnxruntimeExtensionsEnabled.java'
}
}
}
}
repositories {
@ -136,4 +165,9 @@ dependencies {
// Mobile build:
// implementation "com.microsoft.onnxruntime:onnxruntime-mobile:latest.integration@aar"
implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar"
// By default it will just include onnxruntime full aar package
if (ortExtensionsEnabled) {
implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar"
}
}

View file

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package ai.onnxruntime.reactnative;
import ai.onnxruntime.OrtSession.SessionOptions;
import android.util.Log;
class OnnxruntimeExtensions {
public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) {
Log.i("OnnxruntimeExtensions",
"ORT Extensions is not enabled in the current configuration. If you want to enable this support, "
+ "please add \"onnxruntimeEnableExtensions\": \"true\" in your project root directory package.json.");
return;
}
}

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package ai.onnxruntime.reactnative;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.extensions.OrtxPackage;
class OnnxruntimeExtensions {
public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) throws OrtException {
sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath());
}
}

View file

@ -158,9 +158,15 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
OrtSession ortSession;
SessionOptions sessionOptions = parseSessionOptions(options);
if (modelData != null && modelData.length > 0) { // load model via model data array
// optional call for registering custom ops when ort extensions enabled
OnnxruntimeExtensions ortExt = new OnnxruntimeExtensions();
ortExt.registerOrtExtensionsIfEnabled(sessionOptions);
if (modelData != null && modelData.length > 0) {
// load model via model data array
ortSession = ortEnvironment.createSession(modelData, sessionOptions);
} else { // load model via model path string uri
} else {
// load model via model path string uri
InputStream modelStream =
reactContext.getApplicationContext().getContentResolver().openInputStream(Uri.parse(uri));
Reader reader = new BufferedReader(new InputStreamReader(modelStream));

View file

@ -8,6 +8,18 @@
#import <React/RCTLog.h>
#import <onnxruntime/onnxruntime_cxx_api.h>
#ifdef ORT_ENABLE_EXTENSIONS
extern "C" {
// Note: Declared in onnxruntime_extensions.h but forward declared here to resolve a build issue:
// (A compilation error happened while building an expo react native ios app, onnxruntime_c_api.h header
// included in the onnxruntime_extensions.h leads to a redefinition conflicts with multiple object defined in the ORT C
// API.) So doing a forward declaration here instead of #include "onnxruntime_extensions.h" as a workaround for now
// before we have a fix.
// TODO: Investigate if we can include onnxruntime_extensions.h here
OrtStatus *RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api);
} // Extern C
#endif
@implementation OnnxruntimeModule
struct SessionInfo {
@ -135,6 +147,10 @@ RCT_EXPORT_METHOD(run
sessionInfo = new SessionInfo();
Ort::SessionOptions sessionOptions = [self parseSessionOptions:options];
#ifdef ORT_ENABLE_EXTENSIONS
Ort::ThrowOnError(RegisterCustomOps(sessionOptions, OrtGetApiBase()));
#endif
if (modelData == nil) {
sessionInfo->session.reset(new Ort::Session(*ortEnv, [modelPath UTF8String], sessionOptions));
} else {

View file

@ -2,6 +2,9 @@ require "json"
package = JSON.parse(File.read(File.join(__dir__, "package.json")))
# Expect to return the absolute path of the react native root project dir
root_dir = File.dirname(File.dirname(__dir__))
Pod::Spec.new do |spec|
spec.static_framework = true
@ -19,4 +22,22 @@ Pod::Spec.new do |spec|
spec.dependency "React-Core"
spec.dependency "onnxruntime-c"
spec.xcconfig = {
'OTHER_CPLUSPLUSFLAGS' => '-Wall -Wextra',
}
if (File.exist?(File.join(root_dir, 'package.json')))
# Read the react native root project directory package.json file
root_package = JSON.parse(File.read(File.join(root_dir, 'package.json')))
if (root_package["onnxruntimeExtensionsEnabled"] == 'true')
spec.dependency "onnxruntime-extensions-c"
spec.xcconfig = {
'OTHER_CPLUSPLUSFLAGS' => '-DORT_ENABLE_EXTENSIONS=1 -Wall -Wextra',
}
end
else
puts "Could not find package.json file in the expected directory: #{root_dir}. ONNX Runtime Extensions will not be enabled."
end
end