mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/rn] optimize exception message on Android (#12113)
This commit is contained in:
parent
b50239251d
commit
3ce25db7eb
1 changed files with 58 additions and 54 deletions
|
|
@ -72,7 +72,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
WritableMap resultMap = loadModel(uri, options);
|
||||
promise.resolve(resultMap);
|
||||
} catch (Exception e) {
|
||||
promise.reject("Can't read a model " + uri, e);
|
||||
promise.reject("Can't load model \"" + uri + "\": " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
WritableMap resultMap = run(key, input, output, options);
|
||||
promise.resolve(resultMap);
|
||||
} catch (Exception e) {
|
||||
promise.reject("Fail to inference", e);
|
||||
promise.reject("Fail to inference: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -122,21 +122,14 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
|
||||
if (!sessionMap.containsKey(uri)) {
|
||||
byte[] modelArray = null;
|
||||
try {
|
||||
Reader reader = new BufferedReader(new InputStreamReader(modelStream));
|
||||
modelArray = new byte[modelStream.available()];
|
||||
modelStream.read(modelArray);
|
||||
} catch (IOException e) {
|
||||
throw new Exception("Can't read a model " + uri, e);
|
||||
}
|
||||
|
||||
try {
|
||||
SessionOptions sessionOptions = parseSessionOptions(options);
|
||||
ortSession = ortEnvironment.createSession(modelArray, sessionOptions);
|
||||
sessionMap.put(uri, ortSession);
|
||||
} catch (OrtException e) {
|
||||
throw new Exception("Can't create InferenceSession", e);
|
||||
}
|
||||
Reader reader = new BufferedReader(new InputStreamReader(modelStream));
|
||||
modelArray = new byte[modelStream.available()];
|
||||
modelStream.read(modelArray);
|
||||
|
||||
SessionOptions sessionOptions = parseSessionOptions(options);
|
||||
ortSession = ortEnvironment.createSession(modelArray, sessionOptions);
|
||||
sessionMap.put(uri, ortSession);
|
||||
} else {
|
||||
ortSession = sessionMap.get(uri);
|
||||
}
|
||||
|
|
@ -169,7 +162,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
public WritableMap run(String key, ReadableMap input, ReadableArray output, ReadableMap options) throws Exception {
|
||||
OrtSession ortSession = sessionMap.get(key);
|
||||
if (ortSession == null) {
|
||||
throw new Exception("Model is not loaded " + key);
|
||||
throw new Exception("Model is not loaded: " + key);
|
||||
}
|
||||
|
||||
RunOptions runOptions = parseRunOptions(options);
|
||||
|
|
@ -177,51 +170,62 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
|
|||
long startTime = System.currentTimeMillis();
|
||||
Map<String, OnnxTensor> feed = new HashMap<>();
|
||||
Iterator<String> iterator = ortSession.getInputNames().iterator();
|
||||
while (iterator.hasNext()) {
|
||||
String inputName = iterator.next();
|
||||
try {
|
||||
while (iterator.hasNext()) {
|
||||
String inputName = iterator.next();
|
||||
|
||||
ReadableMap inputMap = input.getMap(inputName);
|
||||
if (inputMap == null) {
|
||||
throw new Exception("Can't find input: " + inputName);
|
||||
ReadableMap inputMap = input.getMap(inputName);
|
||||
if (inputMap == null) {
|
||||
throw new Exception("Can't find input: " + inputName);
|
||||
}
|
||||
|
||||
if (inputMap.getType("data") != ReadableType.String) {
|
||||
// NOTE:
|
||||
//
|
||||
// tensor data should always be a BASE64 encoded string.
|
||||
// This is because the current React Native bridge supports limited data type as arguments.
|
||||
// In order to pass data from JS to Java, we have to encode them into string.
|
||||
//
|
||||
// see also:
|
||||
// https://reactnative.dev/docs/native-modules-android#argument-types
|
||||
throw new Exception("Non string type of a tensor data is not allowed");
|
||||
}
|
||||
|
||||
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
|
||||
feed.put(inputName, onnxTensor);
|
||||
}
|
||||
|
||||
if (inputMap.getType("data") != ReadableType.String) {
|
||||
throw new Exception("Non string type of a tensor data is not allowed");
|
||||
Set<String> requestedOutputs = null;
|
||||
if (output.size() > 0) {
|
||||
requestedOutputs = new HashSet<>();
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
requestedOutputs.add(output.getString(i));
|
||||
}
|
||||
}
|
||||
|
||||
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
|
||||
feed.put(inputName, onnxTensor);
|
||||
}
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "createInputTensor: " + duration);
|
||||
|
||||
Set<String> requestedOutputs = null;
|
||||
if (output.size() > 0) {
|
||||
requestedOutputs = new HashSet<>();
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
requestedOutputs.add(output.getString(i));
|
||||
startTime = System.currentTimeMillis();
|
||||
Result result = null;
|
||||
if (requestedOutputs != null) {
|
||||
result = ortSession.run(feed, requestedOutputs, runOptions);
|
||||
} else {
|
||||
result = ortSession.run(feed, runOptions);
|
||||
}
|
||||
duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "inference: " + duration);
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
WritableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "createOutputTensor: " + duration);
|
||||
|
||||
return resultMap;
|
||||
|
||||
} finally {
|
||||
OnnxValue.close(feed);
|
||||
}
|
||||
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "createInputTensor: " + duration);
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
Result result = null;
|
||||
if (requestedOutputs != null) {
|
||||
result = ortSession.run(feed, requestedOutputs, runOptions);
|
||||
} else {
|
||||
result = ortSession.run(feed, runOptions);
|
||||
}
|
||||
duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "inference: " + duration);
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
WritableMap resultMap = TensorHelper.createOutputTensor(result);
|
||||
duration = System.currentTimeMillis() - startTime;
|
||||
Log.d("Duration", "createOutputTensor: " + duration);
|
||||
|
||||
OnnxValue.close(feed);
|
||||
|
||||
return resultMap;
|
||||
}
|
||||
|
||||
private static final Map<String, SessionOptions.OptLevel> graphOptimizationLevelTable =
|
||||
|
|
|
|||
Loading…
Reference in a new issue