mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[js/rn] Package dependency change to manage ort-extensions for react_native app (#15641)
### Description <!-- Describe your changes. --> js/react_native package dependency change to manage ort-extensions for react-native app. Enable optional inclusion of ort-ext aar/ ort-ext pods for react-native extensions apps when specifiy `ortExtensionsEnabled` in user's package.json file ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local> Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
This commit is contained in:
parent
41dcf0d32e
commit
c8bd34f975
6 changed files with 109 additions and 2 deletions
|
|
@ -20,6 +20,24 @@ def getExtOrIntegerDefault(name) {
|
|||
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
|
||||
}
|
||||
|
||||
def checkIfOrtExtensionsEnabled() {
|
||||
// locate user's project dir
|
||||
def reactnativeRootDir = project.rootDir.parentFile
|
||||
// get package.json file in root directory
|
||||
def packageJsonFile = new File(reactnativeRootDir, 'package.json')
|
||||
// read field 'onnxruntimeExtensionsEnabled'
|
||||
if (packageJsonFile.exists()) {
|
||||
def packageJsonContents = packageJsonFile.getText()
|
||||
def packageJson = new groovy.json.JsonSlurper().parseText(packageJsonContents)
|
||||
return packageJson.onnxruntimeExtensionsEnabled == "true"
|
||||
} else {
|
||||
logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ONNX Runtime Extensions will not be enabled.")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()
|
||||
|
||||
android {
|
||||
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
|
||||
buildToolsVersion getExtOrDefault('buildToolsVersion')
|
||||
|
|
@ -43,6 +61,17 @@ android {
|
|||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
main {
|
||||
java.srcDirs = ['src/main/java/']
|
||||
if (ortExtensionsEnabled) {
|
||||
java.exclude '**/OnnxruntimeExtensionsDisabled.java'
|
||||
} else {
|
||||
java.exclude '**/OnnxruntimeExtensionsEnabled.java'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
repositories {
|
||||
|
|
@ -136,4 +165,9 @@ dependencies {
|
|||
// Mobile build:
|
||||
// implementation "com.microsoft.onnxruntime:onnxruntime-mobile:latest.integration@aar"
|
||||
implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar"
|
||||
|
||||
// By default it will just include onnxruntime full aar package
|
||||
if (ortExtensionsEnabled) {
|
||||
implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
package ai.onnxruntime.reactnative;
|
||||
|
||||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import android.util.Log;
|
||||
|
||||
class OnnxruntimeExtensions {
|
||||
public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) {
|
||||
Log.i("OnnxruntimeExtensions",
|
||||
"ORT Extensions is not enabled in the current configuration. If you want to enable this support, "
|
||||
+ "please add \"onnxruntimeEnableExtensions\": \"true\" in your project root directory package.json.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
package ai.onnxruntime.reactnative;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import ai.onnxruntime.extensions.OrtxPackage;
|
||||
|
||||
class OnnxruntimeExtensions {
|
||||
public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) throws OrtException {
|
||||
sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath());
|
||||
}
|
||||
}
|
||||
|
|
@ -158,9 +158,15 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
OrtSession ortSession;
|
||||
SessionOptions sessionOptions = parseSessionOptions(options);
|
||||
|
||||
if (modelData != null && modelData.length > 0) { // load model via model data array
|
||||
// optional call for registering custom ops when ort extensions enabled
|
||||
OnnxruntimeExtensions ortExt = new OnnxruntimeExtensions();
|
||||
ortExt.registerOrtExtensionsIfEnabled(sessionOptions);
|
||||
|
||||
if (modelData != null && modelData.length > 0) {
|
||||
// load model via model data array
|
||||
ortSession = ortEnvironment.createSession(modelData, sessionOptions);
|
||||
} else { // load model via model path string uri
|
||||
} else {
|
||||
// load model via model path string uri
|
||||
InputStream modelStream =
|
||||
reactContext.getApplicationContext().getContentResolver().openInputStream(Uri.parse(uri));
|
||||
Reader reader = new BufferedReader(new InputStreamReader(modelStream));
|
||||
|
|
|
|||
|
|
@ -8,6 +8,18 @@
|
|||
#import <React/RCTLog.h>
|
||||
#import <onnxruntime/onnxruntime_cxx_api.h>
|
||||
|
||||
#ifdef ORT_ENABLE_EXTENSIONS
|
||||
extern "C" {
|
||||
// Note: Declared in onnxruntime_extensions.h but forward declared here to resolve a build issue:
|
||||
// (A compilation error happened while building an expo react native ios app, onnxruntime_c_api.h header
|
||||
// included in the onnxruntime_extensions.h leads to a redefinition conflicts with multiple object defined in the ORT C
|
||||
// API.) So doing a forward declaration here instead of #include "onnxruntime_extensions.h" as a workaround for now
|
||||
// before we have a fix.
|
||||
// TODO: Investigate if we can include onnxruntime_extensions.h here
|
||||
OrtStatus *RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api);
|
||||
} // Extern C
|
||||
#endif
|
||||
|
||||
@implementation OnnxruntimeModule
|
||||
|
||||
struct SessionInfo {
|
||||
|
|
@ -135,6 +147,10 @@ RCT_EXPORT_METHOD(run
|
|||
sessionInfo = new SessionInfo();
|
||||
Ort::SessionOptions sessionOptions = [self parseSessionOptions:options];
|
||||
|
||||
#ifdef ORT_ENABLE_EXTENSIONS
|
||||
Ort::ThrowOnError(RegisterCustomOps(sessionOptions, OrtGetApiBase()));
|
||||
#endif
|
||||
|
||||
if (modelData == nil) {
|
||||
sessionInfo->session.reset(new Ort::Session(*ortEnv, [modelPath UTF8String], sessionOptions));
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ require "json"
|
|||
|
||||
package = JSON.parse(File.read(File.join(__dir__, "package.json")))
|
||||
|
||||
# Expect to return the absolute path of the react native root project dir
|
||||
root_dir = File.dirname(File.dirname(__dir__))
|
||||
|
||||
Pod::Spec.new do |spec|
|
||||
spec.static_framework = true
|
||||
|
||||
|
|
@ -19,4 +22,22 @@ Pod::Spec.new do |spec|
|
|||
|
||||
spec.dependency "React-Core"
|
||||
spec.dependency "onnxruntime-c"
|
||||
|
||||
spec.xcconfig = {
|
||||
'OTHER_CPLUSPLUSFLAGS' => '-Wall -Wextra',
|
||||
}
|
||||
|
||||
if (File.exist?(File.join(root_dir, 'package.json')))
|
||||
# Read the react native root project directory package.json file
|
||||
root_package = JSON.parse(File.read(File.join(root_dir, 'package.json')))
|
||||
if (root_package["onnxruntimeExtensionsEnabled"] == 'true')
|
||||
spec.dependency "onnxruntime-extensions-c"
|
||||
spec.xcconfig = {
|
||||
'OTHER_CPLUSPLUSFLAGS' => '-DORT_ENABLE_EXTENSIONS=1 -Wall -Wextra',
|
||||
}
|
||||
end
|
||||
else
|
||||
puts "Could not find package.json file in the expected directory: #{root_dir}. ONNX Runtime Extensions will not be enabled."
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
|||
Loading…
Reference in a new issue