[Java] Tidying up the sample MNIST code (#3824)

* Updating the Java sample to load MNIST in libsvm format.
* java - code formatting fix.
Co-authored-by: Adam Pocock <adam.pocock@oracle.com>
This commit is contained in:
Dmitri Smirnov 2020-05-05 14:34:13 -07:00 committed by GitHub
parent f7ff5a7aa1
commit 5db30a470e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 21 deletions

View file

@ -9,7 +9,17 @@ TBD: maven distribution
The minimum supported Java Runtime is version 8.
An example implementation is located in [src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java)
An example implementation is located in
[src/test/java/sample/ScoreMNIST.java](src/test/java/sample/ScoreMNIST.java).
Once compiled the sample code expects the following arguments `ScoreMNIST
<path-to-mnist-model> <path-to-mnist> <scikit-learn-flag>`. MNIST is expected
to be in libsvm format. If the optional scikit-learn flag is supplied the model
is expected to be produced by skl2onnx (so expects a flat feature vector, and
produces a structured output), otherwise the model is expected to be a CNN from
pytorch (expecting a `[1][1][28][28]` input, producing a vector of
probabilities). Two example models are provided in [testdata](testdata),
`cnn_mnist_pytorch.onnx` and `lr_mnist_scikit.onnx`. The first is a LeNet5 style
CNN trained using PyTorch, the second is a logistic regression trained using scikit-learn.
This project can be built manually using the instructions below.

View file

@ -12,15 +12,16 @@ import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;
/**
* Demo code, supporting both a pytorch CNN trained on MNIST and a scikit-learn model trained on
@ -29,6 +30,8 @@ import java.util.logging.Logger;
public class ScoreMNIST {
private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName());
/** Pattern for splitting libsvm format files. */
private static final Pattern splitPattern = Pattern.compile("\\s+");
/** A named tuple for sparse classification data. */
private static class SparseData {
@ -38,29 +41,112 @@ public class ScoreMNIST {
public SparseData(int[] labels, List<int[]> indices, List<float[]> values) {
this.labels = labels;
this.indices = indices;
this.values = values;
this.indices = Collections.unmodifiableList(indices);
this.values = Collections.unmodifiableList(values);
}
}
/**
* Deserialises the data and puts it in a named tuple.
* Converts a List of Integer into an int array.
*
* @param list The list to convert.
* @return The int array.
*/
private static int[] convertInts(List<Integer> list) {
int[] output = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
output[i] = list.get(i);
}
return output;
}
/**
* Converts a List of Float into a float array.
*
* @param list The list to convert.
* @return The float array.
*/
private static float[] convertFloats(List<Float> list) {
float[] output = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
output[i] = list.get(i);
}
return output;
}
/**
* Loads data from a libsvm format file.
*
* @param path The path to load the data from.
* @return A named tuple containing the data.
* @throws IOException If it failed to read the file.
* @throws ClassNotFoundException If a class wasn't found (only uses JDK types so this would be
* very odd).
*/
@SuppressWarnings("unchecked")
private static SparseData load(String path) throws IOException, ClassNotFoundException {
try (ObjectInputStream ois =
new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
int[] labels = (int[]) ois.readObject();
List<int[]> indices = (List<int[]>) ois.readObject();
List<float[]> values = (List<float[]>) ois.readObject();
return new SparseData(labels, indices, values);
private static SparseData load(String path) throws IOException {
int pos = 0;
List<int[]> indices = new ArrayList<>();
List<float[]> values = new ArrayList<>();
List<Integer> labels = new ArrayList<>();
String line;
int maxFeatureID = Integer.MIN_VALUE;
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
for (; ; ) {
line = reader.readLine();
if (line == null) {
break;
}
pos++;
String[] fields = splitPattern.split(line);
int lastID = -1;
try {
boolean valid = true;
List<Integer> curIndices = new ArrayList<>();
List<Float> curValues = new ArrayList<>();
for (int i = 1; i < fields.length && valid; i++) {
int ind = fields[i].indexOf(':');
if (ind < 0) {
logger.warning(String.format("Weird line at %d", pos));
valid = false;
}
String ids = fields[i].substring(0, ind);
int id = Integer.parseInt(ids);
curIndices.add(id);
if (maxFeatureID < id) {
maxFeatureID = id;
}
float val = Float.parseFloat(fields[i].substring(ind + 1));
curValues.add(val);
if (id <= lastID) {
logger.warning(String.format("Repeated features at line %d", pos));
valid = false;
} else {
lastID = id;
}
}
if (valid) {
// Store the label
labels.add(Integer.parseInt(fields[0]));
// Store the features
indices.add(convertInts(curIndices));
values.add(convertFloats(curValues));
} else {
throw new IOException("Invalid LibSVM format file at line " + pos);
}
} catch (NumberFormatException ex) {
logger.warning(String.format("Weird line at %d", pos));
throw new IOException("Invalid LibSVM format file", ex);
}
}
}
logger.info(
"Loaded "
+ maxFeatureID
+ " features, "
+ labels.size()
+ " samples, from + '"
+ path
+ "'.");
return new SparseData(convertInts(labels), indices, values);
}
/**
@ -170,11 +256,10 @@ public class ScoreMNIST {
return idx;
}
public static void main(String[] args) throws OrtException, IOException, ClassNotFoundException {
public static void main(String[] args) throws OrtException, IOException {
if (args.length < 2 || args.length > 3) {
System.out.println("Usage: ScoreMNIST <model-path> <test-data> <optional:scikit-learn-flag>");
System.out.println(
"The test data input format is a Java serialized file containing an array of int labels, a list of int[] feature indices, and a list of float[] feature values");
System.out.println("The test data input should be a libsvm format version of MNIST.");
return;
}
@ -232,7 +317,7 @@ public class ScoreMNIST {
confusionMatrix[data.labels[i]][predLabel]++;
if (i % 500 == 0) {
if (i % 2000 == 0) {
logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (i + 1));
logger.log(Level.INFO, "Output type = " + output.get(0).toString());
if (args.length == 3) {

BIN
java/testdata/cnn_mnist_pytorch.onnx vendored Normal file

Binary file not shown.

BIN
java/testdata/lr_mnist_scikit.onnx vendored Normal file

Binary file not shown.