mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
[Java] Tidying up the sample MNIST code (#3824)
* Updating the Java sample to load MNIST in libsvm format. * java - code formatting fix. Co-authored-by: Adam Pocock <adam.pocock@oracle.com>
This commit is contained in:
parent
f7ff5a7aa1
commit
5db30a470e
4 changed files with 116 additions and 21 deletions
|
|
@ -9,7 +9,17 @@ TBD: maven distribution
|
|||
|
||||
The minimum supported Java Runtime is version 8.
|
||||
|
||||
An example implementation is located in [src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java)
|
||||
An example implementation is located in
|
||||
[src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java).
|
||||
Once compiled the sample code expects the following arguments `ScoreMNIST
|
||||
<path-to-mnist-model> <path-to-mnist> <scikit-learn-flag>`. MNIST is expected
|
||||
to be in libsvm format. If the optional scikit-learn flag is supplied the model
|
||||
is expected to be produced by skl2onnx (so expects a flat feature vector, and
|
||||
produces a structured output), otherwise the model is expected to be a CNN from
|
||||
pytorch (expecting a `[1][1][28][28]` input, producing a vector of
|
||||
probabilities). Two example models are provided in [testdata](testdata),
|
||||
`cnn_mnist_pytorch.onnx` and `lr_mnist_scikit.onnx`. The first is a LeNet5 style
|
||||
CNN trained using PyTorch, the second is a logistic regression trained using scikit-learn.
|
||||
|
||||
This project can be built manually using the instructions below.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,15 +12,16 @@ import ai.onnxruntime.OrtSession;
|
|||
import ai.onnxruntime.OrtSession.Result;
|
||||
import ai.onnxruntime.OrtSession.SessionOptions;
|
||||
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* Demo code, supporting both a pytorch CNN trained on MNIST and a scikit-learn model trained on
|
||||
|
|
@ -29,6 +30,8 @@ import java.util.logging.Logger;
|
|||
public class ScoreMNIST {
|
||||
|
||||
private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName());
|
||||
/** Pattern for splitting libsvm format files. */
|
||||
private static final Pattern splitPattern = Pattern.compile("\\s+");
|
||||
|
||||
/** A named tuple for sparse classification data. */
|
||||
private static class SparseData {
|
||||
|
|
@ -38,29 +41,112 @@ public class ScoreMNIST {
|
|||
|
||||
public SparseData(int[] labels, List<int[]> indices, List<float[]> values) {
|
||||
this.labels = labels;
|
||||
this.indices = indices;
|
||||
this.values = values;
|
||||
this.indices = Collections.unmodifiableList(indices);
|
||||
this.values = Collections.unmodifiableList(values);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deserialises the data and puts it in a named tuple.
|
||||
* Converts a List of Integer into an int array.
|
||||
*
|
||||
* @param list The list to convert.
|
||||
* @return The int array.
|
||||
*/
|
||||
private static int[] convertInts(List<Integer> list) {
|
||||
int[] output = new int[list.size()];
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
output[i] = list.get(i);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a List of Float into a float array.
|
||||
*
|
||||
* @param list The list to convert.
|
||||
* @return The float array.
|
||||
*/
|
||||
private static float[] convertFloats(List<Float> list) {
|
||||
float[] output = new float[list.size()];
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
output[i] = list.get(i);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads data from a libsvm format file.
|
||||
*
|
||||
* @param path The path to load the data from.
|
||||
* @return A named tuple containing the data.
|
||||
* @throws IOException If it failed to read the file.
|
||||
* @throws ClassNotFoundException If a class wasn't found (only uses JDK types so this would be
|
||||
* very odd).
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
private static SparseData load(String path) throws IOException, ClassNotFoundException {
|
||||
try (ObjectInputStream ois =
|
||||
new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
|
||||
int[] labels = (int[]) ois.readObject();
|
||||
List<int[]> indices = (List<int[]>) ois.readObject();
|
||||
List<float[]> values = (List<float[]>) ois.readObject();
|
||||
return new SparseData(labels, indices, values);
|
||||
private static SparseData load(String path) throws IOException {
|
||||
int pos = 0;
|
||||
List<int[]> indices = new ArrayList<>();
|
||||
List<float[]> values = new ArrayList<>();
|
||||
List<Integer> labels = new ArrayList<>();
|
||||
String line;
|
||||
int maxFeatureID = Integer.MIN_VALUE;
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
for (; ; ) {
|
||||
line = reader.readLine();
|
||||
if (line == null) {
|
||||
break;
|
||||
}
|
||||
pos++;
|
||||
String[] fields = splitPattern.split(line);
|
||||
int lastID = -1;
|
||||
try {
|
||||
boolean valid = true;
|
||||
List<Integer> curIndices = new ArrayList<>();
|
||||
List<Float> curValues = new ArrayList<>();
|
||||
for (int i = 1; i < fields.length && valid; i++) {
|
||||
int ind = fields[i].indexOf(':');
|
||||
if (ind < 0) {
|
||||
logger.warning(String.format("Weird line at %d", pos));
|
||||
valid = false;
|
||||
}
|
||||
String ids = fields[i].substring(0, ind);
|
||||
int id = Integer.parseInt(ids);
|
||||
curIndices.add(id);
|
||||
if (maxFeatureID < id) {
|
||||
maxFeatureID = id;
|
||||
}
|
||||
float val = Float.parseFloat(fields[i].substring(ind + 1));
|
||||
curValues.add(val);
|
||||
if (id <= lastID) {
|
||||
logger.warning(String.format("Repeated features at line %d", pos));
|
||||
valid = false;
|
||||
} else {
|
||||
lastID = id;
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
// Store the label
|
||||
labels.add(Integer.parseInt(fields[0]));
|
||||
// Store the features
|
||||
indices.add(convertInts(curIndices));
|
||||
values.add(convertFloats(curValues));
|
||||
} else {
|
||||
throw new IOException("Invalid LibSVM format file at line " + pos);
|
||||
}
|
||||
} catch (NumberFormatException ex) {
|
||||
logger.warning(String.format("Weird line at %d", pos));
|
||||
throw new IOException("Invalid LibSVM format file", ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Loaded "
|
||||
+ maxFeatureID
|
||||
+ " features, "
|
||||
+ labels.size()
|
||||
+ " samples, from + '"
|
||||
+ path
|
||||
+ "'.");
|
||||
return new SparseData(convertInts(labels), indices, values);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -170,11 +256,10 @@ public class ScoreMNIST {
|
|||
return idx;
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws OrtException, IOException, ClassNotFoundException {
|
||||
public static void main(String[] args) throws OrtException, IOException {
|
||||
if (args.length < 2 || args.length > 3) {
|
||||
System.out.println("Usage: ScoreMNIST <model-path> <test-data> <optional:scikit-learn-flag>");
|
||||
System.out.println(
|
||||
"The test data input format is a Java serialized file containing an array of int labels, a list of int[] feature indices, and a list of float[] feature values");
|
||||
System.out.println("The test data input should be a libsvm format version of MNIST.");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -232,7 +317,7 @@ public class ScoreMNIST {
|
|||
|
||||
confusionMatrix[data.labels[i]][predLabel]++;
|
||||
|
||||
if (i % 500 == 0) {
|
||||
if (i % 2000 == 0) {
|
||||
logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (i + 1));
|
||||
logger.log(Level.INFO, "Output type = " + output.get(0).toString());
|
||||
if (args.length == 3) {
|
||||
|
|
|
|||
BIN
java/testdata/cnn_mnist_pytorch.onnx
vendored
Normal file
BIN
java/testdata/cnn_mnist_pytorch.onnx
vendored
Normal file
Binary file not shown.
BIN
java/testdata/lr_mnist_scikit.onnx
vendored
Normal file
BIN
java/testdata/lr_mnist_scikit.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue