[js/rn] Implement blob exchange by JSI instead of use base64 (#16094)

### Description
<!-- Describe your changes. -->

- Create `OnnxruntimeJSIHelper` native module to provide two JSI
functions
- `jsiOnnxruntimeStoreArrayBuffer`: Store buffer in Blob Manager &
return blob object (iOS: RCTBlobManager, Android: BlobModule)
  - `jsiOnnxruntimeResolveArrayBuffer`: Use blob object to get buffer
- The part of implementation is reference to
[react-native-blob-jsi-helper](https://github.com/mrousavy/react-native-blob-jsi-helper)
- Replace base64 encode/decode
  - `loadModelFromBlob`: Rename from `loadModelFromBase64EncodedBuffer`
  - `run`: Use blob object to replace input.data & results[].data

For [this
context](https://github.com/microsoft/onnxruntime/issues/16031#issuecomment-1556527812),
it saved a lot of time and avoid JS thread blocking in decode return
type, it is 3700ms -> 5~20ms for the case. (resolve function only takes
0.x ms)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

It’s related to #16031, but not a full implementation for migrate to
JSI.

It just uses JSI through BlobManager to replace the slow part (base64
encode / decode).

Rewriting it entirely in JSI could be complicated, like type convertion
and threading. This PR might be considered a minor change.

/cc @skottmckay
This commit is contained in:
Jhen-Jie Hong 2023-06-16 17:37:02 +08:00 committed by GitHub
parent 9110e5b9bd
commit ea1a5cf920
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 935 additions and 141 deletions

View file

@ -0,0 +1,37 @@
project(OnnxruntimeJSIHelper)
cmake_minimum_required(VERSION 3.9.0)
set (PACKAGE_NAME "onnxruntime-react-native")
set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CXX_STANDARD 17)
file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath)
include_directories(
"${NODE_MODULES_DIR}/react-native/React"
"${NODE_MODULES_DIR}/react-native/React/Base"
"${NODE_MODULES_DIR}/react-native/ReactCommon/jsi"
)
add_library(onnxruntimejsihelper
SHARED
${libPath}
src/main/cpp/cpp-adapter.cpp
)
# Configure C++ 17
set_target_properties(
onnxruntimejsihelper PROPERTIES
CXX_STANDARD 17
CXX_EXTENSIONS OFF
POSITION_INDEPENDENT_CODE ON
)
find_library(log-lib log)
target_link_libraries(
onnxruntimejsihelper
${log-lib} # <-- Logcat logger
android # <-- Android JNI core
)

View file

@ -1,3 +1,5 @@
import java.nio.file.Paths
buildscript {
repositories {
google()
@ -20,6 +22,32 @@ def getExtOrIntegerDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
}
def reactNativeArchitectures() {
def value = project.getProperties().get("reactNativeArchitectures")
return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"]
}
def resolveBuildType() {
Gradle gradle = getGradle()
String tskReqStr = gradle.getStartParameter().getTaskRequests()['args'].toString()
return tskReqStr.contains('Release') ? 'release' : 'debug'
}
static def findNodeModules(baseDir) {
def basePath = baseDir.toPath().normalize()
while (basePath) {
def nodeModulesPath = Paths.get(basePath.toString(), "node_modules")
def reactNativePath = Paths.get(nodeModulesPath.toString(), "react-native")
if (nodeModulesPath.toFile().exists() && reactNativePath.toFile().exists()) {
return nodeModulesPath.toString()
}
basePath = basePath.getParent()
}
throw new GradleException("onnxruntime-react-native: Failed to find node_modules/ path!")
}
def nodeModules = findNodeModules(projectDir);
def checkIfOrtExtensionsEnabled() {
// locate user's project dir
def reactnativeRootDir = project.rootDir.parentFile
@ -38,6 +66,9 @@ def checkIfOrtExtensionsEnabled() {
boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()
def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim()
def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger()
android {
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
buildToolsVersion getExtOrDefault('buildToolsVersion')
@ -47,6 +78,44 @@ android {
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
cmake {
cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all"
if (REACT_NATIVE_MINOR_VERSION >= 71) {
// fabricjni required c++_shared
arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
} else {
arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
}
abiFilters (*reactNativeArchitectures())
}
}
}
if (rootProject.hasProperty("ndkPath")) {
ndkPath rootProject.ext.ndkPath
}
if (rootProject.hasProperty("ndkVersion")) {
ndkVersion rootProject.ext.ndkVersion
}
buildFeatures {
prefab true
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
packagingOptions {
doNotStrip resolveBuildType() == 'debug' ? "**/**/*.so" : ''
excludes = [
"META-INF",
"META-INF/**",
"**/libjsi.so",
]
}
buildTypes {
@ -149,8 +218,6 @@ repositories {
}
}
def REACT_NATIVE_VERSION = new File(['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim())
dependencies {
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
api "org.mockito:mockito-core:2.28.2"

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package ai.onnxruntime.reactnative;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.modules.blob.BlobModule;
public class FakeBlobModule extends BlobModule {
public FakeBlobModule(ReactApplicationContext context) { super(null); }
@Override
public String getName() {
return "BlobModule";
}
public JavaOnlyMap testCreateData(byte[] bytes) {
String blobId = store(bytes);
JavaOnlyMap data = new JavaOnlyMap();
data.putString("blobId", blobId);
data.putInt("offset", 0);
data.putInt("size", bytes.length);
return data;
}
public byte[] testGetData(ReadableMap data) {
String blobId = data.getString("blobId");
int offset = data.getInt("offset");
int size = data.getInt("size");
return resolve(blobId, offset, size);
}
}

View file

@ -10,11 +10,14 @@ import ai.onnxruntime.TensorInfo;
import android.util.Base64;
import androidx.test.platform.app.InstrumentationRegistry;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.CatalystInstance;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
@ -29,12 +32,17 @@ public class OnnxruntimeModuleTest {
private ReactApplicationContext reactContext =
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());
private FakeBlobModule blobModule;
@Before
public void setUp() {}
public void setUp() {
blobModule = new FakeBlobModule(reactContext);
}
@Test
public void getName() throws Exception {
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String name = "Onnxruntime";
Assert.assertEquals(ortModule.getName(), name);
}
@ -47,6 +55,7 @@ public class OnnxruntimeModuleTest {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String sessionKey = "";
// test loadModel()
@ -104,8 +113,7 @@ public class OnnxruntimeModuleTest {
floatBuffer.put(value);
}
floatBuffer.rewind();
String dataEncoded = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array()));
inputDataMap.putMap("input", inputTensorMap);
}
@ -124,10 +132,9 @@ public class OnnxruntimeModuleTest {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
String dataEncoded = outputMap.getString("data");
FloatBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT))
.order(ByteOrder.nativeOrder())
.asFloatBuffer();
ReadableMap data = outputMap.getMap("data");
FloatBuffer buffer =
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
}

View file

@ -20,7 +20,9 @@ import androidx.test.platform.app.InstrumentationRegistry;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
@ -39,11 +41,17 @@ import org.mockito.MockitoSession;
@SmallTest
public class TensorHelperTest {
private ReactApplicationContext reactContext =
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());
private OrtEnvironment ortEnvironment;
private FakeBlobModule blobModule;
@Before
public void setUp() {
ortEnvironment = OrtEnvironment.getEnvironment("TensorHelperTest");
blobModule = new FakeBlobModule(reactContext);
}
@Test
@ -64,10 +72,9 @@ public class TensorHelperTest {
dataFloatBuffer.put(Float.MIN_VALUE);
dataFloatBuffer.put(2.0f);
dataFloatBuffer.put(Float.MAX_VALUE);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
@ -94,10 +101,9 @@ public class TensorHelperTest {
dataByteBuffer.put(Byte.MIN_VALUE);
dataByteBuffer.put((byte)2);
dataByteBuffer.put(Byte.MAX_VALUE);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8);
@ -125,10 +131,9 @@ public class TensorHelperTest {
dataByteBuffer.put((byte)0);
dataByteBuffer.put((byte)2);
dataByteBuffer.put((byte)255);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
@ -157,10 +162,9 @@ public class TensorHelperTest {
dataIntBuffer.put(Integer.MIN_VALUE);
dataIntBuffer.put(2);
dataIntBuffer.put(Integer.MAX_VALUE);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
@ -189,10 +193,9 @@ public class TensorHelperTest {
dataLongBuffer.put(Long.MIN_VALUE);
dataLongBuffer.put(15000000001L);
dataLongBuffer.put(Long.MAX_VALUE);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
@ -221,10 +224,9 @@ public class TensorHelperTest {
dataDoubleBuffer.put(Double.MIN_VALUE);
dataDoubleBuffer.put(1.8e+30);
dataDoubleBuffer.put(Double.MAX_VALUE);
String dataEncoded = Base64.encodeToString(dataByteBuffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array()));
OnnxTensor inputTensor = TensorHelper.createInputTensor(inputTensorMap, ortEnvironment);
OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment);
Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
Assert.assertEquals(outputTensor.getInfo().onnxType,
@ -258,14 +260,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeBool);
String dataEncoded = outputMap.getString("data");
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
ReadableMap data = outputMap.getMap("data");
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i) == 1, inputData[i]);
}
@ -298,15 +300,15 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeDouble);
String dataEncoded = outputMap.getString("data");
ReadableMap data = outputMap.getMap("data");
DoubleBuffer buffer =
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asDoubleBuffer();
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asDoubleBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
}
@ -339,15 +341,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
String dataEncoded = outputMap.getString("data");
FloatBuffer buffer =
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asFloatBuffer();
ReadableMap data = outputMap.getMap("data");
FloatBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
}
@ -380,14 +381,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeByte);
String dataEncoded = outputMap.getString("data");
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
ReadableMap data = outputMap.getMap("data");
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i]);
}
@ -420,15 +421,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeInt);
String dataEncoded = outputMap.getString("data");
IntBuffer buffer =
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asIntBuffer();
ReadableMap data = outputMap.getMap("data");
IntBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asIntBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i]);
}
@ -461,15 +461,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeLong);
String dataEncoded = outputMap.getString("data");
LongBuffer buffer =
ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT)).order(ByteOrder.nativeOrder()).asLongBuffer();
ReadableMap data = outputMap.getMap("data");
LongBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asLongBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i]);
}
@ -502,14 +501,14 @@ public class TensorHelperTest {
OrtSession.Result result = session.run(container);
ReadableMap resultMap = TensorHelper.createOutputTensor(result);
ReadableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeUnsignedByte);
String dataEncoded = outputMap.getString("data");
ByteBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT));
ReadableMap data = outputMap.getMap("data");
ByteBuffer buffer = ByteBuffer.wrap(blobModule.testGetData(data));
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i]);
}

View file

@ -0,0 +1,127 @@
#include <jni.h>
#include <jsi/jsi.h>
#include <string>
using namespace facebook;
typedef u_int8_t byte;
std::string jstring2string(JNIEnv *env, jstring jStr) {
if (!jStr) return "";
jclass stringClass = env->GetObjectClass(jStr);
jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B");
const auto stringJbytes = (jbyteArray) env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8"));
auto length = (size_t) env->GetArrayLength(stringJbytes);
jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr);
std::string ret = std::string((char *)pBytes, length);
env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT);
env->DeleteLocalRef(stringJbytes);
env->DeleteLocalRef(stringClass);
return ret;
}
byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& blobId, int offset, int size) {
if (!env) throw std::runtime_error("JNI Environment is gone!");
// get java class
jclass clazz = env->GetObjectClass(instanceGlobal);
// get method in java class
jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B");
// call method
auto jstring = env->NewStringUTF(blobId.c_str());
auto boxedBytes = (jbyteArray) env->CallObjectMethod(instanceGlobal,
getBufferJava,
// arguments
jstring,
offset,
size);
env->DeleteLocalRef(jstring);
jboolean isCopy = true;
jbyte* bytes = env->GetByteArrayElements(boxedBytes, &isCopy);
env->DeleteLocalRef(boxedBytes);
return reinterpret_cast<byte*>(bytes);
};
std::string createBlob(JNIEnv *env, jobject instanceGlobal, byte* bytes, size_t size) {
if (!env) throw std::runtime_error("JNI Environment is gone!");
// get java class
jclass clazz = env->GetObjectClass(instanceGlobal);
// get method in java class
jmethodID getBufferJava = env->GetMethodID(clazz, "createBlob", "([B)Ljava/lang/String;");
// call method
auto byteArray = env->NewByteArray(size);
env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast<jbyte*>(bytes));
auto blobId = (jstring) env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray);
env->DeleteLocalRef(byteArray);
return jstring2string(env, blobId);
};
extern "C"
JNIEXPORT void JNICALL
Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, jclass _, jlong jsiPtr, jobject instance) {
auto jsiRuntime = reinterpret_cast<jsi::Runtime*>(jsiPtr);
auto& runtime = *jsiRuntime;
auto instanceGlobal = env->NewGlobalRef(instance);
auto resolveArrayBuffer = jsi::Function::createFromHostFunction(runtime,
jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"),
1,
[=](jsi::Runtime& runtime,
const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) -> jsi::Value {
if (count != 1) {
throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!");
}
jsi::Object data = arguments[0].asObject(runtime);
auto blobId = data.getProperty(runtime, "blobId").asString(runtime);
auto offset = data.getProperty(runtime, "offset").asNumber();
auto size = data.getProperty(runtime, "size").asNumber();
auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size);
size_t totalSize = size - offset;
jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer");
jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int) totalSize).getObject(runtime);
jsi::ArrayBuffer buf = o.getArrayBuffer(runtime);
memcpy(buf.data(runtime), reinterpret_cast<byte*>(bytes), totalSize);
return buf;
});
runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer));
auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime,
jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeStoreArrayBuffer"),
1,
[=](jsi::Runtime& runtime,
const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) -> jsi::Value {
if (count != 1) {
throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!");
}
auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime);
auto size = arrayBuffer.size(runtime);
std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size);
jsi::Object result(runtime);
auto blobIdString = jsi::String::createFromUtf8(runtime, blobId);
result.setProperty(runtime, "blobId", blobIdString);
result.setProperty(runtime, "offset", jsi::Value(0));
result.setProperty(runtime, "size", jsi::Value(static_cast<double>(size)));
return result;
});
runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer));
}

