[js/rn] optimize exception message on Android (#12113)

This commit is contained in:
Yulong Wang 2022-07-07 13:26:50 -07:00 committed by GitHub
parent b50239251d
commit 3ce25db7eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -72,7 +72,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
WritableMap resultMap = loadModel(uri, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Can't read a model " + uri, e);
promise.reject("Can't load model \"" + uri + "\": " + e.getMessage(), e);
}
}
@ -91,7 +91,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
WritableMap resultMap = run(key, input, output, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Fail to inference", e);
promise.reject("Fail to inference: " + e.getMessage(), e);
}
}
@ -122,21 +122,14 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
if (!sessionMap.containsKey(uri)) {
byte[] modelArray = null;
try {
Reader reader = new BufferedReader(new InputStreamReader(modelStream));
modelArray = new byte[modelStream.available()];
modelStream.read(modelArray);
} catch (IOException e) {
throw new Exception("Can't read a model " + uri, e);
}
try {
SessionOptions sessionOptions = parseSessionOptions(options);
ortSession = ortEnvironment.createSession(modelArray, sessionOptions);
sessionMap.put(uri, ortSession);
} catch (OrtException e) {
throw new Exception("Can't create InferenceSession", e);
}
Reader reader = new BufferedReader(new InputStreamReader(modelStream));
modelArray = new byte[modelStream.available()];
modelStream.read(modelArray);
SessionOptions sessionOptions = parseSessionOptions(options);
ortSession = ortEnvironment.createSession(modelArray, sessionOptions);
sessionMap.put(uri, ortSession);
} else {
ortSession = sessionMap.get(uri);
}
@ -169,7 +162,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
public WritableMap run(String key, ReadableMap input, ReadableArray output, ReadableMap options) throws Exception {
OrtSession ortSession = sessionMap.get(key);
if (ortSession == null) {
throw new Exception("Model is not loaded " + key);
throw new Exception("Model is not loaded: " + key);
}
RunOptions runOptions = parseRunOptions(options);
@ -177,51 +170,62 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
long startTime = System.currentTimeMillis();
Map<String, OnnxTensor> feed = new HashMap<>();
Iterator<String> iterator = ortSession.getInputNames().iterator();
while (iterator.hasNext()) {
String inputName = iterator.next();
try {
while (iterator.hasNext()) {
String inputName = iterator.next();
ReadableMap inputMap = input.getMap(inputName);
if (inputMap == null) {
throw new Exception("Can't find input: " + inputName);
ReadableMap inputMap = input.getMap(inputName);
if (inputMap == null) {
throw new Exception("Can't find input: " + inputName);
}
if (inputMap.getType("data") != ReadableType.String) {
// NOTE:
//
// tensor data should always be a BASE64 encoded string.
// This is because the current React Native bridge supports limited data type as arguments.
// In order to pass data from JS to Java, we have to encode them into string.
//
// see also:
// https://reactnative.dev/docs/native-modules-android#argument-types
throw new Exception("Non string type of a tensor data is not allowed");
}
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
feed.put(inputName, onnxTensor);
}
if (inputMap.getType("data") != ReadableType.String) {
throw new Exception("Non string type of a tensor data is not allowed");
Set<String> requestedOutputs = null;
if (output.size() > 0) {
requestedOutputs = new HashSet<>();
for (int i = 0; i < output.size(); ++i) {
requestedOutputs.add(output.getString(i));
}
}
OnnxTensor onnxTensor = TensorHelper.createInputTensor(inputMap, ortEnvironment);
feed.put(inputName, onnxTensor);
}
long duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createInputTensor: " + duration);
Set<String> requestedOutputs = null;
if (output.size() > 0) {
requestedOutputs = new HashSet<>();
for (int i = 0; i < output.size(); ++i) {
requestedOutputs.add(output.getString(i));
startTime = System.currentTimeMillis();
Result result = null;
if (requestedOutputs != null) {
result = ortSession.run(feed, requestedOutputs, runOptions);
} else {
result = ortSession.run(feed, runOptions);
}
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "inference: " + duration);
startTime = System.currentTimeMillis();
WritableMap resultMap = TensorHelper.createOutputTensor(result);
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createOutputTensor: " + duration);
return resultMap;
} finally {
OnnxValue.close(feed);
}
long duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createInputTensor: " + duration);
startTime = System.currentTimeMillis();
Result result = null;
if (requestedOutputs != null) {
result = ortSession.run(feed, requestedOutputs, runOptions);
} else {
result = ortSession.run(feed, runOptions);
}
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "inference: " + duration);
startTime = System.currentTimeMillis();
WritableMap resultMap = TensorHelper.createOutputTensor(result);
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createOutputTensor: " + duration);
OnnxValue.close(feed);
return resultMap;
}
private static final Map<String, SessionOptions.OptLevel> graphOptimizationLevelTable =