diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java index 53a9bb51f9..fe59cefbee 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java @@ -72,7 +72,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { WritableMap resultMap = loadModel(uri, options); promise.resolve(resultMap); } catch (Exception e) { - promise.reject("Can't read a model " + uri, e); + promise.reject("Can't load model \"" + uri + "\": " + e.getMessage(), e); } } @@ -91,7 +91,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { WritableMap resultMap = run(key, input, output, options); promise.resolve(resultMap); } catch (Exception e) { - promise.reject("Fail to inference", e); + promise.reject("Fail to inference: " + e.getMessage(), e); } } @@ -122,21 +122,14 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { if (!sessionMap.containsKey(uri)) { byte[] modelArray = null; - try { - Reader reader = new BufferedReader(new InputStreamReader(modelStream)); - modelArray = new byte[modelStream.available()]; - modelStream.read(modelArray); - } catch (IOException e) { - throw new Exception("Can't read a model " + uri, e); - } - try { - SessionOptions sessionOptions = parseSessionOptions(options); - ortSession = ortEnvironment.createSession(modelArray, sessionOptions); - sessionMap.put(uri, ortSession); - } catch (OrtException e) { - throw new Exception("Can't create InferenceSession", e); - } + Reader reader = new BufferedReader(new InputStreamReader(modelStream)); + modelArray = new byte[modelStream.available()]; + modelStream.read(modelArray); + + SessionOptions sessionOptions = parseSessionOptions(options); + ortSession = ortEnvironment.createSession(modelArray, sessionOptions); + sessionMap.put(uri, ortSession); } else { ortSession = sessionMap.get(uri); } @@ -169,7 +162,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { public WritableMap run(String key, ReadableMap input, ReadableArray output, ReadableMap options) throws Exception { OrtSession ortSession = sessionMap.get(key); if (ortSession == null) { - throw new Exception("Model is not loaded " + key); + throw new Exception("Model is not loaded: " + key); } RunOptions runOptions = parseRunOptions(options); @@ -177,51 +170,62 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { long startTime = System.currentTimeMillis(); Map feed = new HashMap<>(); Iterator iterator = ortSession.getInputNames().iterator(); - while (iterator.hasNext()) { - String inputName = iterator.next(); + try { + while (iterator.hasNext()) { + String inputName = iterator.next(); - ReadableMap inputMap = input.getMap(inputName); - if (inputMap == null) { - throw new Exception("Can't find input: " + inputName); + ReadableMap inputMap = input.getMap(inputName); + if (inputMap == null) { + throw new Exception("Can't find input: " + inputName); + } + + if (inputMap.getType("data") != ReadableType.String) { + // NOTE: + // + // tensor data should always be a BASE64 encoded string. + // This is because the current React Native bridge supports limited data type as arguments. + // In order to pass data from JS to Java, we have to encode them into string. + // + // see also: + // https://reactnative.dev/docs/native-modules-android#argument-types + throw new Exception("Non string type of a tensor data is not allowed"); + } + + OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment); + feed.put(inputName, onnxTensor); } - if (inputMap.getType("data") != ReadableType.String) { - throw new Exception("Non string type of a tensor data is not allowed"); + Set requestedOutputs = null; + if (output.size() > 0) { + requestedOutputs = new HashSet<>(); + for (int i = 0; i < output.size(); ++i) { + requestedOutputs.add(output.getString(i)); + } } - OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment); - feed.put(inputName, onnxTensor); - } + long duration = System.currentTimeMillis() - startTime; + Log.d("Duration", "createInputTensor: " + duration); - Set requestedOutputs = null; - if (output.size() > 0) { - requestedOutputs = new HashSet<>(); - for (int i = 0; i < output.size(); ++i) { - requestedOutputs.add(output.getString(i)); + startTime = System.currentTimeMillis(); + Result result = null; + if (requestedOutputs != null) { + result = ortSession.run(feed, requestedOutputs, runOptions); + } else { + result = ortSession.run(feed, runOptions); } + duration = System.currentTimeMillis() - startTime; + Log.d("Duration", "inference: " + duration); + + startTime = System.currentTimeMillis(); + WritableMap resultMap = TensorHelper.createOutputTensor(result); + duration = System.currentTimeMillis() - startTime; + Log.d("Duration", "createOutputTensor: " + duration); + + return resultMap; + + } finally { + OnnxValue.close(feed); } - - long duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createInputTensor: " + duration); - - startTime = System.currentTimeMillis(); - Result result = null; - if (requestedOutputs != null) { - result = ortSession.run(feed, requestedOutputs, runOptions); - } else { - result = ortSession.run(feed, runOptions); - } - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "inference: " + duration); - - startTime = System.currentTimeMillis(); - WritableMap resultMap = TensorHelper.createOutputTensor(result); - duration = System.currentTimeMillis() - startTime; - Log.d("Duration", "createOutputTensor: " + duration); - - OnnxValue.close(feed); - - return resultMap; } private static final Map graphOptimizationLevelTable =