mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[js/rn] Implement blob exchange by JSI instead of use base64 (#16094)
### Description <!-- Describe your changes. --> - Create `OnnxruntimeJSIHelper` native module to provide two JSI functions - `jsiOnnxruntimeStoreArrayBuffer`: Store buffer in Blob Manager & return blob object (iOS: RCTBlobManager, Android: BlobModule) - `jsiOnnxruntimeResolveArrayBuffer`: Use blob object to get buffer - The part of implementation is reference to [react-native-blob-jsi-helper](https://github.com/mrousavy/react-native-blob-jsi-helper) - Replace base64 encode/decode - `loadModelFromBlob`: Rename from `loadModelFromBase64EncodedBuffer` - `run`: Use blob object to replace input.data & results[].data For [this context](https://github.com/microsoft/onnxruntime/issues/16031#issuecomment-1556527812), it saved a lot of time and avoid JS thread blocking in decode return type, it is 3700ms -> 5~20ms for the case. (resolve function only takes 0.x ms) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It’s related to #16031, but not a full implementation for migrate to JSI. It just uses JSI through BlobManager to replace the slow part (base64 encode / decode). Rewriting it entirely in JSI could be complicated, like type convertion and threading. This PR might be considered a minor change. /cc @skottmckay
This commit is contained in:
parent
9110e5b9bd
commit
ea1a5cf920
23 changed files with 935 additions and 141 deletions
37
js/react_native/android/CMakeLists.txt
Normal file
37
js/react_native/android/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
project(OnnxruntimeJSIHelper)
|
||||
cmake_minimum_required(VERSION 3.9.0)
|
||||
|
||||
set (PACKAGE_NAME "onnxruntime-react-native")
|
||||
set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath)
|
||||
|
||||
include_directories(
|
||||
"${NODE_MODULES_DIR}/react-native/React"
|
||||
"${NODE_MODULES_DIR}/react-native/React/Base"
|
||||
"${NODE_MODULES_DIR}/react-native/ReactCommon/jsi"
|
||||
)
|
||||
|
||||
add_library(onnxruntimejsihelper
|
||||
SHARED
|
||||
${libPath}
|
||||
src/main/cpp/cpp-adapter.cpp
|
||||
)
|
||||
|
||||
# Configure C++ 17
|
||||
set_target_properties(
|
||||
onnxruntimejsihelper PROPERTIES
|
||||
CXX_STANDARD 17
|
||||
CXX_EXTENSIONS OFF
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
)
|
||||
|
||||
find_library(log-lib log)
|
||||
|
||||
target_link_libraries(
|
||||
onnxruntimejsihelper
|
||||
${log-lib} # <-- Logcat logger
|
||||
android # <-- Android JNI core
|
||||
)
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import java.nio.file.Paths
|
||||
|
||||
buildscript {
|
||||
repositories {
|
||||
google()
|
||||
|
|
@ -20,6 +22,32 @@ def getExtOrIntegerDefault(name) {
|
|||
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
|
||||
}
|
||||
|
||||
def reactNativeArchitectures() {
|
||||
def value = project.getProperties().get("reactNativeArchitectures")
|
||||
return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"]
|
||||
}
|
||||
|
||||
def resolveBuildType() {
|
||||
Gradle gradle = getGradle()
|
||||
String tskReqStr = gradle.getStartParameter().getTaskRequests()['args'].toString()
|
||||
return tskReqStr.contains('Release') ? 'release' : 'debug'
|
||||
}
|
||||
|
||||
static def findNodeModules(baseDir) {
|
||||
def basePath = baseDir.toPath().normalize()
|
||||
while (basePath) {
|
||||
def nodeModulesPath = Paths.get(basePath.toString(), "node_modules")
|
||||
def reactNativePath = Paths.get(nodeModulesPath.toString(), "react-native")
|
||||
if (nodeModulesPath.toFile().exists() && reactNativePath.toFile().exists()) {
|
||||
return nodeModulesPath.toString()
|
||||
}
|
||||
basePath = basePath.getParent()
|
||||
}
|
||||
throw new GradleException("onnxruntime-react-native: Failed to find node_modules/ path!")
|
||||
}
|
||||
|
||||
def nodeModules = findNodeModules(projectDir);
|
||||
|
||||
def checkIfOrtExtensionsEnabled() {
|
||||
// locate user's project dir
|
||||
def reactnativeRootDir = project.rootDir.parentFile
|
||||
|
|
@ -38,6 +66,9 @@ def checkIfOrtExtensionsEnabled() {
|
|||
|
||||
boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()
|
||||
|
||||
def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim()
|
||||
def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger()
|
||||
|
||||
android {
|
||||
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
|
||||
buildToolsVersion getExtOrDefault('buildToolsVersion')
|
||||
|
|
@ -47,6 +78,44 @@ android {
|
|||
versionCode 1
|
||||
versionName "1.0"
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all"
|
||||
if (REACT_NATIVE_MINOR_VERSION >= 71) {
|
||||
// fabricjni required c++_shared
|
||||
arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
|
||||
} else {
|
||||
arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
|
||||
}
|
||||
abiFilters (*reactNativeArchitectures())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (rootProject.hasProperty("ndkPath")) {
|
||||
ndkPath rootProject.ext.ndkPath
|
||||
}
|
||||
if (rootProject.hasProperty("ndkVersion")) {
|
||||
ndkVersion rootProject.ext.ndkVersion
|
||||
}
|
||||
|
||||
buildFeatures {
|
||||
prefab true
|
||||
}
|
||||
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
|
||||
packagingOptions {
|
||||
doNotStrip resolveBuildType() == 'debug' ? "**/**/*.so" : ''
|
||||
excludes = [
|
||||
"META-INF",
|
||||
"META-INF/**",
|
||||
"**/libjsi.so",
|
||||
]
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
|
|
@ -149,8 +218,6 @@ repositories {
|
|||
}
|
||||
}
|
||||
|
||||
def REACT_NATIVE_VERSION = new File(['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim())
|
||||
|
||||
dependencies {
|
||||
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
|
||||
api "org.mockito:mockito-core:2.28.2"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
package ai.onnxruntime.reactnative;
|
||||
|
||||
import com.facebook.react.bridge.Arguments;
|
||||
import com.facebook.react.bridge.JavaOnlyMap;
|
||||
import com.facebook.react.bridge.ReactApplicationContext;
|
||||
import com.facebook.react.bridge.ReadableMap;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
|
||||
public class FakeBlobModule extends BlobModule {
|
||||
|
||||
public FakeBlobModule(ReactApplicationContext context) { super(null); }
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "BlobModule";
|
||||
}
|
||||
|
||||
public JavaOnlyMap testCreateData(byte[] bytes) {
|
||||
String blobId = store(bytes);
|
||||
JavaOnlyMap data = new JavaOnlyMap();
|
||||
data.putString("blobId", blobId);
|
||||
data.putInt("offset", 0);
|
||||
data.putInt("size", bytes.length);
|
||||
return data;
|
||||
}
|
||||
|
||||
public byte[] testGetData(ReadableMap data) {
|
||||
String blobId = data.getString("blobId");
|
||||
int offset = data.getInt("offset");
|
||||
int size = data.getInt("size");
|
||||
return resolve(blobId, offset, size);
|
||||
}
|
||||
}
|
||||
|
|
@ -10,11 +10,14 @@ import ai.onnxruntime.TensorInfo;
|
|||
import android.util.Base64;
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
import com.facebook.react.bridge.Arguments;
|
||||
import com.facebook.react.bridge.CatalystInstance;
|
||||
import com.facebook.react.bridge.JavaOnlyArray;
|
||||
import com.facebook.react.bridge.JavaOnlyMap;
|
||||
import com.facebook.react.bridge.ReactApplicationContext;
|
||||
import com.facebook.react.bridge.ReadableArray;
|
||||
import com.facebook.react.bridge.ReadableMap;
|
||||
import com.facebook.react.bridge.WritableMap;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
|
@ -29,12 +32,17 @@ public class OnnxruntimeModuleTest {
|
|||
private ReactApplicationContext reactContext =
|
||||
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());
|
||||
|
||||
private FakeBlobModule blobModule;
|
||||
|
||||
@Before
|
||||
public void setUp() {}
|
||||
public void setUp() {
|
||||
blobModule = new FakeBlobModule(reactContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getName() throws Exception {
|
||||
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
|
||||
ortModule.blobModule = blobModule;
|
||||
String name = "Onnxruntime";
|
||||
Assert.assertEquals(ortModule.getName(), name);
|
||||
}
|
||||
|
|
@ -47,6 +55,7 @@ public class OnnxruntimeModuleTest {
|
|||
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());
|
||||
|
||||
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
|
||||
ortModule.blobModule = blobModule;
|
||||
String sessionKey = "";
|
||||
|
||||
// test loadModel()
|
||||
|
|
@ -104,8 +113,7 @@ public class OnnxruntimeModuleTest {
|
|||
floatBuffer.put(value);
|
||||
}
|
||||
floatBuffer.rewind();
|
||||
String dataEncoded = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array()));
|
||||
|
||||
inputDataMap.putMap("input", inputTensorMap);
|
||||
}
|
||||
|
|
@ -124,10 +132,9 @@ public class OnnxruntimeModuleTest {
|
|||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
FloatBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT))
|
||||
.order(ByteOrder.nativeOrder())
|
||||
.asFloatBuffer();
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
FloatBuffer buffer =
|
||||
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ import androidx.test.platform.app.InstrumentationRegistry;
|
|||
import com.facebook.react.bridge.Arguments;
|
||||
import com.facebook.react.bridge.JavaOnlyArray;
|
||||
import com.facebook.react.bridge.JavaOnlyMap;
|
||||
import com.facebook.react.bridge.ReactApplicationContext;
|
||||
import com.facebook.react.bridge.ReadableMap;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
|
@ -39,11 +41,17 @@ import org.mockito.MockitoSession;
|
|||
|
||||
@SmallTest
|
||||
public class TensorHelperTest {
|
||||
private ReactApplicationContext reactContext =
|
||||
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());
|
||||
|
||||
private OrtEnvironment ortEnvironment;
|
||||
|
||||
private FakeBlobModule blobModule;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
ortEnvironment = OrtEnvironment.getEnvironment("TensorHelperTest");
|
||||
blobModule = new FakeBlobModule(reactContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -64,10 +72,9 @@ public class TensorHelperTest {
|
|||
dataFloatBuffer.put(Float.MIN_VALUE);
|
||||
dataFloatBuffer.put(2.0f);
|
||||
dataFloatBuffer.put(Float.MAX_VALUE);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
|
@ -94,10 +101,9 @@ public class TensorHelperTest {
|
|||
dataByteBuffer.put(Byte.MIN_VALUE);
|
||||
dataByteBuffer.put((byte)2);
|
||||
dataByteBuffer.put(Byte.MAX_VALUE);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
|
||||
|
|
@ -125,10 +131,9 @@ public class TensorHelperTest {
|
|||
dataByteBuffer.put((byte)0);
|
||||
dataByteBuffer.put((byte)2);
|
||||
dataByteBuffer.put((byte)255);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
|
||||
|
|
@ -157,10 +162,9 @@ public class TensorHelperTest {
|
|||
dataIntBuffer.put(Integer.MIN_VALUE);
|
||||
dataIntBuffer.put(2);
|
||||
dataIntBuffer.put(Integer.MAX_VALUE);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
|
||||
|
|
@ -189,10 +193,9 @@ public class TensorHelperTest {
|
|||
dataLongBuffer.put(Long.MIN_VALUE);
|
||||
dataLongBuffer.put(15000000001L);
|
||||
dataLongBuffer.put(Long.MAX_VALUE);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
|
||||
|
|
@ -221,10 +224,9 @@ public class TensorHelperTest {
|
|||
dataDoubleBuffer.put(Double.MIN_VALUE);
|
||||
dataDoubleBuffer.put(1.8e+30);
|
||||
dataDoubleBuffer.put(Double.MAX_VALUE);
|
||||
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
|
||||
inputTensorMap.putString("data", dataEncoded);
|
||||
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
|
||||
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
|
||||
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
|
||||
|
||||
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
|
||||
Assert.assertEquals(outputTensor.getInfo().onnxType,
|
||||
|
|
@ -258,14 +260,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeBool);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i) == 1, inputData[i]);
|
||||
}
|
||||
|
|
@ -298,15 +300,15 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeDouble);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
DoubleBuffer buffer =
|
||||
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asDoubleBuffer();
|
||||
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asDoubleBuffer();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
|
||||
}
|
||||
|
|
@ -339,15 +341,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
FloatBuffer buffer =
|
||||
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asFloatBuffer();
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
FloatBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
|
||||
}
|
||||
|
|
@ -380,14 +381,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeByte);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i]);
|
||||
}
|
||||
|
|
@ -420,15 +421,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeInt);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
IntBuffer buffer =
|
||||
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asIntBuffer();
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
IntBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asIntBuffer();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i]);
|
||||
}
|
||||
|
|
@ -461,15 +461,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeLong);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
LongBuffer buffer =
|
||||
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asLongBuffer();
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
LongBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asLongBuffer();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i]);
|
||||
}
|
||||
|
|
@ -502,14 +501,14 @@ public class TensorHelperTest {
|
|||
|
||||
OrtSession.Result result = session.run(container);
|
||||
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
ReadableMap outputMap = resultMap.getMap("output");
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
|
||||
}
|
||||
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeUnsignedByte);
|
||||
String dataEncoded = outputMap.getString("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
|
||||
ReadableMap data = outputMap.getMap("data");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Assert.assertEquals(buffer.get(i), inputData[i]);
|
||||
}
|
||||
|
|
|
|||
127
js/react_native/android/src/main/cpp/cpp-adapter.cpp
Normal file
127
js/react_native/android/src/main/cpp/cpp-adapter.cpp
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
#include <jni.h>
|
||||
#include <jsi/jsi.h>
|
||||
#include <string>
|
||||
|
||||
using namespace facebook;
|
||||
|
||||
typedef u_int8_t byte;
|
||||
|
||||
std::string jstring2string(JNIEnv *env, jstring jStr) {
|
||||
if (!jStr) return "";
|
||||
|
||||
jclass stringClass = env->GetObjectClass(jStr);
|
||||
jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B");
|
||||
const auto stringJbytes = (jbyteArray) env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8"));
|
||||
|
||||
auto length = (size_t) env->GetArrayLength(stringJbytes);
|
||||
jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr);
|
||||
|
||||
std::string ret = std::string((char *)pBytes, length);
|
||||
env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT);
|
||||
|
||||
env->DeleteLocalRef(stringJbytes);
|
||||
env->DeleteLocalRef(stringClass);
|
||||
return ret;
|
||||
}
|
||||
|
||||
byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& blobId, int offset, int size) {
|
||||
if (!env) throw std::runtime_error("JNI Environment is gone!");
|
||||
|
||||
// get java class
|
||||
jclass clazz = env->GetObjectClass(instanceGlobal);
|
||||
// get method in java class
|
||||
jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B");
|
||||
// call method
|
||||
auto jstring = env->NewStringUTF(blobId.c_str());
|
||||
auto boxedBytes = (jbyteArray) env->CallObjectMethod(instanceGlobal,
|
||||
getBufferJava,
|
||||
// arguments
|
||||
jstring,
|
||||
offset,
|
||||
size);
|
||||
env->DeleteLocalRef(jstring);
|
||||
|
||||
jboolean isCopy = true;
|
||||
jbyte* bytes = env->GetByteArrayElements(boxedBytes, &isCopy);
|
||||
env->DeleteLocalRef(boxedBytes);
|
||||
return reinterpret_cast<byte*>(bytes);
|
||||
};
|
||||
|
||||
std::string createBlob(JNIEnv *env, jobject instanceGlobal, byte* bytes, size_t size) {
|
||||
if (!env) throw std::runtime_error("JNI Environment is gone!");
|
||||
|
||||
// get java class
|
||||
jclass clazz = env->GetObjectClass(instanceGlobal);
|
||||
// get method in java class
|
||||
jmethodID getBufferJava = env->GetMethodID(clazz, "createBlob", "([B)Ljava/lang/String;");
|
||||
// call method
|
||||
auto byteArray = env->NewByteArray(size);
|
||||
env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast<jbyte*>(bytes));
|
||||
auto blobId = (jstring) env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray);
|
||||
env->DeleteLocalRef(byteArray);
|
||||
|
||||
return jstring2string(env, blobId);
|
||||
};
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, jclass _, jlong jsiPtr, jobject instance) {
|
||||
auto jsiRuntime = reinterpret_cast<jsi::Runtime*>(jsiPtr);
|
||||
|
||||
auto& runtime = *jsiRuntime;
|
||||
|
||||
auto instanceGlobal = env->NewGlobalRef(instance);
|
||||
|
||||
auto resolveArrayBuffer = jsi::Function::createFromHostFunction(runtime,
|
||||
jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"),
|
||||
1,
|
||||
[=](jsi::Runtime& runtime,
|
||||
const jsi::Value& thisValue,
|
||||
const jsi::Value* arguments,
|
||||
size_t count) -> jsi::Value {
|
||||
if (count != 1) {
|
||||
throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!");
|
||||
}
|
||||
|
||||
jsi::Object data = arguments[0].asObject(runtime);
|
||||
auto blobId = data.getProperty(runtime, "blobId").asString(runtime);
|
||||
auto offset = data.getProperty(runtime, "offset").asNumber();
|
||||
auto size = data.getProperty(runtime, "size").asNumber();
|
||||
|
||||
auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size);
|
||||
|
||||
size_t totalSize = size - offset;
|
||||
jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer");
|
||||
jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int) totalSize).getObject(runtime);
|
||||
jsi::ArrayBuffer buf = o.getArrayBuffer(runtime);
|
||||
memcpy(buf.data(runtime), reinterpret_cast<byte*>(bytes), totalSize);
|
||||
|
||||
return buf;
|
||||
});
|
||||
runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer));
|
||||
|
||||
auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime,
|
||||
jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeStoreArrayBuffer"),
|
||||
1,
|
||||
[=](jsi::Runtime& runtime,
|
||||
const jsi::Value& thisValue,
|
||||
const jsi::Value* arguments,
|
||||
size_t count) -> jsi::Value {
|
||||
if (count != 1) {
|
||||
throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!");
|
||||
}
|
||||
|
||||
auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime);
|
||||
auto size = arrayBuffer.size(runtime);
|
||||
|
||||
std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size);
|
||||
|
||||
jsi::Object result(runtime);
|
||||
auto blobIdString = jsi::String::createFromUtf8(runtime, blobId);
|
||||
result.setProperty(runtime, "blobId", blobIdString);
|
||||
result.setProperty(runtime, "offset", jsi::Value(0));
|
||||
result.setProperty(runtime, "size", jsi::Value(static_cast<double>(size)));
|
||||
return result;
|
||||
});
|
||||
runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer));
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
package ai.onnxruntime.reactnative;
|
||||
|
||||
import androidx.annotation.NonNull;
|
||||
import com.facebook.react.bridge.JavaScriptContextHolder;
|
||||
import com.facebook.react.bridge.ReactApplicationContext;
|
||||
import com.facebook.react.bridge.ReactContextBaseJavaModule;
|
||||
import com.facebook.react.bridge.ReactMethod;
|
||||
import com.facebook.react.module.annotations.ReactModule;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
|
||||
@ReactModule(name = OnnxruntimeJSIHelper.NAME)
|
||||
public class OnnxruntimeJSIHelper extends ReactContextBaseJavaModule {
|
||||
public static final String NAME = "OnnxruntimeJSIHelper";
|
||||
|
||||
private static ReactApplicationContext reactContext;
|
||||
protected BlobModule blobModule;
|
||||
|
||||
public OnnxruntimeJSIHelper(ReactApplicationContext context) {
|
||||
super(context);
|
||||
reactContext = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public void checkBlobModule() {
|
||||
if (blobModule == null) {
|
||||
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
|
||||
if (blobModule == null) {
|
||||
throw new RuntimeException("BlobModule is not initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ReactMethod(isBlockingSynchronousMethod = true)
|
||||
public boolean install() {
|
||||
try {
|
||||
System.loadLibrary("onnxruntimejsihelper");
|
||||
JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder();
|
||||
nativeInstall(jsContext.get(), this);
|
||||
return true;
|
||||
} catch (Exception exception) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] getBlobBuffer(String blobId, int offset, int size) {
|
||||
checkBlobModule();
|
||||
byte[] bytes = blobModule.resolve(blobId, offset, size);
|
||||
blobModule.remove(blobId);
|
||||
if (bytes == null) {
|
||||
throw new RuntimeException("Failed to resolve Blob #" + blobId + "! Not found.");
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
public String createBlob(byte[] buffer) {
|
||||
checkBlobModule();
|
||||
String blobId = blobModule.store(buffer);
|
||||
if (blobId == null) {
|
||||
throw new RuntimeException("Failed to create Blob!");
|
||||
}
|
||||
return blobId;
|
||||
}
|
||||
|
||||
public static native void nativeInstall(long jsiPointer, OnnxruntimeJSIHelper instance);
|
||||
}
|
||||
|
|
@ -13,7 +13,6 @@ import ai.onnxruntime.OrtSession.RunOptions;
|
|||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import android.net.Uri;
|
||||
import android.os.Build;
|
||||
import android.util.Base64;
|
||||
import android.util.Log;
|
||||
import androidx.annotation.NonNull;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
|
@ -28,6 +27,7 @@ import com.facebook.react.bridge.ReadableMap;
|
|||
import com.facebook.react.bridge.ReadableType;
|
||||
import com.facebook.react.bridge.WritableArray;
|
||||
import com.facebook.react.bridge.WritableMap;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
|
@ -56,6 +56,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
return key;
|
||||
}
|
||||
|
||||
protected BlobModule blobModule;
|
||||
|
||||
public OnnxruntimeModule(ReactApplicationContext context) {
|
||||
super(context);
|
||||
reactContext = context;
|
||||
|
|
@ -67,6 +69,15 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
return "Onnxruntime";
|
||||
}
|
||||
|
||||
public void checkBlobModule() {
|
||||
if (blobModule == null) {
|
||||
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
|
||||
if (blobModule == null) {
|
||||
throw new RuntimeException("BlobModule is not initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* React native binding API to load a model using given uri.
|
||||
*
|
||||
|
|
@ -87,19 +98,22 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
}
|
||||
|
||||
/**
|
||||
* React native binding API to load a model using the BASE64 encoded model data.
|
||||
* React native binding API to load a model using blob object that data stored in BlobModule.
|
||||
*
|
||||
* @param data the BASE64 encoded model data.
|
||||
* @param data the blob object
|
||||
* @param options onnxruntime session options
|
||||
* @param promise output returning back to react native js
|
||||
* @note the value provided to `promise` includes a key representing the session.
|
||||
* when run() is called, the key must be passed into the first parameter.
|
||||
*/
|
||||
@ReactMethod
|
||||
public void loadModelFromBase64EncodedBuffer(String data, ReadableMap options, Promise promise) {
|
||||
public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) {
|
||||
try {
|
||||
byte[] modelData = Base64.decode(data, Base64.DEFAULT);
|
||||
WritableMap resultMap = loadModel(modelData, options);
|
||||
checkBlobModule();
|
||||
String blobId = data.getString("blobId");
|
||||
byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size"));
|
||||
blobModule.remove(blobId);
|
||||
WritableMap resultMap = loadModel(bytes, options);
|
||||
promise.resolve(resultMap);
|
||||
} catch (Exception e) {
|
||||
promise.reject("Failed to load model from buffer: " + e.getMessage(), e);
|
||||
|
|
@ -242,6 +256,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
|
||||
RunOptions runOptions = parseRunOptions(options);
|
||||
|
||||
checkBlobModule();
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
Map<String, OnnxTensor> feed = new HashMap<>();
|
||||
Iterator<String> iterator = ortSession.getInputNames().iterator();
|
||||
|
|
@ -255,19 +271,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
throw new Exception("Can't find input: " + inputName);
|
||||
}
|
||||
|
||||
if (inputMap.getType("data") != ReadableType.String) {
|
||||
// NOTE:
|
||||
//
|
||||
// tensor data should always be a BASE64 encoded string.
|
||||
// This is because the current React Native bridge supports limited data type as arguments.
|
||||
// In order to pass data from JS to Java, we have to encode them into string.
|
||||
//
|
||||
// see also:
|
||||
// https://reactnative.dev/docs/native-modules-android#argument-types
|
||||
throw new Exception("Non string type of a tensor data is not allowed");
|
||||
}
|
||||
|
||||
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
|
||||
OnnxTensor onnxTensor = TensorHelper.createInputTensor(blobModule, inputMap, ortEnvironment);
|
||||
feed.put(inputName, onnxTensor);
|
||||
}
|
||||
|
||||
|
|
@ -292,7 +296,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
|
|||
Log.d("Duration", "inference: " + duration);
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
WritableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
|
||||
duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "createOutputTensor: " + duration);
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import androidx.annotation.RequiresApi;
|
|||
import com.facebook.react.ReactPackage;
|
||||
import com.facebook.react.bridge.NativeModule;
|
||||
import com.facebook.react.bridge.ReactApplicationContext;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
import com.facebook.react.uimanager.ViewManager;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
|
@ -21,6 +22,7 @@ public class OnnxruntimePackage implements ReactPackage {
|
|||
public List<NativeModule> createNativeModules(@NonNull ReactApplicationContext reactContext) {
|
||||
List<NativeModule> modules = new ArrayList<>();
|
||||
modules.add(new OnnxruntimeModule(reactContext));
|
||||
modules.add(new OnnxruntimeJSIHelper(reactContext));
|
||||
return modules;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import com.facebook.react.bridge.ReadableArray;
|
|||
import com.facebook.react.bridge.ReadableMap;
|
||||
import com.facebook.react.bridge.WritableArray;
|
||||
import com.facebook.react.bridge.WritableMap;
|
||||
import com.facebook.react.modules.blob.BlobModule;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
|
|
@ -45,9 +46,10 @@ public class TensorHelper {
|
|||
|
||||
/**
|
||||
* It creates an input tensor from a map passed by react native js.
|
||||
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
|
||||
* 'data' is blob object and the buffer is stored in BlobModule. It first resolve it and creates a tensor.
|
||||
*/
|
||||
public static OnnxTensor createInputTensor(ReadableMap inputTensor, OrtEnvironment ortEnvironment) throws Exception {
|
||||
public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap inputTensor,
|
||||
OrtEnvironment ortEnvironment) throws Exception {
|
||||
// shape
|
||||
ReadableArray dimsArray = inputTensor.getArray("dims");
|
||||
long[] dims = new long[dimsArray.size()];
|
||||
|
|
@ -68,8 +70,11 @@ public class TensorHelper {
|
|||
}
|
||||
onnxTensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims);
|
||||
} else {
|
||||
String data = inputTensor.getString("data");
|
||||
ByteBuffer values = ByteBuffer.wrap(Base64.decode(data, Base64.DEFAULT)).order(ByteOrder.nativeOrder());
|
||||
ReadableMap data = inputTensor.getMap("data");
|
||||
String blobId = data.getString("blobId");
|
||||
byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size"));
|
||||
blobModule.remove(blobId);
|
||||
ByteBuffer values = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder());
|
||||
onnxTensor = createInputTensor(tensorType, dims, values, ortEnvironment);
|
||||
}
|
||||
|
||||
|
|
@ -78,9 +83,9 @@ public class TensorHelper {
|
|||
|
||||
/**
|
||||
* It creates an output map from an output tensor.
|
||||
* a data array is encoded as base64 string.
|
||||
* a data array is store in BlobModule.
|
||||
*/
|
||||
public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception {
|
||||
public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception {
|
||||
WritableMap outputTensorMap = Arguments.createMap();
|
||||
|
||||
Iterator<Map.Entry<String, OnnxValue>> iterator = result.iterator();
|
||||
|
|
@ -115,8 +120,13 @@ public class TensorHelper {
|
|||
}
|
||||
outputTensor.putArray("data", dataArray);
|
||||
} else {
|
||||
String data = createOutputTensor(onnxTensor);
|
||||
outputTensor.putString("data", data);
|
||||
// Store in BlobModule then create a blob object as data
|
||||
byte[] bufferArray = createOutputTensor(onnxTensor);
|
||||
WritableMap data = Arguments.createMap();
|
||||
data.putString("blobId", blobModule.store(bufferArray));
|
||||
data.putInt("offset", 0);
|
||||
data.putInt("size", bufferArray.length);
|
||||
outputTensor.putMap("data", data);
|
||||
}
|
||||
|
||||
outputTensorMap.putMap(outputName, outputTensor);
|
||||
|
|
@ -177,7 +187,7 @@ public class TensorHelper {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception {
|
||||
private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception {
|
||||
TensorInfo tensorInfo = onnxTensor.getInfo();
|
||||
ByteBuffer buffer = null;
|
||||
|
||||
|
|
@ -224,8 +234,7 @@ public class TensorHelper {
|
|||
throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString());
|
||||
}
|
||||
|
||||
String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
|
||||
return data;
|
||||
return buffer.array();
|
||||
}
|
||||
|
||||
private static final Map<String, TensorInfo.OnnxTensorType> JsTensorTypeToOnnxTensorTypeMap =
|
||||
|
|
|
|||
5
js/react_native/ios/OnnxruntimeJSIHelper.h
Normal file
5
js/react_native/ios/OnnxruntimeJSIHelper.h
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#import <React/RCTBridgeModule.h>
|
||||
|
||||
@interface OnnxruntimeJSIHelper : NSObject <RCTBridgeModule>
|
||||
|
||||
@end
|
||||
85
js/react_native/ios/OnnxruntimeJSIHelper.mm
Normal file
85
js/react_native/ios/OnnxruntimeJSIHelper.mm
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
#import "OnnxruntimeJSIHelper.h"
|
||||
#import <React/RCTBlobManager.h>
|
||||
#import <React/RCTBridge+Private.h>
|
||||
#import <jsi/jsi.h>
|
||||
|
||||
@implementation OnnxruntimeJSIHelper
|
||||
|
||||
RCT_EXPORT_MODULE()
|
||||
|
||||
RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) {
|
||||
RCTBridge *bridge = [RCTBridge currentBridge];
|
||||
RCTCxxBridge *cxxBridge = (RCTCxxBridge *)bridge;
|
||||
if (cxxBridge == nil) {
|
||||
return @false;
|
||||
}
|
||||
|
||||
using namespace facebook;
|
||||
|
||||
auto jsiRuntime = (jsi::Runtime *)cxxBridge.runtime;
|
||||
if (jsiRuntime == nil) {
|
||||
return @false;
|
||||
}
|
||||
auto &runtime = *jsiRuntime;
|
||||
|
||||
auto resolveArrayBuffer = jsi::Function::createFromHostFunction(
|
||||
runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeResolveArrayBuffer"), 1,
|
||||
[](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value {
|
||||
if (count != 1) {
|
||||
throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!");
|
||||
}
|
||||
|
||||
auto data = args[0].asObject(runtime);
|
||||
auto blobId = data.getProperty(runtime, "blobId").asString(runtime).utf8(runtime);
|
||||
auto size = data.getProperty(runtime, "size").asNumber();
|
||||
auto offset = data.getProperty(runtime, "offset").asNumber();
|
||||
|
||||
RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
|
||||
if (blobManager == nil) {
|
||||
throw jsi::JSError(runtime, "RCTBlobManager is not initialized");
|
||||
}
|
||||
|
||||
NSString *blobIdStr = [NSString stringWithUTF8String:blobId.c_str()];
|
||||
auto blob = [blobManager resolve:blobIdStr offset:(long)offset size:(long)size];
|
||||
|
||||
jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer");
|
||||
jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)blob.length).getObject(runtime);
|
||||
jsi::ArrayBuffer buf = o.getArrayBuffer(runtime);
|
||||
memcpy(buf.data(runtime), blob.bytes, blob.length);
|
||||
[blobManager remove:blobIdStr];
|
||||
return buf;
|
||||
});
|
||||
runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", resolveArrayBuffer);
|
||||
|
||||
auto storeArrayBuffer = jsi::Function::createFromHostFunction(
|
||||
runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeStoreArrayBuffer"), 1,
|
||||
[](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value {
|
||||
if (count != 1) {
|
||||
throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!");
|
||||
}
|
||||
|
||||
auto arrayBuffer = args[0].asObject(runtime).getArrayBuffer(runtime);
|
||||
auto size = arrayBuffer.length(runtime);
|
||||
NSData *data = [NSData dataWithBytesNoCopy:arrayBuffer.data(runtime) length:size freeWhenDone:NO];
|
||||
|
||||
RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
|
||||
if (blobManager == nil) {
|
||||
throw jsi::JSError(runtime, "RCTBlobManager is not initialized");
|
||||
}
|
||||
|
||||
NSString *blobId = [blobManager store:data];
|
||||
|
||||
jsi::Object result(runtime);
|
||||
auto blobIdString = jsi::String::createFromUtf8(runtime, [blobId cStringUsingEncoding:NSUTF8StringEncoding]);
|
||||
result.setProperty(runtime, "blobId", blobIdString);
|
||||
result.setProperty(runtime, "offset", jsi::Value(0));
|
||||
result.setProperty(runtime, "size", jsi::Value(static_cast<double>(size)));
|
||||
return result;
|
||||
});
|
||||
|
||||
runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", storeArrayBuffer);
|
||||
|
||||
return @true;
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
@ -5,9 +5,12 @@
|
|||
#define OnnxruntimeModule_h
|
||||
|
||||
#import <React/RCTBridgeModule.h>
|
||||
#import <React/RCTBlobManager.h>
|
||||
|
||||
@interface OnnxruntimeModule : NSObject<RCTBridgeModule>
|
||||
|
||||
- (void)setBlobManager:(RCTBlobManager *)manager;
|
||||
|
||||
-(NSDictionary*)loadModel:(NSString*)modelPath
|
||||
options:(NSDictionary*)options;
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
#import "TensorHelper.h"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <React/RCTBlobManager.h>
|
||||
#import <React/RCTBridge+Private.h>
|
||||
#import <React/RCTLog.h>
|
||||
|
||||
// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened
|
||||
|
|
@ -44,6 +46,21 @@ static int nextSessionId = 0;
|
|||
|
||||
RCT_EXPORT_MODULE(Onnxruntime)
|
||||
|
||||
RCTBlobManager *blobManager = nil;
|
||||
|
||||
- (void)checkBlobManager {
|
||||
if (blobManager == nil) {
|
||||
blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
|
||||
if (blobManager == nil) {
|
||||
@throw @"RCTBlobManager is not initialized";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- (void)setBlobManager:(RCTBlobManager *)manager {
|
||||
blobManager = manager;
|
||||
}
|
||||
|
||||
/**
|
||||
* React native binding API to load a model using given uri.
|
||||
*
|
||||
|
|
@ -68,22 +85,27 @@ RCT_EXPORT_METHOD(loadModel
|
|||
}
|
||||
|
||||
/**
|
||||
* React native binding API to load a model using BASE64 encoded model data string.
|
||||
* React native binding API to load a model using blob object that data stored in RCTBlobManager.
|
||||
*
|
||||
* @param modelData the BASE64 encoded model data string
|
||||
* @param modelDataBlob a model data blob object
|
||||
* @param options onnxruntime session options
|
||||
* @param resolve callback for returning output back to react native js
|
||||
* @param reject callback for returning an error back to react native js
|
||||
* @note when run() is called, the same modelPath must be passed into the first parameter.
|
||||
*/
|
||||
RCT_EXPORT_METHOD(loadModelFromBase64EncodedBuffer
|
||||
: (NSString *)modelDataBase64EncodedString options
|
||||
RCT_EXPORT_METHOD(loadModelFromBlob
|
||||
: (NSDictionary *)modelDataBlob options
|
||||
: (NSDictionary *)options resolver
|
||||
: (RCTPromiseResolveBlock)resolve rejecter
|
||||
: (RCTPromiseRejectBlock)reject) {
|
||||
@try {
|
||||
NSData *modelDataDecoded = [[NSData alloc] initWithBase64EncodedString:modelDataBase64EncodedString options:0];
|
||||
NSDictionary *resultMap = [self loadModelFromBuffer:modelDataDecoded options:options];
|
||||
[self checkBlobManager];
|
||||
NSString *blobId = [modelDataBlob objectForKey:@"blobId"];
|
||||
long size = [[modelDataBlob objectForKey:@"size"] longValue];
|
||||
long offset = [[modelDataBlob objectForKey:@"offset"] longValue];
|
||||
auto modelData = [blobManager resolve:blobId offset:offset size:size];
|
||||
NSDictionary *resultMap = [self loadModelFromBuffer:modelData options:options];
|
||||
[blobManager remove:blobId];
|
||||
resolve(resultMap);
|
||||
} @catch (...) {
|
||||
reject(@"onnxruntime", @"failed to load model from buffer", nil);
|
||||
|
|
@ -255,6 +277,8 @@ RCT_EXPORT_METHOD(run
|
|||
}
|
||||
SessionInfo *sessionInfo = (SessionInfo *)[value pointerValue];
|
||||
|
||||
[self checkBlobManager];
|
||||
|
||||
std::vector<Ort::Value> feeds;
|
||||
std::vector<Ort::MemoryAllocation> allocations;
|
||||
feeds.reserve(sessionInfo->inputNames.size());
|
||||
|
|
@ -265,7 +289,10 @@ RCT_EXPORT_METHOD(run
|
|||
@throw exception;
|
||||
}
|
||||
|
||||
Ort::Value value = [TensorHelper createInputTensor:inputTensor ortAllocator:ortAllocator allocations:allocations];
|
||||
Ort::Value value = [TensorHelper createInputTensor:blobManager
|
||||
input:inputTensor
|
||||
ortAllocator:ortAllocator
|
||||
allocations:allocations];
|
||||
feeds.emplace_back(std::move(value));
|
||||
}
|
||||
|
||||
|
|
@ -280,7 +307,7 @@ RCT_EXPORT_METHOD(run
|
|||
sessionInfo->session->Run(runOptions, sessionInfo->inputNames.data(), feeds.data(),
|
||||
sessionInfo->inputNames.size(), requestedOutputs.data(), requestedOutputs.size());
|
||||
|
||||
NSDictionary *resultMap = [TensorHelper createOutputTensor:requestedOutputs values:result];
|
||||
NSDictionary *resultMap = [TensorHelper createOutputTensor:blobManager outputNames:requestedOutputs values:result];
|
||||
|
||||
return resultMap;
|
||||
}
|
||||
|
|
@ -378,6 +405,7 @@ static NSDictionary *executionModeTable = @{@"sequential" : @(ORT_SEQUENTIAL), @
|
|||
while (NSString *key = [iterator nextObject]) {
|
||||
[self dispose:key];
|
||||
}
|
||||
blobManager = nullptr;
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */; };
|
||||
7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */; };
|
||||
C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */; };
|
||||
DB8FC9B525C2867800C72F26 /* OnnxruntimeModule.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */; };
|
||||
DB8FC9B825C2868700C72F26 /* TensorHelper.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B725C2868700C72F26 /* TensorHelper.mm */; };
|
||||
DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */ = {isa = PBXBuildFile; fileRef = DBDB57D92603211A004F16BE /* TensorHelperTest.mm */; };
|
||||
|
|
@ -39,6 +42,14 @@
|
|||
|
||||
/* Begin PBXFileReference section */
|
||||
134814201AA4EA6300B7C361 /* libOnnxruntimeModule.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libOnnxruntimeModule.a; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModuleTest.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModule.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.debug.xcconfig"; sourceTree = "<group>"; };
|
||||
548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.release.xcconfig"; sourceTree = "<group>"; };
|
||||
63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.debug.xcconfig"; sourceTree = "<group>"; };
|
||||
7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = FakeRCTBlobManager.m; sourceTree = "<group>"; };
|
||||
7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = FakeRCTBlobManager.h; sourceTree = "<group>"; };
|
||||
8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.release.xcconfig"; sourceTree = "<group>"; };
|
||||
DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = OnnxruntimeModule.mm; sourceTree = SOURCE_ROOT; };
|
||||
DB8FC9B725C2868700C72F26 /* TensorHelper.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = TensorHelper.mm; sourceTree = SOURCE_ROOT; };
|
||||
DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = OnnxruntimeModuleTest.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
|
|
@ -53,6 +64,7 @@
|
|||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
|
|
@ -61,6 +73,7 @@
|
|||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
DBDB57DC2603211A004F16BE /* libOnnxruntimeModule.a in Frameworks */,
|
||||
0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
|
|
@ -84,16 +97,30 @@
|
|||
134814211AA4EA7D00B7C361 /* Products */,
|
||||
62ED2272D9F9CF7E3D0A8F87 /* Pods */,
|
||||
DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */,
|
||||
6FFDF1594C99DA125B013E34 /* Frameworks */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
62ED2272D9F9CF7E3D0A8F87 /* Pods */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */,
|
||||
548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */,
|
||||
5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */,
|
||||
8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */,
|
||||
);
|
||||
path = Pods;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
6FFDF1594C99DA125B013E34 /* Frameworks */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */,
|
||||
38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */,
|
||||
);
|
||||
name = Frameworks;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
DB8FC9B325C2861300C72F26 /* OnnxruntimeModule */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
|
|
@ -109,6 +136,8 @@
|
|||
DBDB57D92603211A004F16BE /* TensorHelperTest.mm */,
|
||||
DBDB57DB2603211A004F16BE /* Info.plist */,
|
||||
DBDB58AF262A92D6004F16BE /* OnnxruntimeModuleTest.mm */,
|
||||
7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */,
|
||||
7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */,
|
||||
);
|
||||
path = OnnxruntimeModuleTest;
|
||||
sourceTree = "<group>";
|
||||
|
|
@ -120,6 +149,7 @@
|
|||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = 58B511EF1A9E6C8500147676 /* Build configuration list for PBXNativeTarget "OnnxruntimeModule" */;
|
||||
buildPhases = (
|
||||
FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */,
|
||||
58B511D71A9E6C8500147676 /* Sources */,
|
||||
58B511D81A9E6C8500147676 /* Frameworks */,
|
||||
58B511D91A9E6C8500147676 /* CopyFiles */,
|
||||
|
|
@ -137,9 +167,11 @@
|
|||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = DBDB57E12603211A004F16BE /* Build configuration list for PBXNativeTarget "OnnxruntimeModuleTest" */;
|
||||
buildPhases = (
|
||||
896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */,
|
||||
DBDB57D32603211A004F16BE /* Sources */,
|
||||
DBDB57D42603211A004F16BE /* Frameworks */,
|
||||
DBDB57D52603211A004F16BE /* Resources */,
|
||||
015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
|
|
@ -200,6 +232,119 @@
|
|||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXShellScriptBuildPhase section */
|
||||
015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */ = {
|
||||
isa = PBXShellScriptBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
inputPaths = (
|
||||
"${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh",
|
||||
"${BUILT_PRODUCTS_DIR}/DoubleConversion/DoubleConversion.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/RCT-Folly/folly.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/RCTTypeSafety/RCTTypeSafety.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-Codegen/React_Codegen.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-Core/React.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-CoreModules/CoreModules.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTAnimation/RCTAnimation.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTBlob/RCTBlob.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTImage/RCTImage.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTLinking/RCTLinking.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTNetwork/RCTNetwork.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTSettings/RCTSettings.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTText/RCTText.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-RCTVibration/RCTVibration.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-bridging/react_bridging.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-cxxreact/cxxreact.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-jsi/jsi.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-jsiexecutor/jsireact.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-jsinspector/jsinspector.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-logger/logger.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/React-perflogger/reactperflogger.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/ReactCommon/ReactCommon.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/Yoga/yoga.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/fmt/fmt.framework",
|
||||
"${BUILT_PRODUCTS_DIR}/glog/glog.framework",
|
||||
);
|
||||
name = "[CP] Embed Pods Frameworks";
|
||||
outputPaths = (
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/DoubleConversion.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/folly.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTTypeSafety.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React_Codegen.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/CoreModules.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTAnimation.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTBlob.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTImage.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTLinking.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTNetwork.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTSettings.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTText.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTVibration.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/react_bridging.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/cxxreact.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsi.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsireact.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsinspector.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/logger.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/reactperflogger.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ReactCommon.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/yoga.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fmt.framework",
|
||||
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/glog.framework",
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
shellPath = /bin/sh;
|
||||
shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh\"\n";
|
||||
showEnvVarsInLog = 0;
|
||||
};
|
||||
896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */ = {
|
||||
isa = PBXShellScriptBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
inputFileListPaths = (
|
||||
);
|
||||
inputPaths = (
|
||||
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
|
||||
"${PODS_ROOT}/Manifest.lock",
|
||||
);
|
||||
name = "[CP] Check Pods Manifest.lock";
|
||||
outputFileListPaths = (
|
||||
);
|
||||
outputPaths = (
|
||||
"$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModuleTest-checkManifestLockResult.txt",
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
shellPath = /bin/sh;
|
||||
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
|
||||
showEnvVarsInLog = 0;
|
||||
};
|
||||
FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */ = {
|
||||
isa = PBXShellScriptBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
inputFileListPaths = (
|
||||
);
|
||||
inputPaths = (
|
||||
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
|
||||
"${PODS_ROOT}/Manifest.lock",
|
||||
);
|
||||
name = "[CP] Check Pods Manifest.lock";
|
||||
outputFileListPaths = (
|
||||
);
|
||||
outputPaths = (
|
||||
"$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModule-checkManifestLockResult.txt",
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
shellPath = /bin/sh;
|
||||
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
|
||||
showEnvVarsInLog = 0;
|
||||
};
|
||||
/* End PBXShellScriptBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
58B511D71A9E6C8500147676 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
|
|
@ -216,6 +361,7 @@
|
|||
files = (
|
||||
DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */,
|
||||
DBDB58B0262A92D7004F16BE /* OnnxruntimeModuleTest.mm in Sources */,
|
||||
7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
|
|
@ -329,6 +475,7 @@
|
|||
};
|
||||
58B511F01A9E6C8500147676 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
baseConfigurationReference = 63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */;
|
||||
buildSettings = {
|
||||
HEADER_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
|
|
@ -352,6 +499,7 @@
|
|||
};
|
||||
58B511F11A9E6C8500147676 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
baseConfigurationReference = 548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */;
|
||||
buildSettings = {
|
||||
HEADER_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
|
|
@ -374,6 +522,7 @@
|
|||
};
|
||||
DBDB57DF2603211A004F16BE /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
baseConfigurationReference = 5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */;
|
||||
buildSettings = {
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
|
|
@ -446,6 +595,7 @@
|
|||
};
|
||||
DBDB57E02603211A004F16BE /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
baseConfigurationReference = 8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */;
|
||||
buildSettings = {
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef FakeRCTBlobManager_h
|
||||
#define FakeRCTBlobManager_h
|
||||
|
||||
#import <React/RCTBlobManager.h>
|
||||
|
||||
@interface FakeRCTBlobManager : RCTBlobManager
|
||||
|
||||
@property (nonatomic, strong) NSMutableDictionary *blobs;
|
||||
|
||||
- (NSString *)store:(NSData *)data;
|
||||
|
||||
- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size;
|
||||
|
||||
- (NSDictionary *)testCreateData:(NSData *)buffer;
|
||||
|
||||
- (NSString *)testGetData:(NSDictionary *)data;
|
||||
|
||||
@end
|
||||
|
||||
#endif /* FakeRCTBlobManager_h */
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "FakeRCTBlobManager.h"
|
||||
|
||||
@implementation FakeRCTBlobManager
|
||||
|
||||
- (instancetype)init {
|
||||
if (self = [super init]) {
|
||||
_blobs = [NSMutableDictionary new];
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (NSString *)store:(NSData *)data {
|
||||
NSString *blobId = [[NSUUID UUID] UUIDString];
|
||||
_blobs[blobId] = data;
|
||||
return blobId;
|
||||
}
|
||||
|
||||
- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size {
|
||||
NSData *data = _blobs[blobId];
|
||||
if (data == nil) {
|
||||
return nil;
|
||||
}
|
||||
return [data subdataWithRange:NSMakeRange(offset, size)];
|
||||
}
|
||||
|
||||
- (NSDictionary *)testCreateData:(NSData *)buffer {
|
||||
NSString* blobId = [self store:buffer];
|
||||
return @{
|
||||
@"blobId": blobId,
|
||||
@"offset": @0,
|
||||
@"size": @(buffer.length),
|
||||
};
|
||||
}
|
||||
|
||||
- (NSString *)testGetData:(NSDictionary *)data {
|
||||
NSString *blobId = [data objectForKey:@"blobId"];
|
||||
long size = [[data objectForKey:@"size"] longValue];
|
||||
long offset = [[data objectForKey:@"offset"] longValue];
|
||||
[self resolve:blobId offset:offset size:size];
|
||||
return blobId;
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#import "OnnxruntimeModule.h"
|
||||
#import "FakeRCTBlobManager.h"
|
||||
#import "TensorHelper.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
|
@ -13,6 +14,14 @@
|
|||
|
||||
@implementation OnnxruntimeModuleTest
|
||||
|
||||
FakeRCTBlobManager *fakeBlobManager = nil;
|
||||
|
||||
+ (void)initialize {
|
||||
if (self == [OnnxruntimeModuleTest class]) {
|
||||
fakeBlobManager = [FakeRCTBlobManager new];
|
||||
}
|
||||
}
|
||||
|
||||
- (void)testOnnxruntimeModule {
|
||||
NSBundle *bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]];
|
||||
NSString *dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"];
|
||||
|
|
@ -20,6 +29,7 @@
|
|||
NSString *sessionKey2 = @"";
|
||||
|
||||
OnnxruntimeModule *onnxruntimeModule = [OnnxruntimeModule new];
|
||||
[onnxruntimeModule setBlobManager:fakeBlobManager];
|
||||
|
||||
{
|
||||
// test loadModelFromBuffer()
|
||||
|
|
@ -70,8 +80,8 @@
|
|||
}
|
||||
floatPtr = (float *)[byteBufferRef bytes];
|
||||
|
||||
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
|
||||
inputTensorMap[@"data"] = dataEncoded;
|
||||
XCTAssertNotNil(fakeBlobManager);
|
||||
inputTensorMap[@"data"] = [fakeBlobManager testCreateData:byteBufferRef];
|
||||
|
||||
NSMutableDictionary *inputDataMap = [NSMutableDictionary dictionary];
|
||||
inputDataMap[@"input"] = inputTensorMap;
|
||||
|
|
@ -84,8 +94,18 @@
|
|||
NSDictionary *resultMap = [onnxruntimeModule run:sessionKey input:inputDataMap output:output options:options];
|
||||
NSDictionary *resultMap2 = [onnxruntimeModule run:sessionKey2 input:inputDataMap output:output options:options];
|
||||
|
||||
XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
|
||||
XCTAssertTrue([[resultMap2 objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
|
||||
// Compare output & input, but data.blobId is different
|
||||
// dims
|
||||
XCTAssertTrue([[resultMap objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
|
||||
XCTAssertTrue([[resultMap2 objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
|
||||
|
||||
// type
|
||||
XCTAssertEqual([resultMap objectForKey:@"output"][@"type"], JsTensorTypeFloat);
|
||||
XCTAssertEqual([resultMap2 objectForKey:@"output"][@"type"], JsTensorTypeFloat);
|
||||
|
||||
// data ({ blobId, offset, size })
|
||||
XCTAssertEqual([[resultMap objectForKey:@"output"][@"data"][@"offset"] longValue], 0);
|
||||
XCTAssertEqual([[resultMap2 objectForKey:@"output"][@"data"][@"size"] longValue], byteBufferSize);
|
||||
}
|
||||
|
||||
// test dispose
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#import "TensorHelper.h"
|
||||
|
||||
#import "FakeRCTBlobManager.h"
|
||||
#import <XCTest/XCTest.h>
|
||||
#import <onnxruntime/onnxruntime_cxx_api.h>
|
||||
#include <vector>
|
||||
|
|
@ -13,6 +14,14 @@
|
|||
|
||||
@implementation TensorHelperTest
|
||||
|
||||
FakeRCTBlobManager *testBlobManager = nil;
|
||||
|
||||
+ (void)initialize {
|
||||
if (self == [TensorHelperTest class]) {
|
||||
testBlobManager = [FakeRCTBlobManager new];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::function<NSNumber *(T value)> &convert,
|
||||
ONNXTensorElementDataType onnxType, NSString *jsTensorType) {
|
||||
|
|
@ -34,12 +43,13 @@ static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::funct
|
|||
typePtr[i] = outValues[i];
|
||||
}
|
||||
|
||||
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
|
||||
inputTensorMap[@"data"] = dataEncoded;
|
||||
XCTAssertNotNil(testBlobManager);
|
||||
inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef];
|
||||
|
||||
Ort::AllocatorWithDefaultOptions ortAllocator;
|
||||
std::vector<Ort::MemoryAllocation> allocations;
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
|
||||
input:inputTensorMap
|
||||
ortAllocator:ortAllocator
|
||||
allocations:allocations];
|
||||
|
||||
|
|
@ -126,7 +136,8 @@ static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::funct
|
|||
|
||||
Ort::AllocatorWithDefaultOptions ortAllocator;
|
||||
std::vector<Ort::MemoryAllocation> allocations;
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
|
||||
input:inputTensorMap
|
||||
ortAllocator:ortAllocator
|
||||
allocations:allocations];
|
||||
|
||||
|
|
@ -194,10 +205,11 @@ static void testCreateOutputTensorT(const std::array<T, 5> &outValues, std::func
|
|||
typePtr[i] = outValues[i];
|
||||
}
|
||||
|
||||
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
|
||||
inputTensorMap[@"data"] = dataEncoded;
|
||||
inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef];
|
||||
;
|
||||
std::vector<Ort::MemoryAllocation> allocations;
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
|
||||
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
|
||||
input:inputTensorMap
|
||||
ortAllocator:ortAllocator
|
||||
allocations:allocations];
|
||||
|
||||
|
|
@ -208,9 +220,24 @@ static void testCreateOutputTensorT(const std::array<T, 5> &outValues, std::func
|
|||
auto output = session.Run(runOptions, inputNames.data(), feeds.data(), inputNames.size(), outputNames.data(),
|
||||
outputNames.size());
|
||||
|
||||
NSDictionary *resultMap = [TensorHelper createOutputTensor:outputNames values:output];
|
||||
NSDictionary *resultMap = [TensorHelper createOutputTensor:testBlobManager outputNames:outputNames values:output];
|
||||
|
||||
XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
|
||||
// Compare output & input, but data.blobId is different
|
||||
|
||||
NSDictionary *outputMap = [resultMap objectForKey:@"output"];
|
||||
|
||||
// dims
|
||||
XCTAssertTrue([outputMap[@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
|
||||
|
||||
// type
|
||||
XCTAssertEqual(outputMap[@"type"], jsTensorType);
|
||||
|
||||
// data ({ blobId, offset, size })
|
||||
NSDictionary *data = outputMap[@"data"];
|
||||
|
||||
XCTAssertNotNil(data[@"blobId"]);
|
||||
XCTAssertEqual([data[@"offset"] longValue], 0);
|
||||
XCTAssertEqual([data[@"size"] longValue], byteBufferSize);
|
||||
}
|
||||
|
||||
- (void)testCreateOutputTensorFloat {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#define TensorHelper_h
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <React/RCTBlobManager.h>
|
||||
|
||||
// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened
|
||||
// in an expo react native ios app (a redefinition error happened with multiple object types defined within
|
||||
|
|
@ -36,17 +37,19 @@ FOUNDATION_EXPORT NSString* const JsTensorTypeString;
|
|||
|
||||
/**
|
||||
* It creates an input tensor from a map passed by react native js.
|
||||
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
|
||||
* 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor.
|
||||
*/
|
||||
+(Ort::Value)createInputTensor:(NSDictionary*)input
|
||||
+(Ort::Value)createInputTensor:(RCTBlobManager *)blobManager
|
||||
input:(NSDictionary*)input
|
||||
ortAllocator:(OrtAllocator*)ortAllocator
|
||||
allocations:(std::vector<Ort::MemoryAllocation>&)allocatons;
|
||||
allocations:(std::vector<Ort::MemoryAllocation>&)allocations;
|
||||
|
||||
/**
|
||||
* It creates an output map from an output tensor.
|
||||
* a data array is encoded as base64 string.
|
||||
* a data array is store in RCTBlobManager.
|
||||
*/
|
||||
+(NSDictionary*)createOutputTensor:(const std::vector<const char*>&)outputNames
|
||||
+(NSDictionary*)createOutputTensor:(RCTBlobManager *)blobManager
|
||||
outputNames:(const std::vector<const char*>&)outputNames
|
||||
values:(const std::vector<Ort::Value>&)values;
|
||||
|
||||
@end
|
||||
|
|
|
|||
|
|
@ -21,11 +21,12 @@ NSString *const JsTensorTypeString = @"string";
|
|||
|
||||
/**
|
||||
* It creates an input tensor from a map passed by react native js.
|
||||
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
|
||||
* 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor.
|
||||
*/
|
||||
+ (Ort::Value)createInputTensor:(NSDictionary *)input
|
||||
+ (Ort::Value)createInputTensor:(RCTBlobManager *)blobManager
|
||||
input:(NSDictionary *)input
|
||||
ortAllocator:(OrtAllocator *)ortAllocator
|
||||
allocations:(std::vector<Ort::MemoryAllocation> &)allocatons {
|
||||
allocations:(std::vector<Ort::MemoryAllocation> &)allocations {
|
||||
// shape
|
||||
NSArray *dimsArray = [input objectForKey:@"dims"];
|
||||
std::vector<int64_t> dims;
|
||||
|
|
@ -48,22 +49,27 @@ NSString *const JsTensorTypeString = @"string";
|
|||
}
|
||||
return inputTensor;
|
||||
} else {
|
||||
NSString *data = [input objectForKey:@"data"];
|
||||
NSData *buffer = [[NSData alloc] initWithBase64EncodedString:data options:0];
|
||||
NSDictionary *data = [input objectForKey:@"data"];
|
||||
NSString *blobId = [data objectForKey:@"blobId"];
|
||||
long size = [[data objectForKey:@"size"] longValue];
|
||||
long offset = [[data objectForKey:@"offset"] longValue];
|
||||
auto buffer = [blobManager resolve:blobId offset:offset size:size];
|
||||
Ort::Value inputTensor = [self createInputTensor:tensorType
|
||||
dims:dims
|
||||
buffer:buffer
|
||||
ortAllocator:ortAllocator
|
||||
allocations:allocatons];
|
||||
allocations:allocations];
|
||||
[blobManager remove:blobId];
|
||||
return inputTensor;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* It creates an output map from an output tensor.
|
||||
* a data array is encoded as base64 string.
|
||||
* a data array is store in RCTBlobManager.
|
||||
*/
|
||||
+ (NSDictionary *)createOutputTensor:(const std::vector<const char *> &)outputNames
|
||||
+ (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager
|
||||
outputNames:(const std::vector<const char *> &)outputNames
|
||||
values:(const std::vector<Ort::Value> &)values {
|
||||
if (outputNames.size() != values.size()) {
|
||||
NSException *exception = [NSException exceptionWithName:@"create output tensor"
|
||||
|
|
@ -109,8 +115,13 @@ NSString *const JsTensorTypeString = @"string";
|
|||
}
|
||||
outputTensor[@"data"] = buffer;
|
||||
} else {
|
||||
NSString *data = [self createOutputTensor:value];
|
||||
outputTensor[@"data"] = data;
|
||||
NSData *data = [self createOutputTensor:value];
|
||||
NSString *blobId = [blobManager store:data];
|
||||
outputTensor[@"data"] = @{
|
||||
@"blobId" : blobId,
|
||||
@"offset" : @0,
|
||||
@"size" : @(data.length),
|
||||
};
|
||||
}
|
||||
|
||||
outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor;
|
||||
|
|
@ -170,15 +181,14 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T> static NSString *createOutputTensorT(const Ort::Value &tensor) {
|
||||
template <typename T> static NSData *createOutputTensorT(const Ort::Value &tensor) {
|
||||
const auto data = tensor.GetTensorData<T>();
|
||||
NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data
|
||||
length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T)
|
||||
freeWhenDone:false];
|
||||
return [buffer base64EncodedStringWithOptions:0];
|
||||
return [NSData dataWithBytesNoCopy:(void *)data
|
||||
length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T)
|
||||
freeWhenDone:false];
|
||||
}
|
||||
|
||||
+ (NSString *)createOutputTensor:(const Ort::Value &)tensor {
|
||||
+ (NSData *)createOutputTensor:(const Ort::Value &)tensor {
|
||||
ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType();
|
||||
|
||||
switch (tensorType) {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Buffer} from 'buffer';
|
||||
import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common';
|
||||
import {Platform} from 'react-native';
|
||||
|
||||
import {binding, Binding} from './binding';
|
||||
import {binding, Binding, JSIBlob, jsiHelper} from './binding';
|
||||
|
||||
type SupportedTypedArray = Exclude<Tensor.DataType, string[]>;
|
||||
|
||||
|
|
@ -69,11 +68,11 @@ class OnnxruntimeSessionHandler implements SessionHandler {
|
|||
if (typeof this.#pathOrBuffer === 'string') {
|
||||
results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options);
|
||||
} else {
|
||||
if (!this.#inferenceSession.loadModelFromBase64EncodedBuffer) {
|
||||
throw new Error('Native module method "loadModelFromBase64EncodedBuffer" is not defined');
|
||||
if (!this.#inferenceSession.loadModelFromBlob) {
|
||||
throw new Error('Native module method "loadModelFromBlob" is not defined');
|
||||
}
|
||||
const modelInBase64String = Buffer.from(this.#pathOrBuffer).toString('base64');
|
||||
results = await this.#inferenceSession.loadModelFromBase64EncodedBuffer(modelInBase64String, options);
|
||||
const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer);
|
||||
results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options);
|
||||
}
|
||||
// resolve promise if onnxruntime session is successfully created
|
||||
this.#key = results.key;
|
||||
|
|
@ -113,18 +112,18 @@ class OnnxruntimeSessionHandler implements SessionHandler {
|
|||
return output;
|
||||
}
|
||||
|
||||
|
||||
encodeFeedsType(feeds: SessionHandler.FeedsType): Binding.FeedsType {
|
||||
const returnValue: {[name: string]: Binding.EncodedTensorType} = {};
|
||||
for (const key in feeds) {
|
||||
if (Object.hasOwnProperty.call(feeds, key)) {
|
||||
let data: string|string[];
|
||||
let data: JSIBlob|string[];
|
||||
|
||||
if (Array.isArray(feeds[key].data)) {
|
||||
data = feeds[key].data as string[];
|
||||
} else {
|
||||
// Base64-encode tensor data
|
||||
const buffer = (feeds[key].data as SupportedTypedArray).buffer;
|
||||
data = Buffer.from(buffer, 0, buffer.byteLength).toString('base64');
|
||||
data = jsiHelper.storeArrayBuffer(buffer);
|
||||
}
|
||||
|
||||
returnValue[key] = {
|
||||
|
|
@ -146,9 +145,9 @@ class OnnxruntimeSessionHandler implements SessionHandler {
|
|||
if (Array.isArray(results[key].data)) {
|
||||
tensorData = results[key].data as string[];
|
||||
} else {
|
||||
const buffer: Buffer = Buffer.from(results[key].data as string, 'base64');
|
||||
const buffer = jsiHelper.resolveArrayBuffer(results[key].data as JSIBlob) as SupportedTypedArray;
|
||||
const typedArray = tensorTypeToTypedArray(results[key].type as Tensor.Type);
|
||||
tensorData = new typedArray(buffer.buffer, buffer.byteOffset, buffer.length / typedArray.BYTES_PER_ELEMENT);
|
||||
tensorData = new typedArray(buffer, buffer.byteOffset, buffer.byteLength / typedArray.BYTES_PER_ELEMENT);
|
||||
}
|
||||
|
||||
returnValue[key] = new Tensor(results[key].type as Tensor.Type, tensorData, results[key].dims);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,14 @@ interface ModelLoadInfo {
|
|||
}
|
||||
|
||||
/**
|
||||
* Tensor type for react native, which doesn't allow ArrayBuffer, so data will be encoded as Base64 string.
|
||||
* JSIBlob is a blob object that exchange ArrayBuffer by OnnxruntimeJSIHelper.
|
||||
*/
|
||||
export type JSIBlob = {
|
||||
blobId: string; offset: number; size: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Tensor type for react native, which doesn't allow ArrayBuffer in native bridge, so data will be stored as JSIBlob.
|
||||
*/
|
||||
interface EncodedTensor {
|
||||
/**
|
||||
|
|
@ -38,10 +45,10 @@ interface EncodedTensor {
|
|||
*/
|
||||
readonly type: string;
|
||||
/**
|
||||
* the Base64 encoded string of the buffer data of the tensor.
|
||||
* if data is string array, it won't be encoded as Base64 string.
|
||||
* the JSIBlob object of the buffer data of the tensor.
|
||||
* if data is string array, it won't be stored as JSIBlob.
|
||||
*/
|
||||
readonly data: string|string[];
|
||||
readonly data: JSIBlob|string[];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -64,12 +71,41 @@ export declare namespace Binding {
|
|||
|
||||
interface InferenceSession {
|
||||
loadModel(modelPath: string, options: SessionOptions): Promise<ModelLoadInfoType>;
|
||||
loadModelFromBase64EncodedBuffer?(buffer: string, options: SessionOptions): Promise<ModelLoadInfoType>;
|
||||
loadModelFromBlob?(blob: JSIBlob, options: SessionOptions): Promise<ModelLoadInfoType>;
|
||||
dispose(key: string): Promise<void>;
|
||||
run(key: string, feeds: FeedsType, fetches: FetchesType, options: RunOptions): Promise<ReturnType>;
|
||||
}
|
||||
}
|
||||
|
||||
// export native binding
|
||||
const {Onnxruntime} = NativeModules;
|
||||
const {Onnxruntime, OnnxruntimeJSIHelper} = NativeModules;
|
||||
export const binding = Onnxruntime as Binding.InferenceSession;
|
||||
|
||||
// install JSI helper global functions
|
||||
OnnxruntimeJSIHelper.install();
|
||||
|
||||
declare global {
|
||||
// eslint-disable-next-line no-var
|
||||
var jsiOnnxruntimeStoreArrayBuffer: ((buffer: ArrayBuffer) => JSIBlob)|undefined;
|
||||
// eslint-disable-next-line no-var
|
||||
var jsiOnnxruntimeResolveArrayBuffer: ((blob: JSIBlob) => ArrayBuffer)|undefined;
|
||||
}
|
||||
|
||||
export const jsiHelper = {
|
||||
storeArrayBuffer: globalThis.jsiOnnxruntimeStoreArrayBuffer || (() => {
|
||||
throw new Error(
|
||||
'jsiOnnxruntimeStoreArrayBuffer is not found, ' +
|
||||
'please make sure OnnxruntimeJSIHelper installation is successful.');
|
||||
}),
|
||||
resolveArrayBuffer: globalThis.jsiOnnxruntimeResolveArrayBuffer || (() => {
|
||||
throw new Error(
|
||||
'jsiOnnxruntimeResolveArrayBuffer is not found, ' +
|
||||
'please make sure OnnxruntimeJSIHelper installation is successful.');
|
||||
}),
|
||||
};
|
||||
|
||||
// Remove global functions after installation
|
||||
{
|
||||
delete globalThis.jsiOnnxruntimeStoreArrayBuffer;
|
||||
delete globalThis.jsiOnnxruntimeResolveArrayBuffer;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue