diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index 4c8a318234..abf56a59a0 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -20,6 +20,24 @@ def getExtOrIntegerDefault(name) { return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger() } +def checkIfOrtExtensionsEnabled() { + // locate user's project dir + def reactnativeRootDir = project.rootDir.parentFile + // get package.json file in root directory + def packageJsonFile = new File(reactnativeRootDir, 'package.json') + // read field 'onnxruntimeExtensionsEnabled' + if (packageJsonFile.exists()) { + def packageJsonContents = packageJsonFile.getText() + def packageJson = new groovy.json.JsonSlurper().parseText(packageJsonContents) + return packageJson.onnxruntimeExtensionsEnabled == "true" + } else { + logger.warn("Could not find package.json file in the expected directory: ${reactnativeRootDir}. ONNX Runtime Extensions will not be enabled.") + } + return false +} + +boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled() + android { compileSdkVersion getExtOrIntegerDefault('compileSdkVersion') buildToolsVersion getExtOrDefault('buildToolsVersion') @@ -43,6 +61,17 @@ android { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } + + sourceSets { + main { + java.srcDirs = ['src/main/java/'] + if (ortExtensionsEnabled) { + java.exclude '**/OnnxruntimeExtensionsDisabled.java' + } else { + java.exclude '**/OnnxruntimeExtensionsEnabled.java' + } + } + } } repositories { @@ -136,4 +165,9 @@ dependencies { // Mobile build: // implementation "com.microsoft.onnxruntime:onnxruntime-mobile:latest.integration@aar" implementation "com.microsoft.onnxruntime:onnxruntime-android:latest.integration@aar" + + // By default it will just include onnxruntime full aar package + if (ortExtensionsEnabled) { + implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar" + } } diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java new file mode 100644 index 0000000000..de4c880981 --- /dev/null +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsDisabled.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai.onnxruntime.reactnative; + +import ai.onnxruntime.OrtSession.SessionOptions; +import android.util.Log; + +class OnnxruntimeExtensions { + public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) { + Log.i("OnnxruntimeExtensions", + "ORT Extensions is not enabled in the current configuration. If you want to enable this support, " + + "please add \"onnxruntimeEnableExtensions\": \"true\" in your project root directory package.json."); + return; + } +} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java new file mode 100644 index 0000000000..9bbf41c8f1 --- /dev/null +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeExtensionsEnabled.java @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai.onnxruntime.reactnative; + +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession.SessionOptions; +import ai.onnxruntime.extensions.OrtxPackage; + +class OnnxruntimeExtensions { + public void registerOrtExtensionsIfEnabled(SessionOptions sessionOptions) throws OrtException { + sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()); + } +} diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 99f6b44eb8..81b19e3829 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -158,9 +158,15 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { OrtSession ortSession; SessionOptions sessionOptions = parseSessionOptions(options); - if (modelData != null && modelData.length > 0) { // load model via model data array + // optional call for registering custom ops when ort extensions enabled + OnnxruntimeExtensions ortExt = new OnnxruntimeExtensions(); + ortExt.registerOrtExtensionsIfEnabled(sessionOptions); + + if (modelData != null && modelData.length > 0) { + // load model via model data array ortSession = ortEnvironment.createSession(modelData, sessionOptions); - } else { // load model via model path string uri + } else { + // load model via model path string uri InputStream modelStream = reactContext.getApplicationContext().getContentResolver().openInputStream(Uri.parse(uri)); Reader reader = new BufferedReader(new InputStreamReader(modelStream)); diff --git a/js/react_native/ios/OnnxruntimeModule.mm b/js/react_native/ios/OnnxruntimeModule.mm index 09bcf7f752..1b2b52c231 100644 --- a/js/react_native/ios/OnnxruntimeModule.mm +++ b/js/react_native/ios/OnnxruntimeModule.mm @@ -8,6 +8,18 @@ #import #import +#ifdef ORT_ENABLE_EXTENSIONS +extern "C" { +// Note: Declared in onnxruntime_extensions.h but forward declared here to resolve a build issue: +// (A compilation error happened while building an expo react native ios app, onnxruntime_c_api.h header +// included in the onnxruntime_extensions.h leads to a redefinition conflicts with multiple object defined in the ORT C +// API.) So doing a forward declaration here instead of #include "onnxruntime_extensions.h" as a workaround for now +// before we have a fix. +// TODO: Investigate if we can include onnxruntime_extensions.h here +OrtStatus *RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api); +} // Extern C +#endif + @implementation OnnxruntimeModule struct SessionInfo { @@ -135,6 +147,10 @@ RCT_EXPORT_METHOD(run sessionInfo = new SessionInfo(); Ort::SessionOptions sessionOptions = [self parseSessionOptions:options]; +#ifdef ORT_ENABLE_EXTENSIONS + Ort::ThrowOnError(RegisterCustomOps(sessionOptions, OrtGetApiBase())); +#endif + if (modelData == nil) { sessionInfo->session.reset(new Ort::Session(*ortEnv, [modelPath UTF8String], sessionOptions)); } else { diff --git a/js/react_native/onnxruntime-react-native.podspec b/js/react_native/onnxruntime-react-native.podspec index aeb08b336b..914a396be1 100644 --- a/js/react_native/onnxruntime-react-native.podspec +++ b/js/react_native/onnxruntime-react-native.podspec @@ -2,6 +2,9 @@ require "json" package = JSON.parse(File.read(File.join(__dir__, "package.json"))) +# Expect to return the absolute path of the react native root project dir +root_dir = File.dirname(File.dirname(__dir__)) + Pod::Spec.new do |spec| spec.static_framework = true @@ -19,4 +22,22 @@ Pod::Spec.new do |spec| spec.dependency "React-Core" spec.dependency "onnxruntime-c" + + spec.xcconfig = { + 'OTHER_CPLUSPLUSFLAGS' => '-Wall -Wextra', + } + + if (File.exist?(File.join(root_dir, 'package.json'))) + # Read the react native root project directory package.json file + root_package = JSON.parse(File.read(File.join(root_dir, 'package.json'))) + if (root_package["onnxruntimeExtensionsEnabled"] == 'true') + spec.dependency "onnxruntime-extensions-c" + spec.xcconfig = { + 'OTHER_CPLUSPLUSFLAGS' => '-DORT_ENABLE_EXTENSIONS=1 -Wall -Wextra', + } + end + else + puts "Could not find package.json file in the expected directory: #{root_dir}. ONNX Runtime Extensions will not be enabled." + end + end