View file

@ -0,0 +1,70 @@
package ai.onnxruntime.reactnative;
import androidx.annotation.NonNull;
import com.facebook.react.bridge.JavaScriptContextHolder;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.module.annotations.ReactModule;
import com.facebook.react.modules.blob.BlobModule;
@ReactModule(name = OnnxruntimeJSIHelper.NAME)
public class OnnxruntimeJSIHelper extends ReactContextBaseJavaModule {
public static final String NAME = "OnnxruntimeJSIHelper";
private static ReactApplicationContext reactContext;
protected BlobModule blobModule;
public OnnxruntimeJSIHelper(ReactApplicationContext context) {
super(context);
reactContext = context;
}
@Override
@NonNull
public String getName() {
return NAME;
}
public void checkBlobModule() {
if (blobModule == null) {
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
if (blobModule == null) {
throw new RuntimeException("BlobModule is not initialized");
}
}
}
@ReactMethod(isBlockingSynchronousMethod = true)
public boolean install() {
try {
System.loadLibrary("onnxruntimejsihelper");
JavaScriptContextHolder jsContext = getReactApplicationContext().getJavaScriptContextHolder();
nativeInstall(jsContext.get(), this);
return true;
} catch (Exception exception) {
return false;
}
}
public byte[] getBlobBuffer(String blobId, int offset, int size) {
checkBlobModule();
byte[] bytes = blobModule.resolve(blobId, offset, size);
blobModule.remove(blobId);
if (bytes == null) {
throw new RuntimeException("Failed to resolve Blob #" + blobId + "! Not found.");
}
return bytes;
}
public String createBlob(byte[] buffer) {
checkBlobModule();
String blobId = blobModule.store(buffer);
if (blobId == null) {
throw new RuntimeException("Failed to create Blob!");
}
return blobId;
}
public static native void nativeInstall(long jsiPointer, OnnxruntimeJSIHelper instance);
}

View file

@ -13,7 +13,6 @@ import ai.onnxruntime.OrtSession.RunOptions;
import ai.onnxruntime.OrtSession.SessionOptions;
import android.net.Uri;
import android.os.Build;
import android.util.Base64;
import android.util.Log;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
@ -28,6 +27,7 @@ import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReadableType;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
@ -56,6 +56,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
return key;
}
protected BlobModule blobModule;
public OnnxruntimeModule(ReactApplicationContext context) {
super(context);
reactContext = context;
@ -67,6 +69,15 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
return "Onnxruntime";
}
public void checkBlobModule() {
if (blobModule == null) {
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
if (blobModule == null) {
throw new RuntimeException("BlobModule is not initialized");
}
}
}
/**
* React native binding API to load a model using given uri.
*
@ -87,19 +98,22 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
}
/**
* React native binding API to load a model using the BASE64 encoded model data.
* React native binding API to load a model using blob object that data stored in BlobModule.
*
* @param data the BASE64 encoded model data.
* @param data the blob object
* @param options onnxruntime session options
* @param promise output returning back to react native js
* @note the value provided to `promise` includes a key representing the session.
* when run() is called, the key must be passed into the first parameter.
*/
@ReactMethod
public void loadModelFromBase64EncodedBuffer(String data, ReadableMap options, Promise promise) {
public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) {
try {
byte[] modelData = Base64.decode(data, Base64.DEFAULT);
WritableMap resultMap = loadModel(modelData, options);
checkBlobModule();
String blobId = data.getString("blobId");
byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size"));
blobModule.remove(blobId);
WritableMap resultMap = loadModel(bytes, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Failed to load model from buffer: " + e.getMessage(), e);
@ -242,6 +256,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
RunOptions runOptions = parseRunOptions(options);
checkBlobModule();
long startTime = System.currentTimeMillis();
Map<String, OnnxTensor> feed = new HashMap<>();
Iterator<String> iterator = ortSession.getInputNames().iterator();
@ -255,19 +271,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
throw new Exception("Can't find input: " + inputName);
}
if (inputMap.getType("data") != ReadableType.String) {
// NOTE:
//
// tensor data should always be a BASE64 encoded string.
// This is because the current React Native bridge supports limited data type as arguments.
// In order to pass data from JS to Java, we have to encode them into string.
//
// see also:
// https://reactnative.dev/docs/native-modules-android#argument-types
throw new Exception("Non string type of a tensor data is not allowed");
}
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
OnnxTensor onnxTensor = TensorHelper.createInputTensor(blobModule, inputMap, ortEnvironment);
feed.put(inputName, onnxTensor);
}
@ -292,7 +296,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule implements Lif
Log.d("Duration", "inference: " + duration);
startTime = System.currentTimeMillis();
WritableMap resultMap = TensorHelper.createOutputTensor(result);
WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createOutputTensor: " + duration);

View file

@ -9,6 +9,7 @@ import androidx.annotation.RequiresApi;
import com.facebook.react.ReactPackage;
import com.facebook.react.bridge.NativeModule;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.modules.blob.BlobModule;
import com.facebook.react.uimanager.ViewManager;
import java.util.ArrayList;
import java.util.Collections;
@ -21,6 +22,7 @@ public class OnnxruntimePackage implements ReactPackage {
public List<NativeModule> createNativeModules(@NonNull ReactApplicationContext reactContext) {
List<NativeModule> modules = new ArrayList<>();
modules.add(new OnnxruntimeModule(reactContext));
modules.add(new OnnxruntimeJSIHelper(reactContext));
return modules;
}

View file

@ -16,6 +16,7 @@ import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
@ -45,9 +46,10 @@ public class TensorHelper {
/**
* It creates an input tensor from a map passed by react native js.
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
* 'data' is blob object and the buffer is stored in BlobModule. It first resolve it and creates a tensor.
*/
public static OnnxTensor createInputTensor(ReadableMap inputTensor, OrtEnvironment ortEnvironment) throws Exception {
public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap inputTensor,
OrtEnvironment ortEnvironment) throws Exception {
// shape
ReadableArray dimsArray = inputTensor.getArray("dims");
long[] dims = new long[dimsArray.size()];
@ -68,8 +70,11 @@ public class TensorHelper {
}
onnxTensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims);
} else {
String data = inputTensor.getString("data");
ByteBuffer values = ByteBuffer.wrap(Base64.decode(data, Base64.DEFAULT)).order(ByteOrder.nativeOrder());
ReadableMap data = inputTensor.getMap("data");
String blobId = data.getString("blobId");
byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size"));
blobModule.remove(blobId);
ByteBuffer values = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder());
onnxTensor = createInputTensor(tensorType, dims, values, ortEnvironment);
}
@ -78,9 +83,9 @@ public class TensorHelper {
/**
* It creates an output map from an output tensor.
* a data array is encoded as base64 string.
* a data array is store in BlobModule.
*/
public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception {
public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception {
WritableMap outputTensorMap = Arguments.createMap();
Iterator<Map.Entry<String, OnnxValue>> iterator = result.iterator();
@ -115,8 +120,13 @@ public class TensorHelper {
}
outputTensor.putArray("data", dataArray);
} else {
String data = createOutputTensor(onnxTensor);
outputTensor.putString("data", data);
// Store in BlobModule then create a blob object as data
byte[] bufferArray = createOutputTensor(onnxTensor);
WritableMap data = Arguments.createMap();
data.putString("blobId", blobModule.store(bufferArray));
data.putInt("offset", 0);
data.putInt("size", bufferArray.length);
outputTensor.putMap("data", data);
}
outputTensorMap.putMap(outputName, outputTensor);
@ -177,7 +187,7 @@ public class TensorHelper {
return tensor;
}
private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception {
private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception {
TensorInfo tensorInfo = onnxTensor.getInfo();
ByteBuffer buffer = null;
@ -224,8 +234,7 @@ public class TensorHelper {
throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString());
}
String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
return data;
return buffer.array();
}
private static final Map<String, TensorInfo.OnnxTensorType> JsTensorTypeToOnnxTensorTypeMap =

View file

@ -0,0 +1,5 @@
#import <React/RCTBridgeModule.h>
@interface OnnxruntimeJSIHelper : NSObject <RCTBridgeModule>
@end

View file

@ -0,0 +1,85 @@
#import "OnnxruntimeJSIHelper.h"
#import <React/RCTBlobManager.h>
#import <React/RCTBridge+Private.h>
#import <jsi/jsi.h>
@implementation OnnxruntimeJSIHelper
RCT_EXPORT_MODULE()
RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) {
RCTBridge *bridge = [RCTBridge currentBridge];
RCTCxxBridge *cxxBridge = (RCTCxxBridge *)bridge;
if (cxxBridge == nil) {
return @false;
}
using namespace facebook;
auto jsiRuntime = (jsi::Runtime *)cxxBridge.runtime;
if (jsiRuntime == nil) {
return @false;
}
auto &runtime = *jsiRuntime;
auto resolveArrayBuffer = jsi::Function::createFromHostFunction(
runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeResolveArrayBuffer"), 1,
[](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value {
if (count != 1) {
throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!");
}
auto data = args[0].asObject(runtime);
auto blobId = data.getProperty(runtime, "blobId").asString(runtime).utf8(runtime);
auto size = data.getProperty(runtime, "size").asNumber();
auto offset = data.getProperty(runtime, "offset").asNumber();
RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
if (blobManager == nil) {
throw jsi::JSError(runtime, "RCTBlobManager is not initialized");
}
NSString *blobIdStr = [NSString stringWithUTF8String:blobId.c_str()];
auto blob = [blobManager resolve:blobIdStr offset:(long)offset size:(long)size];
jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer");
jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)blob.length).getObject(runtime);
jsi::ArrayBuffer buf = o.getArrayBuffer(runtime);
memcpy(buf.data(runtime), blob.bytes, blob.length);
[blobManager remove:blobIdStr];
return buf;
});
runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", resolveArrayBuffer);
auto storeArrayBuffer = jsi::Function::createFromHostFunction(
runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeStoreArrayBuffer"), 1,
[](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value {
if (count != 1) {
throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!");
}
auto arrayBuffer = args[0].asObject(runtime).getArrayBuffer(runtime);
auto size = arrayBuffer.length(runtime);
NSData *data = [NSData dataWithBytesNoCopy:arrayBuffer.data(runtime) length:size freeWhenDone:NO];
RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
if (blobManager == nil) {
throw jsi::JSError(runtime, "RCTBlobManager is not initialized");
}
NSString *blobId = [blobManager store:data];
jsi::Object result(runtime);
auto blobIdString = jsi::String::createFromUtf8(runtime, [blobId cStringUsingEncoding:NSUTF8StringEncoding]);
result.setProperty(runtime, "blobId", blobIdString);
result.setProperty(runtime, "offset", jsi::Value(0));
result.setProperty(runtime, "size", jsi::Value(static_cast<double>(size)));
return result;
});
runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", storeArrayBuffer);
return @true;
}
@end

View file

@ -5,9 +5,12 @@
#define OnnxruntimeModule_h
#import <React/RCTBridgeModule.h>
#import <React/RCTBlobManager.h>
@interface OnnxruntimeModule : NSObject<RCTBridgeModule>
- (void)setBlobManager:(RCTBlobManager *)manager;
-(NSDictionary*)loadModel:(NSString*)modelPath
options:(NSDictionary*)options;

View file

@ -5,6 +5,8 @@
#import "TensorHelper.h"
#import <Foundation/Foundation.h>
#import <React/RCTBlobManager.h>
#import <React/RCTBridge+Private.h>
#import <React/RCTLog.h>
// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened
@ -44,6 +46,21 @@ static int nextSessionId = 0;
RCT_EXPORT_MODULE(Onnxruntime)
RCTBlobManager *blobManager = nil;
- (void)checkBlobManager {
if (blobManager == nil) {
blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
if (blobManager == nil) {
@throw @"RCTBlobManager is not initialized";
}
}
}
- (void)setBlobManager:(RCTBlobManager *)manager {
blobManager = manager;
}
/**
* React native binding API to load a model using given uri.
*
@ -68,22 +85,27 @@ RCT_EXPORT_METHOD(loadModel
}
/**
* React native binding API to load a model using BASE64 encoded model data string.
* React native binding API to load a model using blob object that data stored in RCTBlobManager.
*
* @param modelData the BASE64 encoded model data string
* @param modelDataBlob a model data blob object
* @param options onnxruntime session options
* @param resolve callback for returning output back to react native js
* @param reject callback for returning an error back to react native js
* @note when run() is called, the same modelPath must be passed into the first parameter.
*/
RCT_EXPORT_METHOD(loadModelFromBase64EncodedBuffer
: (NSString *)modelDataBase64EncodedString options
RCT_EXPORT_METHOD(loadModelFromBlob
: (NSDictionary *)modelDataBlob options
: (NSDictionary *)options resolver
: (RCTPromiseResolveBlock)resolve rejecter
: (RCTPromiseRejectBlock)reject) {
@try {
NSData *modelDataDecoded = [[NSData alloc] initWithBase64EncodedString:modelDataBase64EncodedString options:0];
NSDictionary *resultMap = [self loadModelFromBuffer:modelDataDecoded options:options];
[self checkBlobManager];
NSString *blobId = [modelDataBlob objectForKey:@"blobId"];
long size = [[modelDataBlob objectForKey:@"size"] longValue];
long offset = [[modelDataBlob objectForKey:@"offset"] longValue];
auto modelData = [blobManager resolve:blobId offset:offset size:size];
NSDictionary *resultMap = [self loadModelFromBuffer:modelData options:options];
[blobManager remove:blobId];
resolve(resultMap);
} @catch (...) {
reject(@"onnxruntime", @"failed to load model from buffer", nil);
@ -255,6 +277,8 @@ RCT_EXPORT_METHOD(run
}
SessionInfo *sessionInfo = (SessionInfo *)[value pointerValue];
[self checkBlobManager];
std::vector<Ort::Value> feeds;
std::vector<Ort::MemoryAllocation> allocations;
feeds.reserve(sessionInfo->inputNames.size());
@ -265,7 +289,10 @@ RCT_EXPORT_METHOD(run
@throw exception;
}
Ort::Value value = [TensorHelper createInputTensor:inputTensor ortAllocator:ortAllocator allocations:allocations];
Ort::Value value = [TensorHelper createInputTensor:blobManager
input:inputTensor
ortAllocator:ortAllocator
allocations:allocations];
feeds.emplace_back(std::move(value));
}
@ -280,7 +307,7 @@ RCT_EXPORT_METHOD(run
sessionInfo->session->Run(runOptions, sessionInfo->inputNames.data(), feeds.data(),
sessionInfo->inputNames.size(), requestedOutputs.data(), requestedOutputs.size());
NSDictionary *resultMap = [TensorHelper createOutputTensor:requestedOutputs values:result];
NSDictionary *resultMap = [TensorHelper createOutputTensor:blobManager outputNames:requestedOutputs values:result];
return resultMap;
}
@ -378,6 +405,7 @@ static NSDictionary *executionModeTable = @{@"sequential" : @(ORT_SEQUENTIAL), @
while (NSString *key = [iterator nextObject]) {
[self dispose:key];
}
blobManager = nullptr;
}
@end

View file

@ -7,6 +7,9 @@
objects = {
/* Begin PBXBuildFile section */
0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */; };
7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */; };
C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */; };
DB8FC9B525C2867800C72F26 /* OnnxruntimeModule.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */; };
DB8FC9B825C2868700C72F26 /* TensorHelper.mm in Sources */ = {isa = PBXBuildFile; fileRef = DB8FC9B725C2868700C72F26 /* TensorHelper.mm */; };
DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */ = {isa = PBXBuildFile; fileRef = DBDB57D92603211A004F16BE /* TensorHelperTest.mm */; };
@ -39,6 +42,14 @@
/* Begin PBXFileReference section */
134814201AA4EA6300B7C361 /* libOnnxruntimeModule.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libOnnxruntimeModule.a; sourceTree = BUILT_PRODUCTS_DIR; };
38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModuleTest.framework; sourceTree = BUILT_PRODUCTS_DIR; };
49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_OnnxruntimeModule.framework; sourceTree = BUILT_PRODUCTS_DIR; };
5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.debug.xcconfig"; sourceTree = "<group>"; };
548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.release.xcconfig"; sourceTree = "<group>"; };
63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModule.debug.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModule/Pods-OnnxruntimeModule.debug.xcconfig"; sourceTree = "<group>"; };
7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = FakeRCTBlobManager.m; sourceTree = "<group>"; };
7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = FakeRCTBlobManager.h; sourceTree = "<group>"; };
8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-OnnxruntimeModuleTest.release.xcconfig"; path = "Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest.release.xcconfig"; sourceTree = "<group>"; };
DB8FC9B425C2867800C72F26 /* OnnxruntimeModule.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = OnnxruntimeModule.mm; sourceTree = SOURCE_ROOT; };
DB8FC9B725C2868700C72F26 /* TensorHelper.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = TensorHelper.mm; sourceTree = SOURCE_ROOT; };
DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = OnnxruntimeModuleTest.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
@ -53,6 +64,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
C60033360456900E26D6F96F /* Pods_OnnxruntimeModule.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -61,6 +73,7 @@
buildActionMask = 2147483647;
files = (
DBDB57DC2603211A004F16BE /* libOnnxruntimeModule.a in Frameworks */,
0105483CF04B9471894F3EAA /* Pods_OnnxruntimeModuleTest.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -84,16 +97,30 @@
134814211AA4EA7D00B7C361 /* Products */,
62ED2272D9F9CF7E3D0A8F87 /* Pods */,
DBDB57D72603211A004F16BE /* OnnxruntimeModuleTest.xctest */,
6FFDF1594C99DA125B013E34 /* Frameworks */,
);
sourceTree = "<group>";
};
62ED2272D9F9CF7E3D0A8F87 /* Pods */ = {
isa = PBXGroup;
children = (
63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */,
548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */,
5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */,
8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */,
);
path = Pods;
sourceTree = "<group>";
};
6FFDF1594C99DA125B013E34 /* Frameworks */ = {
isa = PBXGroup;
children = (
49D0ADD02E7162A5F0DE8BAB /* Pods_OnnxruntimeModule.framework */,
38EB61A518C2DF782F7CD433 /* Pods_OnnxruntimeModuleTest.framework */,
);
name = Frameworks;
sourceTree = "<group>";
};
DB8FC9B325C2861300C72F26 /* OnnxruntimeModule */ = {
isa = PBXGroup;
children = (
@ -109,6 +136,8 @@
DBDB57D92603211A004F16BE /* TensorHelperTest.mm */,
DBDB57DB2603211A004F16BE /* Info.plist */,
DBDB58AF262A92D6004F16BE /* OnnxruntimeModuleTest.mm */,
7FD234662A1F221700734B71 /* FakeRCTBlobManager.m */,
7FD234682A1F234500734B71 /* FakeRCTBlobManager.h */,
);
path = OnnxruntimeModuleTest;
sourceTree = "<group>";
@ -120,6 +149,7 @@
isa = PBXNativeTarget;
buildConfigurationList = 58B511EF1A9E6C8500147676 /* Build configuration list for PBXNativeTarget "OnnxruntimeModule" */;
buildPhases = (
FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */,
58B511D71A9E6C8500147676 /* Sources */,
58B511D81A9E6C8500147676 /* Frameworks */,
58B511D91A9E6C8500147676 /* CopyFiles */,
@ -137,9 +167,11 @@
isa = PBXNativeTarget;
buildConfigurationList = DBDB57E12603211A004F16BE /* Build configuration list for PBXNativeTarget "OnnxruntimeModuleTest" */;
buildPhases = (
896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */,
DBDB57D32603211A004F16BE /* Sources */,
DBDB57D42603211A004F16BE /* Frameworks */,
DBDB57D52603211A004F16BE /* Resources */,
015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */,
);
buildRules = (
);
@ -200,6 +232,119 @@
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXShellScriptBuildPhase section */
015C75E59BC80D4507FB6E8A /* [CP] Embed Pods Frameworks */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputPaths = (
"${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh",
"${BUILT_PRODUCTS_DIR}/DoubleConversion/DoubleConversion.framework",
"${BUILT_PRODUCTS_DIR}/RCT-Folly/folly.framework",
"${BUILT_PRODUCTS_DIR}/RCTTypeSafety/RCTTypeSafety.framework",
"${BUILT_PRODUCTS_DIR}/React-Codegen/React_Codegen.framework",
"${BUILT_PRODUCTS_DIR}/React-Core/React.framework",
"${BUILT_PRODUCTS_DIR}/React-CoreModules/CoreModules.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTAnimation/RCTAnimation.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTBlob/RCTBlob.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTImage/RCTImage.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTLinking/RCTLinking.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTNetwork/RCTNetwork.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTSettings/RCTSettings.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTText/RCTText.framework",
"${BUILT_PRODUCTS_DIR}/React-RCTVibration/RCTVibration.framework",
"${BUILT_PRODUCTS_DIR}/React-bridging/react_bridging.framework",
"${BUILT_PRODUCTS_DIR}/React-cxxreact/cxxreact.framework",
"${BUILT_PRODUCTS_DIR}/React-jsi/jsi.framework",
"${BUILT_PRODUCTS_DIR}/React-jsiexecutor/jsireact.framework",
"${BUILT_PRODUCTS_DIR}/React-jsinspector/jsinspector.framework",
"${BUILT_PRODUCTS_DIR}/React-logger/logger.framework",
"${BUILT_PRODUCTS_DIR}/React-perflogger/reactperflogger.framework",
"${BUILT_PRODUCTS_DIR}/ReactCommon/ReactCommon.framework",
"${BUILT_PRODUCTS_DIR}/Yoga/yoga.framework",
"${BUILT_PRODUCTS_DIR}/fmt/fmt.framework",
"${BUILT_PRODUCTS_DIR}/glog/glog.framework",
);
name = "[CP] Embed Pods Frameworks";
outputPaths = (
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/DoubleConversion.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/folly.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTTypeSafety.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React_Codegen.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/React.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/CoreModules.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTAnimation.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTBlob.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTImage.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTLinking.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTNetwork.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTSettings.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTText.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/RCTVibration.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/react_bridging.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/cxxreact.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsi.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsireact.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/jsinspector.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/logger.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/reactperflogger.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ReactCommon.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/yoga.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fmt.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/glog.framework",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-OnnxruntimeModuleTest/Pods-OnnxruntimeModuleTest-frameworks.sh\"\n";
showEnvVarsInLog = 0;
};
896E89AEC864CBD0CC7E0AF1 /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputFileListPaths = (
);
inputPaths = (
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
"${PODS_ROOT}/Manifest.lock",
);
name = "[CP] Check Pods Manifest.lock";
outputFileListPaths = (
);
outputPaths = (
"$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModuleTest-checkManifestLockResult.txt",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
FA8BD7B76BD8BD02A6DB750A /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputFileListPaths = (
);
inputPaths = (
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
"${PODS_ROOT}/Manifest.lock",
);
name = "[CP] Check Pods Manifest.lock";
outputFileListPaths = (
);
outputPaths = (
"$(DERIVED_FILE_DIR)/Pods-OnnxruntimeModule-checkManifestLockResult.txt",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
/* End PBXShellScriptBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
58B511D71A9E6C8500147676 /* Sources */ = {
isa = PBXSourcesBuildPhase;
@ -216,6 +361,7 @@
files = (
DBDB57DA2603211A004F16BE /* TensorHelperTest.mm in Sources */,
DBDB58B0262A92D7004F16BE /* OnnxruntimeModuleTest.mm in Sources */,
7FD234672A1F221700734B71 /* FakeRCTBlobManager.m in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -329,6 +475,7 @@
};
58B511F01A9E6C8500147676 /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 63B05EB079B0A4D99448F1D3 /* Pods-OnnxruntimeModule.debug.xcconfig */;
buildSettings = {
HEADER_SEARCH_PATHS = (
"$(inherited)",
@ -352,6 +499,7 @@
};
58B511F11A9E6C8500147676 /* Release */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 548638FE75FCC69C842C9545 /* Pods-OnnxruntimeModule.release.xcconfig */;
buildSettings = {
HEADER_SEARCH_PATHS = (
"$(inherited)",
@ -374,6 +522,7 @@
};
DBDB57DF2603211A004F16BE /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 5391B4C0B7C168594AA0DD0B /* Pods-OnnxruntimeModuleTest.debug.xcconfig */;
buildSettings = {
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
@ -446,6 +595,7 @@
};
DBDB57E02603211A004F16BE /* Release */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 8529D8A6F40E462E62B38B52 /* Pods-OnnxruntimeModuleTest.release.xcconfig */;
buildSettings = {
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef FakeRCTBlobManager_h
#define FakeRCTBlobManager_h
#import <React/RCTBlobManager.h>
@interface FakeRCTBlobManager : RCTBlobManager
@property (nonatomic, strong) NSMutableDictionary *blobs;
- (NSString *)store:(NSData *)data;
- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size;
- (NSDictionary *)testCreateData:(NSData *)buffer;
- (NSString *)testGetData:(NSDictionary *)data;
@end
#endif /* FakeRCTBlobManager_h */

View file

@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import <Foundation/Foundation.h>
#import "FakeRCTBlobManager.h"
@implementation FakeRCTBlobManager
- (instancetype)init {
if (self = [super init]) {
_blobs = [NSMutableDictionary new];
}
return self;
}
- (NSString *)store:(NSData *)data {
NSString *blobId = [[NSUUID UUID] UUIDString];
_blobs[blobId] = data;
return blobId;
}
- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size {
NSData *data = _blobs[blobId];
if (data == nil) {
return nil;
}
return [data subdataWithRange:NSMakeRange(offset, size)];
}
- (NSDictionary *)testCreateData:(NSData *)buffer {
NSString* blobId = [self store:buffer];
return @{
@"blobId": blobId,
@"offset": @0,
@"size": @(buffer.length),
};
}
- (NSString *)testGetData:(NSDictionary *)data {
NSString *blobId = [data objectForKey:@"blobId"];
long size = [[data objectForKey:@"size"] longValue];
long offset = [[data objectForKey:@"offset"] longValue];
[self resolve:blobId offset:offset size:size];
return blobId;
}
@end

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#import "OnnxruntimeModule.h"
#import "FakeRCTBlobManager.h"
#import "TensorHelper.h"
#import <XCTest/XCTest.h>
@ -13,6 +14,14 @@
@implementation OnnxruntimeModuleTest
FakeRCTBlobManager *fakeBlobManager = nil;
+ (void)initialize {
if (self == [OnnxruntimeModuleTest class]) {
fakeBlobManager = [FakeRCTBlobManager new];
}
}
- (void)testOnnxruntimeModule {
NSBundle *bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]];
NSString *dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"];
@ -20,6 +29,7 @@
NSString *sessionKey2 = @"";
OnnxruntimeModule *onnxruntimeModule = [OnnxruntimeModule new];
[onnxruntimeModule setBlobManager:fakeBlobManager];
{
// test loadModelFromBuffer()
@ -70,8 +80,8 @@
}
floatPtr = (float *)[byteBufferRef bytes];
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
inputTensorMap[@"data"] = dataEncoded;
XCTAssertNotNil(fakeBlobManager);
inputTensorMap[@"data"] = [fakeBlobManager testCreateData:byteBufferRef];
NSMutableDictionary *inputDataMap = [NSMutableDictionary dictionary];
inputDataMap[@"input"] = inputTensorMap;
@ -84,8 +94,18 @@
NSDictionary *resultMap = [onnxruntimeModule run:sessionKey input:inputDataMap output:output options:options];
NSDictionary *resultMap2 = [onnxruntimeModule run:sessionKey2 input:inputDataMap output:output options:options];
XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
XCTAssertTrue([[resultMap2 objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
// Compare output & input, but data.blobId is different
// dims
XCTAssertTrue([[resultMap objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
XCTAssertTrue([[resultMap2 objectForKey:@"output"][@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
// type
XCTAssertEqual([resultMap objectForKey:@"output"][@"type"], JsTensorTypeFloat);
XCTAssertEqual([resultMap2 objectForKey:@"output"][@"type"], JsTensorTypeFloat);
// data ({ blobId, offset, size })
XCTAssertEqual([[resultMap objectForKey:@"output"][@"data"][@"offset"] longValue], 0);
XCTAssertEqual([[resultMap2 objectForKey:@"output"][@"data"][@"size"] longValue], byteBufferSize);
}
// test dispose

View file

@ -3,6 +3,7 @@
#import "TensorHelper.h"
#import "FakeRCTBlobManager.h"
#import <XCTest/XCTest.h>
#import <onnxruntime/onnxruntime_cxx_api.h>
#include <vector>
@ -13,6 +14,14 @@
@implementation TensorHelperTest
FakeRCTBlobManager *testBlobManager = nil;
+ (void)initialize {
if (self == [TensorHelperTest class]) {
testBlobManager = [FakeRCTBlobManager new];
}
}
template <typename T>
static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::function<NSNumber *(T value)> &convert,
ONNXTensorElementDataType onnxType, NSString *jsTensorType) {
@ -34,12 +43,13 @@ static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::funct
typePtr[i] = outValues[i];
}
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
inputTensorMap[@"data"] = dataEncoded;
XCTAssertNotNil(testBlobManager);
inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef];
Ort::AllocatorWithDefaultOptions ortAllocator;
std::vector<Ort::MemoryAllocation> allocations;
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
input:inputTensorMap
ortAllocator:ortAllocator
allocations:allocations];
@ -126,7 +136,8 @@ static void testCreateInputTensorT(const std::array<T, 3> &outValues, std::funct
Ort::AllocatorWithDefaultOptions ortAllocator;
std::vector<Ort::MemoryAllocation> allocations;
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
input:inputTensorMap
ortAllocator:ortAllocator
allocations:allocations];
@ -194,10 +205,11 @@ static void testCreateOutputTensorT(const std::array<T, 5> &outValues, std::func
typePtr[i] = outValues[i];
}
NSString *dataEncoded = [byteBufferRef base64EncodedStringWithOptions:0];
inputTensorMap[@"data"] = dataEncoded;
inputTensorMap[@"data"] = [testBlobManager testCreateData:byteBufferRef];
;
std::vector<Ort::MemoryAllocation> allocations;
Ort::Value inputTensor = [TensorHelper createInputTensor:inputTensorMap
Ort::Value inputTensor = [TensorHelper createInputTensor:testBlobManager
input:inputTensorMap
ortAllocator:ortAllocator
allocations:allocations];
@ -208,9 +220,24 @@ static void testCreateOutputTensorT(const std::array<T, 5> &outValues, std::func
auto output = session.Run(runOptions, inputNames.data(), feeds.data(), inputNames.size(), outputNames.data(),
outputNames.size());
NSDictionary *resultMap = [TensorHelper createOutputTensor:outputNames values:output];
NSDictionary *resultMap = [TensorHelper createOutputTensor:testBlobManager outputNames:outputNames values:output];
XCTAssertTrue([[resultMap objectForKey:@"output"] isEqualToDictionary:inputTensorMap]);
// Compare output & input, but data.blobId is different
NSDictionary *outputMap = [resultMap objectForKey:@"output"];
// dims
XCTAssertTrue([outputMap[@"dims"] isEqualToArray:inputTensorMap[@"dims"]]);
// type
XCTAssertEqual(outputMap[@"type"], jsTensorType);
// data ({ blobId, offset, size })
NSDictionary *data = outputMap[@"data"];
XCTAssertNotNil(data[@"blobId"]);
XCTAssertEqual([data[@"offset"] longValue], 0);
XCTAssertEqual([data[@"size"] longValue], byteBufferSize);
}
- (void)testCreateOutputTensorFloat {

View file

@ -5,6 +5,7 @@
#define TensorHelper_h
#import <Foundation/Foundation.h>
#import <React/RCTBlobManager.h>
// Note: Using below syntax for including ort c api and ort extensions headers to resolve a compiling error happened
// in an expo react native ios app (a redefinition error happened with multiple object types defined within
@ -36,17 +37,19 @@ FOUNDATION_EXPORT NSString* const JsTensorTypeString;
/**
* It creates an input tensor from a map passed by react native js.
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
* 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor.
*/
+(Ort::Value)createInputTensor:(NSDictionary*)input
+(Ort::Value)createInputTensor:(RCTBlobManager *)blobManager
input:(NSDictionary*)input
ortAllocator:(OrtAllocator*)ortAllocator
allocations:(std::vector<Ort::MemoryAllocation>&)allocatons;
allocations:(std::vector<Ort::MemoryAllocation>&)allocations;
/**
* It creates an output map from an output tensor.
* a data array is encoded as base64 string.
* a data array is store in RCTBlobManager.
*/
+(NSDictionary*)createOutputTensor:(const std::vector<const char*>&)outputNames
+(NSDictionary*)createOutputTensor:(RCTBlobManager *)blobManager
outputNames:(const std::vector<const char*>&)outputNames
values:(const std::vector<Ort::Value>&)values;
@end

View file

@ -21,11 +21,12 @@ NSString *const JsTensorTypeString = @"string";
/**
* It creates an input tensor from a map passed by react native js.
* 'data' must be a string type as data is encoded as base64. It first decodes it and creates a tensor.
* 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor.
*/
+ (Ort::Value)createInputTensor:(NSDictionary *)input
+ (Ort::Value)createInputTensor:(RCTBlobManager *)blobManager
input:(NSDictionary *)input
ortAllocator:(OrtAllocator *)ortAllocator
allocations:(std::vector<Ort::MemoryAllocation> &)allocatons {
allocations:(std::vector<Ort::MemoryAllocation> &)allocations {
// shape
NSArray *dimsArray = [input objectForKey:@"dims"];
std::vector<int64_t> dims;
@ -48,22 +49,27 @@ NSString *const JsTensorTypeString = @"string";
}
return inputTensor;
} else {
NSString *data = [input objectForKey:@"data"];
NSData *buffer = [[NSData alloc] initWithBase64EncodedString:data options:0];
NSDictionary *data = [input objectForKey:@"data"];
NSString *blobId = [data objectForKey:@"blobId"];
long size = [[data objectForKey:@"size"] longValue];
long offset = [[data objectForKey:@"offset"] longValue];
auto buffer = [blobManager resolve:blobId offset:offset size:size];
Ort::Value inputTensor = [self createInputTensor:tensorType
dims:dims
buffer:buffer
ortAllocator:ortAllocator
allocations:allocatons];
allocations:allocations];
[blobManager remove:blobId];
return inputTensor;
}
}
/**
* It creates an output map from an output tensor.
* a data array is encoded as base64 string.
* a data array is store in RCTBlobManager.
*/
+ (NSDictionary *)createOutputTensor:(const std::vector<const char *> &)outputNames
+ (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager
outputNames:(const std::vector<const char *> &)outputNames
values:(const std::vector<Ort::Value> &)values {
if (outputNames.size() != values.size()) {
NSException *exception = [NSException exceptionWithName:@"create output tensor"
@ -109,8 +115,13 @@ NSString *const JsTensorTypeString = @"string";
}
outputTensor[@"data"] = buffer;
} else {
NSString *data = [self createOutputTensor:value];
outputTensor[@"data"] = data;
NSData *data = [self createOutputTensor:value];
NSString *blobId = [blobManager store:data];
outputTensor[@"data"] = @{
@"blobId" : blobId,
@"offset" : @0,
@"size" : @(data.length),
};
}
outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor;
@ -170,15 +181,14 @@ static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vect
}
}
template <typename T> static NSString *createOutputTensorT(const Ort::Value &tensor) {
template <typename T> static NSData *createOutputTensorT(const Ort::Value &tensor) {
const auto data = tensor.GetTensorData<T>();
NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data
length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T)
freeWhenDone:false];
return [buffer base64EncodedStringWithOptions:0];
return [NSData dataWithBytesNoCopy:(void *)data
length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T)
freeWhenDone:false];
}
+ (NSString *)createOutputTensor:(const Ort::Value &)tensor {
+ (NSData *)createOutputTensor:(const Ort::Value &)tensor {
ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType();
switch (tensorType) {

View file

@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Buffer} from 'buffer';
import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common';
import {Platform} from 'react-native';
import {binding, Binding} from './binding';
import {binding, Binding, JSIBlob, jsiHelper} from './binding';
type SupportedTypedArray = Exclude<Tensor.DataType, string[]>;
@ -69,11 +68,11 @@ class OnnxruntimeSessionHandler implements SessionHandler {
if (typeof this.#pathOrBuffer === 'string') {
results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options);
} else {
if (!this.#inferenceSession.loadModelFromBase64EncodedBuffer) {
throw new Error('Native module method "loadModelFromBase64EncodedBuffer" is not defined');
if (!this.#inferenceSession.loadModelFromBlob) {
throw new Error('Native module method "loadModelFromBlob" is not defined');
}
const modelInBase64String = Buffer.from(this.#pathOrBuffer).toString('base64');
results = await this.#inferenceSession.loadModelFromBase64EncodedBuffer(modelInBase64String, options);
const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer);
results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options);
}
// resolve promise if onnxruntime session is successfully created
this.#key = results.key;
@ -113,18 +112,18 @@ class OnnxruntimeSessionHandler implements SessionHandler {
return output;
}
encodeFeedsType(feeds: SessionHandler.FeedsType): Binding.FeedsType {
const returnValue: {[name: string]: Binding.EncodedTensorType} = {};
for (const key in feeds) {
if (Object.hasOwnProperty.call(feeds, key)) {
let data: string|string[];
let data: JSIBlob|string[];
if (Array.isArray(feeds[key].data)) {
data = feeds[key].data as string[];
} else {
// Base64-encode tensor data
const buffer = (feeds[key].data as SupportedTypedArray).buffer;
data = Buffer.from(buffer, 0, buffer.byteLength).toString('base64');
data = jsiHelper.storeArrayBuffer(buffer);
}
returnValue[key] = {
@ -146,9 +145,9 @@ class OnnxruntimeSessionHandler implements SessionHandler {
if (Array.isArray(results[key].data)) {
tensorData = results[key].data as string[];
} else {
const buffer: Buffer = Buffer.from(results[key].data as string, 'base64');
const buffer = jsiHelper.resolveArrayBuffer(results[key].data as JSIBlob) as SupportedTypedArray;
const typedArray = tensorTypeToTypedArray(results[key].type as Tensor.Type);
tensorData = new typedArray(buffer.buffer, buffer.byteOffset, buffer.length / typedArray.BYTES_PER_ELEMENT);
tensorData = new typedArray(buffer, buffer.byteOffset, buffer.byteLength / typedArray.BYTES_PER_ELEMENT);
}
returnValue[key] = new Tensor(results[key].type as Tensor.Type, tensorData, results[key].dims);

View file

@ -26,7 +26,14 @@ interface ModelLoadInfo {
}
/**
* Tensor type for react native, which doesn't allow ArrayBuffer, so data will be encoded as Base64 string.
* JSIBlob is a blob object that exchange ArrayBuffer by OnnxruntimeJSIHelper.
*/
export type JSIBlob = {
blobId: string; offset: number; size: number;
};
/**
* Tensor type for react native, which doesn't allow ArrayBuffer in native bridge, so data will be stored as JSIBlob.
*/
interface EncodedTensor {
/**
@ -38,10 +45,10 @@ interface EncodedTensor {
*/
readonly type: string;
/**
* the Base64 encoded string of the buffer data of the tensor.
* if data is string array, it won't be encoded as Base64 string.
* the JSIBlob object of the buffer data of the tensor.
* if data is string array, it won't be stored as JSIBlob.
*/
readonly data: string|string[];
readonly data: JSIBlob|string[];
}
/**
@ -64,12 +71,41 @@ export declare namespace Binding {
interface InferenceSession {
loadModel(modelPath: string, options: SessionOptions): Promise<ModelLoadInfoType>;
loadModelFromBase64EncodedBuffer?(buffer: string, options: SessionOptions): Promise<ModelLoadInfoType>;
loadModelFromBlob?(blob: JSIBlob, options: SessionOptions): Promise<ModelLoadInfoType>;
dispose(key: string): Promise<void>;
run(key: string, feeds: FeedsType, fetches: FetchesType, options: RunOptions): Promise<ReturnType>;
}
}
// export native binding
const {Onnxruntime} = NativeModules;
const {Onnxruntime, OnnxruntimeJSIHelper} = NativeModules;
export const binding = Onnxruntime as Binding.InferenceSession;
// install JSI helper global functions
OnnxruntimeJSIHelper.install();
declare global {
// eslint-disable-next-line no-var
var jsiOnnxruntimeStoreArrayBuffer: ((buffer: ArrayBuffer) => JSIBlob)|undefined;
// eslint-disable-next-line no-var
var jsiOnnxruntimeResolveArrayBuffer: ((blob: JSIBlob) => ArrayBuffer)|undefined;
}
export const jsiHelper = {
storeArrayBuffer: globalThis.jsiOnnxruntimeStoreArrayBuffer || (() => {
throw new Error(
'jsiOnnxruntimeStoreArrayBuffer is not found, ' +
'please make sure OnnxruntimeJSIHelper installation is successful.');
}),
resolveArrayBuffer: globalThis.jsiOnnxruntimeResolveArrayBuffer || (() => {
throw new Error(
'jsiOnnxruntimeResolveArrayBuffer is not found, ' +
'please make sure OnnxruntimeJSIHelper installation is successful.');
}),
};
// Remove global functions after installation
{
delete globalThis.jsiOnnxruntimeStoreArrayBuffer;
delete globalThis.jsiOnnxruntimeResolveArrayBuffer;
}